diff options
328 files changed, 10530 insertions, 5516 deletions
diff --git a/.buildkite/hooks/post-command b/.buildkite/hooks/post-command index 8af1369a6..c4c6fc90c 100644 --- a/.buildkite/hooks/post-command +++ b/.buildkite/hooks/post-command @@ -51,6 +51,9 @@ if test "${BUILDKITE_COMMAND_EXIT_STATUS}" -ne "0"; then # Attempt to clear the cache and shut down. make clean || echo "make clean failed with code $?" make bazel-shutdown || echo "make bazel-shutdown failed with code $?" + # Attempt to clear any Go cache. + sudo rm -rf "${HOME}/.cache/go-build" + sudo rm -rf "${HOME}/go" fi # Kill any running containers (clear state). @@ -151,10 +151,6 @@ nogo: ## Surfaces all nogo findings. @$(call run,//tools/github $(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo) .PHONY: nogo -gazelle: ## Runs gazelle to update WORKSPACE. - @$(call run,//:gazelle update-repos -from_file=go.mod -prune) -.PHONY: gazelle - ## ## Canonical build and test targets. ## @@ -8,17 +8,17 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") http_file( name = "google_root_pem", urls = [ - "https://pki.goog/roots.pem" + "https://pki.goog/roots.pem", ], ) # Bazel/starlark utilities. http_archive( name = "bazel_skylib", - sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", + sha256 = "1c531376ac7e5a180e0237938a2536de0c54d93f5c278634818e0efc952dd56c", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.3/bazel-skylib-1.0.3.tar.gz", ], ) @@ -36,14 +36,16 @@ http_archive( name = "io_bazel_rules_go", patch_args = ["-p1"], patches = [ + # Ensure we don't destroy the facts visibility. + "//tools:rules_go_visibility.patch", # Newer versions of the rules_go rules will automatically strip test # binaries of symbols, which we don't want. - "//tools:rules_go.patch", + "//tools:rules_go_symbols.patch", ], - sha256 = "8e9434015ff8f3d6962cb8f016230ea7acc1ac402b760a8d66ff54dc11673ca6", + sha256 = "7904dbecbaffd068651916dce77ff3437679f9d20e1a7956bff43826e7645fcc", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.10/rules_go-v0.24.10.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.24.10/rules_go-v0.24.10.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.25.1/rules_go-v0.25.1.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.25.1/rules_go-v0.25.1.tar.gz", ], ) @@ -51,14 +53,18 @@ http_archive( name = "bazel_gazelle", patch_args = ["-p1"], patches = [ + # Fix permissions for facts for go_library, not just tool library. + # This is actually a no-op with the hacky patch above, but should + # slightly future proof this mechanism. + "//tools:bazel_gazelle_generate.patch", # False positive output complaining about Go logrus versions spam the # logs. Strip this message in this case. Does not affect control flow. - "//tools:bazel_gazelle.patch", + "//tools:bazel_gazelle_noise.patch", ], - sha256 = "b85f48fa105c4403326e9525ad2b2cc437babaa6e15a3fc0b1dbab0ab064bc7c", + sha256 = "222e49f034ca7a1d1231422cdb67066b885819885c356673cb1f72f748a3c9d4", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.3/bazel-gazelle-v0.22.3.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.3/bazel-gazelle-v0.22.3.tar.gz", ], ) @@ -66,21 +72,20 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() -go_register_toolchains(go_version = "1.15.2") +go_register_toolchains(go_version = "1.15.7") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") gazelle_dependencies() -# The com_google_protobuf repository below would trigger downloading a older -# version of org_golang_x_sys. If putting this repository statment in a place -# after that of the com_google_protobuf, this statement will not work as -# expected to download a new version of org_golang_x_sys. +# Some repository below has a transitive dependency on this repository. This +# declaration must precede any later declaration that transitively depends on +# an older version, since only the first declaration is considered. go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", - sum = "h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=", - version = "v0.0.0-20200323222414-85ca7c5b95cd", + sum = "h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k=", + version = "v0.0.0-20210119212857-b64e53b001e4", ) # Load C++ rules. @@ -111,16 +116,72 @@ cc_crosstool(name = "crosstool") # Load protobuf dependencies. http_archive( name = "rules_proto", - sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208", - strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313", + sha256 = "2a20fd8af3cad3fbab9fd3aec4a137621e0c31f858af213a7ae0f997723fc4a9", + strip_prefix = "rules_proto-a0761ed101b939e19d83b2da5f59034bffc19c12", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", - "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/a0761ed101b939e19d83b2da5f59034bffc19c12.tar.gz", + "https://github.com/bazelbuild/rules_proto/archive/a0761ed101b939e19d83b2da5f59034bffc19c12.tar.gz", ], ) load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") +go_repository( + name = "com_github_go_gl_glfw", + importpath = "github.com/go-gl/glfw", + sum = "h1:QbL/5oDUmRBzO9/Z7Seo6zf912W/a6Sr4Eu0G/3Jho0=", + version = "v0.0.0-20190409004039-e6da0acd62b1", +) + +go_repository( + name = "com_github_google_go_github_v32", + importpath = "github.com/google/go-github/v32", + sum = "h1:GWkQOdXqviCPx7Q7Fj+KyPoGm4SwHRh8rheoPhd27II=", + version = "v32.1.0", +) + +go_repository( + name = "com_github_google_martian_v3", + importpath = "github.com/google/martian/v3", + sum = "h1:wCKgOCHuUEVfsaQLpPSJb7VdYCdTVZQAuOdYm1yc/60=", + version = "v3.1.0", +) + +go_repository( + name = "io_rsc_quote_v3", + importpath = "rsc.io/quote/v3", + sum = "h1:9JKUTTIUgS6kzR9mK1YuGKv6Nl+DijDNIc0ghT58FaY=", + version = "v3.1.0", +) + +go_repository( + name = "io_rsc_sampler", + importpath = "rsc.io/sampler", + sum = "h1:7uVkIFmeBqHfdjD+gZwtXXI+RODJ2Wc4O7MPEh/QiW4=", + version = "v1.3.0", +) + +go_repository( + name = "org_golang_x_term", + importpath = "golang.org/x/term", + sum = "h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=", + version = "v0.0.0-20201126162022-7de9c90e9dd1", +) + +go_repository( + name = "com_github_hashicorp_errwrap", + importpath = "github.com/hashicorp/errwrap", + sum = "h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_go_multierror", + importpath = "github.com/hashicorp/go-multierror", + sum = "h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI=", + version = "v1.1.0", +) + rules_proto_dependencies() rules_proto_toolchains() @@ -129,11 +190,11 @@ rules_proto_toolchains() # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( name = "bazel_toolchains", - sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48", - strip_prefix = "bazel-toolchains-3.1.1", + sha256 = "1adf5db506a7e3c465a26988514cfc3971af6d5b3c2218925cd6e71ee443fc3f", + strip_prefix = "bazel-toolchains-4.0.0", urls = [ - "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/releases/download/4.0.0/bazel-toolchains-4.0.0.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/4.0.0/bazel-toolchains-4.0.0.tar.gz", ], ) @@ -223,8 +284,8 @@ http_file( go_repository( name = "com_github_sirupsen_logrus", importpath = "github.com/sirupsen/logrus", - sum = "h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=", - version = "v1.6.0", + sum = "h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM=", + version = "v1.7.0", ) go_repository( @@ -252,8 +313,8 @@ go_repository( go_repository( name = "com_github_golang_mock", importpath = "github.com/golang/mock", - sum = "h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=", - version = "v1.3.1", + sum = "h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=", + version = "v1.4.4", ) go_repository( @@ -266,8 +327,8 @@ go_repository( go_repository( name = "com_github_google_uuid", importpath = "github.com/google/uuid", - sum = "h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=", - version = "v1.1.1", + sum = "h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=", + version = "v1.1.2", ) go_repository( @@ -316,8 +377,8 @@ go_repository( name = "org_golang_google_grpc", build_file_proto_mode = "disable", importpath = "google.golang.org/grpc", - sum = "h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=", - version = "v1.29.0", + sum = "h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4=", + version = "v1.36.0-dev.0.20210122012134-2c42474aca0c", ) go_repository( @@ -337,29 +398,29 @@ go_repository( go_repository( name = "org_golang_x_mod", importpath = "golang.org/x/mod", - sum = "h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=", - version = "v0.3.0", + sum = "h1:8pl+sMODzuvGJkmj2W4kZihvVb5mKm8pB/X44PIQHv8=", + version = "v0.4.0", ) go_repository( name = "org_golang_x_net", importpath = "golang.org/x/net", - sum = "h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=", - version = "v0.0.0-20200822124328-c89045814202", + sum = "h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw=", + version = "v0.0.0-20201224014010-6772e930b67b", ) go_repository( name = "org_golang_x_sync", importpath = "golang.org/x/sync", - sum = "h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=", - version = "v0.0.0-20200625203802-6e8e738ad208", + sum = "h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=", + version = "v0.0.0-20201020160332-67f06af15bc9", ) go_repository( name = "org_golang_x_text", importpath = "golang.org/x/text", - sum = "h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=", - version = "v0.3.2", + sum = "h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc=", + version = "v0.3.4", ) go_repository( @@ -372,8 +433,8 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk=", - version = "v0.0.0-20201021000207-d49c4edd7d96", + sum = "h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY=", + version = "v0.1.0", ) go_repository( @@ -393,15 +454,15 @@ go_repository( go_repository( name = "com_github_golang_protobuf", importpath = "github.com/golang/protobuf", - sum = "h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0=", - version = "v1.4.1", + sum = "h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM=", + version = "v1.4.3", ) go_repository( name = "org_golang_x_oauth2", importpath = "golang.org/x/oauth2", - sum = "h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=", - version = "v0.0.0-20200107190931-bf48bf16ab8d", + sum = "h1:Lm4OryKCca1vehdsWogr9N4t7NfZxLbJoc/H0w4K4S4=", + version = "v0.0.0-20201208152858-08078c50e5b5", ) go_repository( @@ -470,8 +531,8 @@ go_repository( go_repository( name = "com_github_stretchr_testify", importpath = "github.com/stretchr/testify", - sum = "h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=", - version = "v1.4.0", + sum = "h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=", + version = "v1.5.1", ) go_repository( @@ -484,8 +545,8 @@ go_repository( go_repository( name = "com_github_microsoft_go_winio", importpath = "github.com/Microsoft/go-winio", - sum = "h1:9pygWVFqbY9lPxM0peffumuVDyMuIMzNLyO9uFjJuQo=", - version = "v0.4.15-0.20200908182639-5b44b70ab3ab", + sum = "h1:FtSW/jqD+l4ba5iPBj9CODVtgfYAD8w2wS923g/cFDk=", + version = "v0.4.16", ) go_repository( @@ -512,22 +573,22 @@ go_repository( go_repository( name = "com_google_cloud_go", importpath = "cloud.google.com/go", - sum = "h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo=", - version = "v0.52.1-0.20200122224058-0482b626c726", + sum = "h1:XgtDnVJRCPEUG21gjFiRPz4zI1Mjg16R+NYQjfmU4XY=", + version = "v0.75.0", ) go_repository( name = "io_opencensus_go", importpath = "go.opencensus.io", - sum = "h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=", - version = "v0.22.2", + sum = "h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0=", + version = "v0.22.5", ) go_repository( name = "co_honnef_go_tools", importpath = "honnef.co/go/tools", - sum = "h1:W18jzjh8mfPez+AwGLxmOImucz/IFjpNlrKVnaj2YVc=", - version = "v0.0.1-2020.1.6", + sum = "h1:EVDuO03OCZwpV2t/tLLxPmPiomagMoBOgfPt0FM+4IY=", + version = "v0.1.1", ) go_repository( @@ -554,8 +615,8 @@ go_repository( go_repository( name = "com_github_cncf_udpa_go", importpath = "github.com/cncf/udpa/go", - sum = "h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=", - version = "v0.0.0-20191209042840-269d4d468f6f", + sum = "h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M=", + version = "v0.0.0-20201120205902-5459f2c99403", ) go_repository( @@ -569,15 +630,15 @@ go_repository( go_repository( name = "com_github_containerd_console", importpath = "github.com/containerd/console", - sum = "h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=", - version = "v0.0.0-20191206165004-02ecf6a7291e", + sum = "h1:u7SFAJyRqWcG6ogaMAx3KjSTy1e3hT9QxqX7Jco7dRc=", + version = "v1.0.1", ) go_repository( name = "com_github_containerd_continuity", importpath = "github.com/containerd/continuity", - sum = "h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=", - version = "v0.0.0-20200928162600-f2cc35102c2a", + sum = "h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI=", + version = "v0.0.0-20201208142359-180525291bb7", ) go_repository( @@ -597,15 +658,8 @@ go_repository( go_repository( name = "com_github_containerd_ttrpc", importpath = "github.com/containerd/ttrpc", - sum = "h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc=", - version = "v0.0.0-20200121165050-0be804eadb15", -) - -go_repository( - name = "com_github_coreos_go_systemd", - importpath = "github.com/coreos/go-systemd", - sum = "h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=", - version = "v0.0.0-20191104093116-d3cd4ed1dbcf", + sum = "h1:2/O3oTZN36q2xRolk0a2WWGgh7/Vf/liElg5hFYLX9U=", + version = "v1.0.2", ) go_repository( @@ -625,8 +679,8 @@ go_repository( go_repository( name = "com_github_envoyproxy_go_control_plane", importpath = "github.com/envoyproxy/go-control-plane", - sum = "h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=", - version = "v0.9.4", + sum = "h1:EmNYJhPYy0pOFjCx2PrgtaBXmee0iUX9hLlxE1xHOJE=", + version = "v0.9.9-0.20201210154907-fd9021fe5dad", ) go_repository( @@ -637,13 +691,6 @@ go_repository( ) go_repository( - name = "com_github_godbus_dbus", - importpath = "github.com/godbus/dbus", - sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=", - version = "v0.0.0-20190422162347-ade71ed3457e", -) - -go_repository( name = "com_github_gogo_googleapis", importpath = "github.com/gogo/googleapis", sum = "h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=", @@ -667,15 +714,8 @@ go_repository( go_repository( name = "com_github_google_go_cmp", importpath = "github.com/google/go-cmp", - sum = "h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI=", - version = "v0.5.3-0.20201020212313-ab46b8bd0abd", -) - -go_repository( - name = "com_github_google_go_github_v28", - importpath = "github.com/google/go-github/v28", - sum = "h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=", - version = "v28.1.2-0.20191108005307-e555eab49ce8", + sum = "h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M=", + version = "v0.5.4", ) go_repository( @@ -716,8 +756,8 @@ go_repository( go_repository( name = "com_github_microsoft_hcsshim", importpath = "github.com/Microsoft/hcsshim", - sum = "h1:ZfF0+zZeYdzMIVMZHKtDKJvLHj76XCuVae/jNkjj0IA=", - version = "v0.8.6", + sum = "h1:lbPVK25c1cu5xTLITwpUcxoA9vKrKErASPYygvouJns=", + version = "v0.8.14", ) go_repository( @@ -730,8 +770,8 @@ go_repository( go_repository( name = "com_github_opencontainers_runtime_spec", importpath = "github.com/opencontainers/runtime-spec", - sum = "h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8=", - version = "v1.0.2-0.20181111125026-1722abf79c2f", + sum = "h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0=", + version = "v1.0.2", ) go_repository( @@ -800,15 +840,15 @@ go_repository( go_repository( name = "org_golang_google_appengine", importpath = "google.golang.org/appengine", - sum = "h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=", - version = "v1.6.5", + sum = "h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=", + version = "v1.6.7", ) go_repository( name = "org_golang_google_genproto", importpath = "google.golang.org/genproto", - sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", - version = "v0.0.0-20200117163144-32f20d992d24", + sum = "h1:n7yjMkxUgbEahYENvAGVlxMUW8TF/KEavLez31znfDw=", + version = "v0.0.0-20210108203827-ffc7fda8c3d7", ) go_repository( @@ -821,15 +861,15 @@ go_repository( go_repository( name = "org_golang_x_exp", importpath = "golang.org/x/exp", - sum = "h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg=", - version = "v0.0.0-20191227195350-da58074b4299", + sum = "h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y=", + version = "v0.0.0-20200224162631-6cc2880d07d6", ) go_repository( name = "org_golang_x_lint", importpath = "golang.org/x/lint", - sum = "h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE=", - version = "v0.0.0-20191125180803-fdd1cda4f05f", + sum = "h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI=", + version = "v0.0.0-20201208152925-83fdc39ff7b5", ) go_repository( @@ -870,15 +910,15 @@ go_repository( go_repository( name = "com_github_go_gl_glfw_v3_3_glfw", importpath = "github.com/go-gl/glfw/v3.3/glfw", - sum = "h1:b+9H1GAsx5RsjvDFLoS5zkNBzIQMuVKUYQDmxU3N5XE=", - version = "v0.0.0-20191125211704-12ad95a8df72", + sum = "h1:WtGNWLvXpe6ZudgnXrq0barxBImvnnJoMEhXAzcbM0I=", + version = "v0.0.0-20200222043503-6f7a984d4dc4", ) go_repository( name = "com_github_golang_groupcache", importpath = "github.com/golang/groupcache", - sum = "h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=", - version = "v0.0.0-20191227052852-215e87163ea7", + sum = "h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=", + version = "v0.0.0-20200121045136-8c9f03a8e57e", ) go_repository( @@ -891,8 +931,8 @@ go_repository( go_repository( name = "com_github_google_pprof", importpath = "github.com/google/pprof", - sum = "h1:LR89qFljJ48s990kEKGsk213yIJDPI4205OKOzbURK8=", - version = "v0.0.0-20201218002935-b9804c9f04c2", + sum = "h1:LB1NXQJhB+dF+5kZVLIn85HJqGvK3zKHIku3bdy3IRc=", + version = "v0.0.0-20210115211752-39141e76b647", ) go_repository( @@ -912,8 +952,8 @@ go_repository( go_repository( name = "com_github_ianlancetaylor_demangle", importpath = "github.com/ianlancetaylor/demangle", - sum = "h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c=", - version = "v0.0.0-20181102032728-5e5cf60278f6", + sum = "h1:mV02weKRL81bEnm8A0HT1/CAelMQDBuQIfLw8n+d6xI=", + version = "v0.0.0-20200824232613-28f6c0f3b639", ) go_repository( @@ -954,8 +994,8 @@ go_repository( go_repository( name = "org_golang_google_api", importpath = "google.golang.org/api", - sum = "h1:yzlyyDW/J0w8yNFJIhiAJy4kq74S+1DOLdawELNxFMA=", - version = "v0.15.0", + sum = "h1:l2Nfbl2GPXdWorv+dT2XfinX2jOOw4zv1VhLstx+6rE=", + version = "v0.36.0", ) go_repository( @@ -982,71 +1022,50 @@ go_repository( go_repository( name = "com_github_vishvananda_netns", importpath = "github.com/vishvananda/netns", - sum = "h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=", - version = "v0.0.0-20200728191858-db3c7e526aae", + sum = "h1:p4VB7kIXpOQvVn1ZaTIVp+3vuYAXFe3OJEvjbUYJLaA=", + version = "v0.0.0-20210104183010-2eb08e3e575f", ) go_repository( name = "com_google_cloud_go_bigquery", importpath = "cloud.google.com/go/bigquery", - sum = "h1:hL+ycaJpVE9M7nLoiXb/Pn10ENE2u+oddxbD8uu0ZVU=", - version = "v1.0.1", + sum = "h1:PQcPefKFdaIzjQFbiyOgAqyx8q5djaE7x9Sqe712DPA=", + version = "v1.8.0", ) go_repository( name = "com_google_cloud_go_datastore", importpath = "cloud.google.com/go/datastore", - sum = "h1:Kt+gOPPp2LEPWp8CSfxhsM8ik9CcyE/gYu+0r+RnZvM=", - version = "v1.0.0", + sum = "h1:/May9ojXjRkPBNVrq+oWLqmWCkr4OU5uRY29bu0mRyQ=", + version = "v1.1.0", ) go_repository( name = "com_google_cloud_go_pubsub", importpath = "cloud.google.com/go/pubsub", - sum = "h1:W9tAK3E57P75u0XLLR82LZyw8VpAnhmyTOxW9qzmyj8=", - version = "v1.0.1", + sum = "h1:ukjixP1wl0LpnZ6LWtZJ0mX5tBmjp1f8Sqer8Z2OMUU=", + version = "v1.3.1", ) go_repository( name = "com_google_cloud_go_storage", importpath = "cloud.google.com/go/storage", - sum = "h1:VV2nUM3wwLLGh9lSABFgZMjInyUbJeaRSE64WuAIQ+4=", - version = "v1.0.0", -) - -go_repository( - name = "com_github_hashicorp_errwrap", - importpath = "github.com/hashicorp/errwrap", - sum = "h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=", - version = "v1.0.0", -) - -go_repository( - name = "com_github_hashicorp_go_multierror", - importpath = "github.com/hashicorp/go-multierror", - sum = "h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=", - version = "v1.0.0", -) - -go_repository( - name = "com_github_dpjacques_clockwork", - importpath = "github.com/dpjacques/clockwork", - sum = "h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c=", - version = "v0.1.1-0.20200827220843-c1f524b839be", + sum = "h1:STgFzyU5/8miMl0//zKh2aQeTyeaUH3WN9bSUiJ09bA=", + version = "v1.10.0", ) go_repository( name = "com_github_cilium_ebpf", importpath = "github.com/cilium/ebpf", - sum = "h1:i8+1fuPLjSgAYXUyBlHNhFwjcfAsP4ufiuH1+PWkyDU=", - version = "v0.0.0-20200110133405-4032b1d8aae3", + sum = "h1:Fv93L3KKckEcEHR3oApXVzyBTDA8WAm6VXhPE00N3f8=", + version = "v0.2.0", ) go_repository( name = "com_github_coreos_go_systemd_v22", importpath = "github.com/coreos/go-systemd/v22", - sum = "h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQamW5YV28=", - version = "v22.0.0", + sum = "h1:kq/SbG2BCKLkDKkjQf5OWwKWUKj1lgs3lFI4PxnR5lg=", + version = "v22.1.0", ) go_repository( @@ -1433,8 +1452,8 @@ go_repository( go_repository( name = "com_github_xeipuuv_gojsonpointer", importpath = "github.com/xeipuuv/gojsonpointer", - sum = "h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=", - version = "v0.0.0-20190905194746-02993c407bfb", + sum = "h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=", + version = "v0.0.0-20180127040702-4e3ac2762d5f", ) go_repository( @@ -5,51 +5,42 @@ go 1.15 replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0 require ( - cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 // indirect - github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab // indirect - github.com/Microsoft/hcsshim v0.8.6 // indirect + cloud.google.com/go v0.75.0 // indirect + github.com/Microsoft/go-winio v0.4.16 // indirect + github.com/Microsoft/hcsshim v0.8.14 // indirect github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect - github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect - github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 + github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 // indirect + github.com/containerd/console v1.0.1 // indirect github.com/containerd/containerd v1.3.9 // indirect - github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect + github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 // indirect github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect + github.com/containerd/ttrpc v1.0.2 // indirect github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 // indirect - github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect - github.com/coreos/go-systemd/v22 v22.0.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible // indirect github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect github.com/docker/go-connections v0.3.0 // indirect github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect - github.com/docker/go-units v0.4.0 // indirect - github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be // indirect - github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect - github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd // indirect - github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect + github.com/google/go-github/v32 v32.1.0 // indirect + github.com/google/pprof v0.0.0-20210115211752-39141e76b647 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect - github.com/hashicorp/go-multierror v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.0 // indirect github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect - github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a + github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a // indirect github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/opencontainers/runc v0.1.1 // indirect - github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect github.com/pborman/uuid v1.2.0 // indirect github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect - github.com/urfave/cli v1.22.2 // indirect github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect - github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect + github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect go.uber.org/multierr v1.6.0 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect - golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 // indirect - google.golang.org/grpc v1.29.0 // indirect + google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c // indirect google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect - gotest.tools v2.2.0+incompatible // indirect - k8s.io/api v0.16.13 - k8s.io/apimachinery v0.16.14-rc.0 - k8s.io/client-go v0.16.13 + honnef.co/go/tools v0.1.1 // indirect + k8s.io/apimachinery v0.16.14-rc.0 // indirect + k8s.io/client-go v0.16.13 // indirect ) @@ -6,12 +6,34 @@ cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6A cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo= -cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go v0.50.0/go.mod h1:r9sluTvynVuxRIOHXQEHMFffphuXHOMZMycpNR5e6To= +cloud.google.com/go v0.52.0/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go v0.53.0/go.mod h1:fp/UouUEsRkN6ryDKNW/Upv/JBKnv6WDthjR6+vze6M= +cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bPc= +cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= +cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= +cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= +cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= +cloud.google.com/go v0.72.0/go.mod h1:M+5Vjvlc2wnp6tjzE102Dw08nGShTscUx2nZMufOKPI= +cloud.google.com/go v0.75.0 h1:XgtDnVJRCPEUG21gjFiRPz4zI1Mjg16R+NYQjfmU4XY= +cloud.google.com/go v0.75.0/go.mod h1:VGuuCn7PG0dwsd5XPVm2Mm3wlh3EL55/79EKB6hlPTY= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= +cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= +cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= +cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= +cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/pubsub v1.1.0/go.mod h1:EwwdRX2sKPjnvnqCa270oGRyludottCI76h+R3AArQw= +cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIAii9o8iA= +cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= +cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= +cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= +cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= @@ -23,61 +45,59 @@ github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbt github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU= -github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= -github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= -github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab h1:9pygWVFqbY9lPxM0peffumuVDyMuIMzNLyO9uFjJuQo= -github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= -github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg= -github.com/Microsoft/hcsshim v0.8.7/go.mod h1:OHd7sQqRFrYd3RmSgbgji+ctCwkbq2wbEYNSzOYtcBQ= -github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= -github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= -github.com/Microsoft/hcsshim v0.8.10 h1:k5wTrpnVU2/xv8ZuzGkbXVd3js5zJ8RnumPo5RxiIxU= -github.com/Microsoft/hcsshim v0.8.10/go.mod h1:g5uw8EV2mAlzqe94tfNBNdr89fnbD/n3HV0OhsddkmM= +github.com/Microsoft/go-winio v0.4.16-0.20201130162521-d1ffc52c7331/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= +github.com/Microsoft/go-winio v0.4.16 h1:FtSW/jqD+l4ba5iPBj9CODVtgfYAD8w2wS923g/cFDk= +github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= +github.com/Microsoft/hcsshim v0.8.14 h1:lbPVK25c1cu5xTLITwpUcxoA9vKrKErASPYygvouJns= +github.com/Microsoft/hcsshim v0.8.14/go.mod h1:NtVKoYxQuTLx6gEq0L96c9Ju4JbRJ4nY2ow3VK6a9Lg= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= -github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3/go.mod h1:MA5e5Lr8slmEg9bt0VpxxWqJlO4iwu3FBdHUzV7wQVg= +github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY= -github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI= -github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59 h1:qWj4qVYZ95vLWwqyNJCQg7rDsG5wPdze0UaPolH7DUk= +github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM= github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw= github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= -github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc= github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE= +github.com/containerd/console v1.0.1 h1:u7SFAJyRqWcG6ogaMAx3KjSTy1e3hT9QxqX7Jco7dRc= +github.com/containerd/console v1.0.1/go.mod h1:XUsP6YE/mKtz6bxc+I8UiKKTP04qjQL4qcS3XoQ5xkw= +github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= +github.com/containerd/containerd v1.3.9 h1:K2U/F4jGAMBqeUssfgJRbFuomLcS2Fxo1vR3UM/Mbh8= github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= -github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM= -github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY= +github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7 h1:6ejg6Lkk8dskcM7wQ28gONkukbQkM4qpj4RnYbpFzrI= +github.com/containerd/continuity v0.0.0-20201208142359-180525291bb7/go.mod h1:kR3BEg7bDFaEddKm54WSmrol1fKWDU1nKYkgrcgZT7Y= github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0= github.com/containerd/go-runc v0.0.0-20180907222934-5a6d9f37cfa3/go.mod h1:IV7qH3hrUgRmyYrtgEeGWJfWbgcHL9CSRruz2Vqcph0= github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds= github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g= +github.com/containerd/ttrpc v0.0.0-20190828154514-0e0f228740de/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X5DkEJB8o= +github.com/containerd/ttrpc v1.0.2 h1:2/O3oTZN36q2xRolk0a2WWGgh7/Vf/liElg5hFYLX9U= +github.com/containerd/ttrpc v1.0.2/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y= github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o= github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg= -github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= -github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQamW5YV28= github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= +github.com/coreos/go-systemd/v22 v22.1.0 h1:kq/SbG2BCKLkDKkjQf5OWwKWUKj1lgs3lFI4PxnR5lg= +github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= @@ -91,26 +111,26 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6Uezg github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= -github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c= -github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg= github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= -github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= -github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= @@ -123,24 +143,33 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1 h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -148,43 +177,57 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI= -github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= -github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github/v32 v32.1.0 h1:GWkQOdXqviCPx7Q7Fj+KyPoGm4SwHRh8rheoPhd27II= +github.com/google/go-github/v32 v32.1.0/go.mod h1:rIEpZD9CTDQwDK9GDrtMTycQNA4JU3qBsCizh3q2WCI= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= +github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/google/pprof v0.0.0-20210115211752-39141e76b647 h1:LB1NXQJhB+dF+5kZVLIn85HJqGvK3zKHIku3bdy3IRc= +github.com/google/pprof v0.0.0-20210115211752-39141e76b647/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k= github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= -github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= +github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639 h1:mV02weKRL81bEnm8A0HT1/CAelMQDBuQIfLw8n+d6xI= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= @@ -192,9 +235,7 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -205,11 +246,9 @@ github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8= github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= @@ -229,8 +268,6 @@ github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5X github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y= github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= -github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8= -github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= @@ -240,7 +277,6 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= @@ -249,35 +285,46 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w= github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f h1:p4VB7kIXpOQvVn1ZaTIVp+3vuYAXFe3OJEvjbUYJLaA= +github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0= +go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= @@ -288,14 +335,17 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191227195350-da58074b4299 h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg= +golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= +golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -304,15 +354,22 @@ golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI= +golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.0 h1:8pl+sMODzuvGJkmj2W4kZihvVb5mKm8pB/X44PIQHv8= +golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -325,23 +382,46 @@ golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200222125558-5a598a2470a0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201031054903-ff519b6c9102/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw= +golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201109201403-9fd604954f58/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5 h1:Lm4OryKCca1vehdsWogr9N4t7NfZxLbJoc/H0w4K4S4= +golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= +golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -355,22 +435,48 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200916030750-2334cc1a136f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4 h1:myAQVi0cGEoqQVR5POX+8RR2mrocKqNN1hmeMqhX27k= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= @@ -392,11 +498,37 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191216173652-a0e659d51361/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20191227053925-7b8e75db28f4/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk= -golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20200122220014-bf1340f18c4a/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200204074204-1cc6d1ef6c74/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200212150539-ea181f53ac56/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= +golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200512131952-2bc93b1c0c88/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200515010526-7d3b6ebf133d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82ur9kSqwfTHTeVxaDqrfMjpcNT6bE= +golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -406,13 +538,29 @@ google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEt google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.13.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/api v0.17.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.19.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= +google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= +google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= +google.golang.org/api v0.35.0/go.mod h1:/XrVsuzM0rZmrsbjJutiuftIzeuTQcEeaYcSk/mQ1dg= +google.golang.org/api v0.36.0 h1:l2Nfbl2GPXdWorv+dT2XfinX2jOOw4zv1VhLstx+6rE= +google.golang.org/api v0.36.0/go.mod h1:+z5ficQTmoYpPn8LCUNVpK5I7hwkpjbcgqA7I34qYtE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -421,9 +569,33 @@ google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191230161307-f3c370f40bfb/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200122232147-0452cf42e150/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200204135345-fa8e72b47b90/go.mod h1:GmwEX6Z4W5gMy59cAlVYjN9JhxgbQH6Gn+gFDQe2lzA= +google.golang.org/genproto v0.0.0-20200212174721-66ed5ce911ce/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200228133532-8c2c7df3a383/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= +google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7 h1:n7yjMkxUgbEahYENvAGVlxMUW8TF/KEavLez31znfDw= +google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -431,14 +603,27 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= -google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKal+60= +google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= +google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c h1:cb+I9RwgcErlwAuOVnGhJ2d3YrcdwGXw+RPArsTWot4= +google.golang.org/grpc v1.36.0-dev.0.20210122012134-2c42474aca0c/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A= google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -457,8 +642,11 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.1.1 h1:EVDuO03OCZwpV2t/tLLxPmPiomagMoBOgfPt0FM+4IY= +honnef.co/go/tools v0.1.1/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= k8s.io/api v0.16.13 h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs= k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs= k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= @@ -472,9 +660,9 @@ k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8= k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E= -k8s.io/utils v0.0.0-20190801114015-581e00157fb1 h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE= k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI= -sigs.k8s.io/yaml v1.1.0 h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs= sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= diff --git a/images/basic/hostoverlaytest/Dockerfile b/images/basic/hostoverlaytest/Dockerfile deleted file mode 100644 index 6cef1a542..000000000 --- a/images/basic/hostoverlaytest/Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY . . - -RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o test_copy_up test_copy_up.c -RUN gcc -O2 -o test_rewinddir test_rewinddir.c diff --git a/images/basic/integrationtest/Dockerfile.x86_64 b/images/basic/integrationtest/Dockerfile.x86_64 new file mode 100644 index 000000000..e80e17527 --- /dev/null +++ b/images/basic/integrationtest/Dockerfile.x86_64 @@ -0,0 +1,7 @@ +FROM ubuntu:bionic + +WORKDIR /root +COPY . . +RUN chmod +x *.sh + +RUN apt-get update && apt-get install -y gcc iputils-ping iproute2 diff --git a/images/basic/hostoverlaytest/copy_up_testfile.txt b/images/basic/integrationtest/copy_up_testfile.txt index e4188c841..e4188c841 100644 --- a/images/basic/hostoverlaytest/copy_up_testfile.txt +++ b/images/basic/integrationtest/copy_up_testfile.txt diff --git a/images/basic/linktest/link_test.c b/images/basic/integrationtest/link_test.c index 45ab00abe..45ab00abe 100644 --- a/images/basic/linktest/link_test.c +++ b/images/basic/integrationtest/link_test.c diff --git a/images/basic/ping4test/ping4.sh b/images/basic/integrationtest/ping4.sh index 2a343712a..2a343712a 100644 --- a/images/basic/ping4test/ping4.sh +++ b/images/basic/integrationtest/ping4.sh diff --git a/images/basic/ping6test/ping6.sh b/images/basic/integrationtest/ping6.sh index 4268951d0..4268951d0 100644 --- a/images/basic/ping6test/ping6.sh +++ b/images/basic/integrationtest/ping6.sh diff --git a/images/basic/hostoverlaytest/test_copy_up.c b/images/basic/integrationtest/test_copy_up.c index 010b261dc..010b261dc 100644 --- a/images/basic/hostoverlaytest/test_copy_up.c +++ b/images/basic/integrationtest/test_copy_up.c diff --git a/images/basic/hostoverlaytest/test_rewinddir.c b/images/basic/integrationtest/test_rewinddir.c index f1a4085e1..f1a4085e1 100644 --- a/images/basic/hostoverlaytest/test_rewinddir.c +++ b/images/basic/integrationtest/test_rewinddir.c diff --git a/images/basic/linktest/Dockerfile b/images/basic/linktest/Dockerfile deleted file mode 100644 index baebc9b76..000000000 --- a/images/basic/linktest/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY . . - -RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o link_test link_test.c diff --git a/images/basic/ping4test/Dockerfile b/images/basic/ping4test/Dockerfile deleted file mode 100644 index 1536be376..000000000 --- a/images/basic/ping4test/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY ping4.sh . -RUN chmod +x ping4.sh - -RUN apt-get update && apt-get install -y iputils-ping diff --git a/images/basic/ping6test/Dockerfile b/images/basic/ping6test/Dockerfile deleted file mode 100644 index cb740bd60..000000000 --- a/images/basic/ping6test/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY ping6.sh . -RUN chmod +x ping6.sh - -RUN apt-get update && apt-get install -y iputils-ping iproute2 diff --git a/images/default/Dockerfile b/images/default/Dockerfile index 19b340237..5f652f2c3 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -24,6 +24,6 @@ RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud # Download the official bazel binary. The APT repository isn't used because there is not packages for arm64. -RUN sh -c 'curl -o /usr/local/bin/bazel https://releases.bazel.build/3.5.1/release/bazel-3.5.1-linux-$(uname -m | sed s/aarch64/arm64/) && chmod ugo+x /usr/local/bin/bazel' +RUN sh -c 'curl -o /usr/local/bin/bazel https://releases.bazel.build/4.0.0/release/bazel-4.0.0-linux-$(uname -m | sed s/aarch64/arm64/) && chmod ugo+x /usr/local/bin/bazel' WORKDIR /workspace ENTRYPOINT ["/usr/local/bin/bazel"] diff --git a/images/syzkaller/Dockerfile b/images/syzkaller/Dockerfile new file mode 100644 index 000000000..df6680f40 --- /dev/null +++ b/images/syzkaller/Dockerfile @@ -0,0 +1,9 @@ +FROM gcr.io/syzkaller/env + +RUN apt update && apt install -y git vim strace gdb procps + +WORKDIR /syzkaller/gopath/src/github.com/google/syzkaller + +RUN git init . && git remote add origin https://github.com/google/syzkaller && git fetch origin && git checkout origin/master && make + +ENTRYPOINT ./bin/syz-manager --config /tmp/syzkaller/syzkaller.cfg diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md new file mode 100644 index 000000000..1eac474f3 --- /dev/null +++ b/images/syzkaller/README.md @@ -0,0 +1,25 @@ +syzkaller is an unsupervised coverage-guided kernel fuzzer. + +* [Github](https://github.com/google/syzkaller) +* [gVisor dashboard](https://syzkaller.appspot.com/gvisor) + +# How to run syzkaller. + +* Build the syzkaller docker image `make load-syzkaller` +* Build runsc and place it in /tmp/syzkaller. `make RUNTIME_DIR=/tmp/syzkaller + refresh` +* Copy the syzkaller config in /tmp/syzkaller `cp + images/syzkaller/default-gvisor-config.cfg /tmp/syzkaller/syzkaller.cfg` +* Run syzkaller `docker run --privileged -it --rm -v + /tmp/syzkaller:/tmp/syzkaller gvisor.dev/images/syzkaller:latest` + +# How to run a syz repro. + +* Repeate all steps except the last one from the previous section. + +* Save a syzkaller repro in /tmp/syzkaller/repro + +* Run syz-repro `docker run --privileged -it --rm -v + /tmp/syzkaller:/tmp/syzkaller --entrypoint="" + gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config + /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro` diff --git a/images/syzkaller/default-gvisor-config.cfg b/images/syzkaller/default-gvisor-config.cfg new file mode 100644 index 000000000..c69641c21 --- /dev/null +++ b/images/syzkaller/default-gvisor-config.cfg @@ -0,0 +1,15 @@ +{ + "name": "gvisor", + "target": "linux/amd64", + "http": ":80", + "workdir": "/tmp/syzkaller/workdir/", + "image": "/tmp/syzkaller/runsc", + "syzkaller": "/syzkaller/gopath/src/github.com/google/syzkaller", + "cover": false, + "procs": 1, + "type": "gvisor", + "vm": { + "count": 1, + "runsc_args": "--debug --network none --platform ptrace --vfs2 --fuse -net-raw -watchdog-action=panic" + } +} @@ -126,10 +126,10 @@ analyzers: - ".*_test.go" # Exclude tests. - "pkg/flipcall/.*_unsafe.go" # Special case. - pkg/gohacks/gohacks_unsafe.go # Special case. + - pkg/ring0/pagetables/allocator_unsafe.go # Special case. - pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go # Special case. - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case. - pkg/sentry/platform/kvm/machine_unsafe.go # Special case. - - pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go # Special case. - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case. - pkg/sentry/vfs/mount_unsafe.go # Special case. - pkg/state/decode_unsafe.go # Special case. @@ -155,4 +155,6 @@ analyzers: SA5011: internal: exclude: - - pkg/sentry/fs/fdpipe/pipe_opener_test.go # False positive. + # https://github.com/dominikh/go-tools/issues/924 + - pkg/sentry/fs/fdpipe/pipe_opener_test.go + - pkg/tcpip/tests/integration/link_resolution_test.go diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 8591acbf2..185eee0bb 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -347,26 +347,57 @@ const SizeOfLinger = 8 // // +marshal type TCPInfo struct { - State uint8 - CaState uint8 + // State is the state of the connection. + State uint8 + + // CaState is the congestion control state. + CaState uint8 + + // Retransmits is the number of retransmissions triggered by RTO. Retransmits uint8 - Probes uint8 - Backoff uint8 - Options uint8 - // WindowScale is the combination of snd_wscale (first 4 bits) and rcv_wscale (second 4 bits) + + // Probes is the number of unanswered zero window probes. + Probes uint8 + + // BackOff indicates exponential backoff. + Backoff uint8 + + // Options indicates the options enabled for the connection. + Options uint8 + + // WindowScale is the combination of snd_wscale (first 4 bits) and + // rcv_wscale (second 4 bits) WindowScale uint8 - // DeliveryRateAppLimited is a boolean and only the first bit is meaningful. + + // DeliveryRateAppLimited is a boolean and only the first bit is + // meaningful. DeliveryRateAppLimited uint8 - RTO uint32 - ATO uint32 + // RTO is the retransmission timeout. + RTO uint32 + + // ATO is the acknowledgement timeout interval. + ATO uint32 + + // SndMss is the send maximum segment size. SndMss uint32 + + // RcvMss is the receive maximum segment size. RcvMss uint32 + // Unacked is the number of packets sent but not acknowledged. Unacked uint32 - Sacked uint32 - Lost uint32 + + // Sacked is the number of packets which are selectively acknowledged. + Sacked uint32 + + // Lost is the number of packets marked as lost. + Lost uint32 + + // Retrans is the number of retransmitted packets. Retrans uint32 + + // Fackets is not used and is always zero. Fackets uint32 // Times. @@ -385,37 +416,78 @@ type TCPInfo struct { Advmss uint32 Reordering uint32 - RcvRTT uint32 + // RcvRTT is the receiver round trip time. + RcvRTT uint32 + + // RcvSpace is the current buffer space available for receiving data. RcvSpace uint32 + // TotalRetrans is the total number of retransmits seen since the start + // of the connection. TotalRetrans uint32 - PacingRate uint64 + // PacingRate is the pacing rate in bytes per second. + PacingRate uint64 + + // MaxPacingRate is the maximum pacing rate. MaxPacingRate uint64 + // BytesAcked is RFC4898 tcpEStatsAppHCThruOctetsAcked. BytesAcked uint64 + // BytesReceived is RFC4898 tcpEStatsAppHCThruOctetsReceived. BytesReceived uint64 + // SegsOut is RFC4898 tcpEStatsPerfSegsOut. SegsOut uint32 + // SegsIn is RFC4898 tcpEStatsPerfSegsIn. SegsIn uint32 + // NotSentBytes is the amount of bytes in the write queue that are not + // yet sent. NotSentBytes uint32 - MinRTT uint32 + + // MinRTT is the minimum round trip time seen in the connection. + MinRTT uint32 + // DataSegsIn is RFC4898 tcpEStatsDataSegsIn. DataSegsIn uint32 + // DataSegsOut is RFC4898 tcpEStatsDataSegsOut. DataSegsOut uint32 + // DeliveryRate is the most recent delivery rate in bytes per second. DeliveryRate uint64 // BusyTime is the time in microseconds busy sending data. BusyTime uint64 + // RwndLimited is the time in microseconds limited by receive window. RwndLimited uint64 + // SndBufLimited is the time in microseconds limited by send buffer. SndBufLimited uint64 + + // Delivered is the total data packets delivered including retransmits. + Delivered uint32 + + // DeliveredCE is the total ECE marked data packets delivered including + // retransmits. + DeliveredCE uint32 + + // BytesSent is RFC4898 tcpEStatsPerfHCDataOctetsOut. + BytesSent uint64 + + // BytesRetrans is RFC4898 tcpEStatsPerfOctetsRetrans. + BytesRetrans uint64 + + // DSACKDups is RFC4898 tcpEStatsStackDSACKDups. + DSACKDups uint32 + + // ReordSeen is the number of reordering events seen since the start of + // the connection. + ReordSeen uint32 } // SizeOfTCPInfo is the binary size of a TCPInfo struct. diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go index 2a8d4708b..1a3c0916f 100644 --- a/pkg/abi/linux/tcp.go +++ b/pkg/abi/linux/tcp.go @@ -59,3 +59,12 @@ const ( MAX_TCP_KEEPINTVL = 32767 MAX_TCP_KEEPCNT = 127 ) + +// Congestion control states from include/uapi/linux/tcp.h. +const ( + TCP_CA_Open = 0 + TCP_CA_Disorder = 1 + TCP_CA_CWR = 2 + TCP_CA_Recovery = 3 + TCP_CA_Loss = 4 +) diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/ring0/BUILD index 2852b7387..d1b14efdb 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/ring0/BUILD @@ -43,16 +43,16 @@ arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) go_library( @@ -77,9 +77,9 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", + "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/sentry/arch", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/ring0/aarch64.go index 3bda594f9..3bda594f9 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/ring0/aarch64.go diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/ring0/defs.go index f9765771e..e2561e4c2 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/ring0/defs.go @@ -15,8 +15,8 @@ package ring0 import ( + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/ring0/defs_amd64.go index 7a2275558..ceddf719d 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/ring0/defs_amd64.go @@ -17,7 +17,6 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/ring0/defs_arm64.go index a014dcbc0..dcb255fc8 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/ring0/defs_arm64.go @@ -17,7 +17,6 @@ package ring0 import ( - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/ring0/entry_amd64.go index d87b1fd00..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/ring0/entry_amd64.go diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/ring0/entry_amd64.s index f59747df3..f59747df3 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/ring0/entry_amd64.s diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/ring0/entry_arm64.go index 62a93f3d6..62a93f3d6 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.go +++ b/pkg/ring0/entry_arm64.go diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/ring0/entry_arm64.s index b2bb18257..b2bb18257 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/ring0/entry_arm64.s diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/ring0/gen_offsets/BUILD index a9703baf6..15b93d61c 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/ring0/gen_offsets/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "defs_impl_arm64", out = "defs_impl_arm64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs_arm64", + template = "//pkg/ring0:defs_arm64", ) go_template_instance( name = "defs_impl_amd64", out = "defs_impl_amd64.go", package = "main", - template = "//pkg/sentry/platform/ring0:defs_amd64", + template = "//pkg/ring0:defs_amd64", ) go_binary( @@ -28,13 +28,13 @@ go_binary( # pass the sentry deps test. system_malloc = True, visibility = [ + "//pkg/ring0:__pkg__", "//pkg/sentry/platform/kvm:__pkg__", - "//pkg/sentry/platform/ring0:__pkg__", ], deps = [ "//pkg/cpuid", + "//pkg/ring0/pagetables", "//pkg/sentry/arch", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/ring0/gen_offsets/main.go b/pkg/ring0/gen_offsets/main.go index a4927da2f..a4927da2f 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/main.go +++ b/pkg/ring0/gen_offsets/main.go diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/ring0/kernel.go index 292f9d0cc..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/ring0/kernel.go diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 36a60700e..36a60700e 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index c05284641..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go diff --git a/pkg/sentry/platform/ring0/kernel_unsafe.go b/pkg/ring0/kernel_unsafe.go index 16955ad91..16955ad91 100644 --- a/pkg/sentry/platform/ring0/kernel_unsafe.go +++ b/pkg/ring0/kernel_unsafe.go diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/ring0/lib_amd64.go index 0ec5c3bc5..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/ring0/lib_amd64.go diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/ring0/lib_amd64.s index 2fe83568a..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/ring0/lib_amd64.s diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go index a490bf3af..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/ring0/lib_arm64.go diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s index e39b32841..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/ring0/lib_arm64.s diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/ring0/offsets_amd64.go index ca4075b09..ca4075b09 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/ring0/offsets_amd64.go diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/ring0/offsets_arm64.go index 164db6d5a..164db6d5a 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/ring0/offsets_arm64.go diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/ring0/pagetables/BUILD index 9e3539e4c..65a978cbb 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/ring0/pagetables/BUILD @@ -9,7 +9,10 @@ package(licenses = ["notice"]) # architecture builds. go_template( name = "generic_walker_%s" % arch, - srcs = ["walker_%s.go" % arch], + srcs = [ + "walker_generic.go", + "walker_%s.go" % arch, + ], opt_types = [ "Visitor", ], @@ -50,6 +53,7 @@ go_library( "pcids_x86.go", "walker_amd64.go", "walker_arm64.go", + "walker_generic.go", ":walker_empty_amd64", ":walker_empty_arm64", ":walker_lookup_amd64", @@ -60,8 +64,8 @@ go_library( ":walker_unmap_arm64", ], visibility = [ + "//pkg/ring0:__subpackages__", "//pkg/sentry/platform/kvm:__subpackages__", - "//pkg/sentry/platform/ring0:__subpackages__", ], deps = [ "//pkg/sync", diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/ring0/pagetables/allocator.go index 8d75b7599..8d75b7599 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator.go +++ b/pkg/ring0/pagetables/allocator.go diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/ring0/pagetables/allocator_unsafe.go index d08bfdeb3..d08bfdeb3 100644 --- a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go +++ b/pkg/ring0/pagetables/allocator_unsafe.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/ring0/pagetables/pagetables.go index 7605d0cb2..8c0a6aa82 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/ring0/pagetables/pagetables.go @@ -60,6 +60,7 @@ type PageTables struct { // Init initializes a set of PageTables. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) Init(allocator Allocator) { p.Allocator = allocator @@ -92,7 +93,6 @@ func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uin } p.InitArch(a) - return p } @@ -112,7 +112,7 @@ type mapVisitor struct { // visit is used for map. // //go:nosplit -func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { p := v.physical + (start - uintptr(v.target)) if pte.Valid() && (pte.Address() != p || pte.Opts() != v.opts) { v.prev = true @@ -122,9 +122,10 @@ func (v *mapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // install a valid entry here, however we must zap any existing // entry to ensure this happens. pte.Clear() - return + return true } pte.Set(p, v.opts) + return true } //go:nosplit @@ -140,7 +141,6 @@ func (*mapVisitor) requiresSplit() bool { return true } // Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { if p.readOnlyShared { @@ -158,9 +158,6 @@ func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physic length = p.upperStart - uintptr(addr) } } - if !opts.AccessType.Any() { - return p.Unmap(addr, length) - } w := mapWalker{ pageTables: p, visitor: mapVisitor{ @@ -187,9 +184,10 @@ func (*unmapVisitor) requiresSplit() bool { return true } // visit unmaps the given entry. // //go:nosplit -func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { pte.Clear() v.count++ + return true } // Unmap unmaps the given range. @@ -199,7 +197,6 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { if p.readOnlyShared { @@ -241,8 +238,9 @@ func (*emptyVisitor) requiresSplit() bool { return false } // visit unmaps the given entry. // //go:nosplit -func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { v.count++ + return true } // IsEmpty checks if the given range is empty. @@ -250,7 +248,6 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) { // Precondition: addr & length must be page-aligned. // // +checkescape:hard,stack -// //go:nosplit func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { w := emptyWalker{ @@ -262,20 +259,28 @@ func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool { // lookupVisitor is used for lookup. type lookupVisitor struct { - target uintptr // Input. - physical uintptr // Output. - opts MapOpts // Output. + target uintptr // Input & Output. + findFirst bool // Input. + physical uintptr // Output. + size uintptr // Output. + opts MapOpts // Output. } // visit matches the given address. // //go:nosplit -func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *lookupVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { if !pte.Valid() { - return + // If looking for the first, then we just keep iterating until + // we find a valid entry. + return v.findFirst } - v.physical = pte.Address() + (start - uintptr(v.target)) + // Is this within the current range? + v.target = start + v.physical = pte.Address() + v.size = (align + 1) v.opts = pte.Opts() + return false } //go:nosplit @@ -286,20 +291,29 @@ func (*lookupVisitor) requiresSplit() bool { return false } // Lookup returns the physical address for the given virtual address. // -// +checkescape:hard,stack +// If findFirst is true, then the next valid address after addr is returned. +// If findFirst is false, then only a mapping for addr will be returned. +// +// Note that if size is zero, then no matching entry was found. // +// +checkescape:hard,stack //go:nosplit -func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) { +func (p *PageTables) Lookup(addr usermem.Addr, findFirst bool) (virtual usermem.Addr, physical, size uintptr, opts MapOpts) { mask := uintptr(usermem.PageSize - 1) - offset := uintptr(addr) & mask + addr &^= usermem.Addr(mask) w := lookupWalker{ pageTables: p, visitor: lookupVisitor{ - target: uintptr(addr &^ usermem.Addr(mask)), + target: uintptr(addr), + findFirst: findFirst, }, } - w.iterateRange(uintptr(addr), uintptr(addr)+1) - return w.visitor.physical + offset, w.visitor.opts + end := ^usermem.Addr(0) &^ usermem.Addr(mask) + if !findFirst { + end = addr + 1 + } + w.iterateRange(uintptr(addr), uintptr(end)) + return usermem.Addr(w.visitor.target), w.visitor.physical, w.visitor.size, w.visitor.opts } // MarkReadOnlyShared marks the pagetables read-only and can be shared. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/ring0/pagetables/pagetables_aarch64.go index 520161755..163a3aea3 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/ring0/pagetables/pagetables_aarch64.go @@ -156,12 +156,7 @@ func (p *PTE) IsSect() bool { // //go:nosplit func (p *PTE) Set(addr uintptr, opts MapOpts) { - if !opts.AccessType.Any() { - p.Clear() - return - } - v := (addr &^ optionMask) | protDefault | nG | readOnly - + v := (addr &^ optionMask) | nG | readOnly | protDefault if p.IsSect() { // Note that this is inherited from the previous instance. Set // does not change the value of Sect. See above. @@ -169,6 +164,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { } else { v |= typePage } + if !opts.AccessType.Any() { + // Leave as non-valid if no access is available. + v &^= pteValid + } if opts.Global { v = v &^ nG diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/ring0/pagetables/pagetables_amd64.go index 4bdde8448..a217f404c 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/ring0/pagetables/pagetables_amd64.go @@ -43,6 +43,7 @@ const ( // InitArch does some additional initialization related to the architecture. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) InitArch(allocator Allocator) { if p.upperSharedPageTables != nil { @@ -50,6 +51,7 @@ func (p *PageTables) InitArch(allocator Allocator) { } } +//go:nosplit func pgdIndex(upperStart uintptr) uintptr { if upperStart&(pgdSize-1) != 0 { panic("upperStart should be pgd size aligned") diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/ring0/pagetables/pagetables_amd64_test.go index 54e8e554f..54e8e554f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go +++ b/pkg/ring0/pagetables/pagetables_amd64_test.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/ring0/pagetables/pagetables_arm64.go index ad0e30c88..fef7a0fd1 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/ring0/pagetables/pagetables_arm64.go @@ -44,6 +44,7 @@ const ( // InitArch does some additional initialization related to the architecture. // +// +checkescape:hard,stack //go:nosplit func (p *PageTables) InitArch(allocator Allocator) { if p.upperSharedPageTables != nil { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/ring0/pagetables/pagetables_arm64_test.go index 2f73d424f..2f73d424f 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go +++ b/pkg/ring0/pagetables/pagetables_arm64_test.go diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/ring0/pagetables/pagetables_test.go index 5c88d087d..772f4fc5e 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go +++ b/pkg/ring0/pagetables/pagetables_test.go @@ -34,7 +34,7 @@ type checkVisitor struct { failed string // Output. } -func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { +func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) bool { v.found = append(v.found, mapping{ start: start, length: align + 1, @@ -43,7 +43,7 @@ func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { }) if v.failed != "" { // Don't keep looking for errors. - return + return false } if v.current >= len(v.expected) { @@ -58,6 +58,7 @@ func (v *checkVisitor) visit(start uintptr, pte *PTE, align uintptr) { v.failed = "opts didn't match" } v.current++ + return true } func (*checkVisitor) requiresAlloc() bool { return false } diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/ring0/pagetables/pagetables_x86.go index 157438d9b..32edd2f0a 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go +++ b/pkg/ring0/pagetables/pagetables_x86.go @@ -137,7 +137,10 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { p.Clear() return } - v := (addr &^ optionMask) | present | accessed + v := (addr &^ optionMask) + if opts.AccessType.Any() { + v |= present | accessed + } if opts.User { v |= user } diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/ring0/pagetables/pcids.go index 964496aac..964496aac 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids.go +++ b/pkg/ring0/pagetables/pcids.go diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go b/pkg/ring0/pagetables/pcids_aarch64.go index fbfd41d83..fbfd41d83 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go +++ b/pkg/ring0/pagetables/pcids_aarch64.go diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s b/pkg/ring0/pagetables/pcids_aarch64.s index e9d62d768..e9d62d768 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s +++ b/pkg/ring0/pagetables/pcids_aarch64.s diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/ring0/pagetables/pcids_x86.go index 91fc5e8dd..91fc5e8dd 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/ring0/pagetables/pcids_x86.go diff --git a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go b/pkg/ring0/pagetables/walker_amd64.go index 8f9dacd93..eb4fbcc31 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_amd64.go +++ b/pkg/ring0/pagetables/walker_amd64.go @@ -16,104 +16,10 @@ package pagetables -// Visitor is a generic type. -type Visitor interface { - // visit is called on each PTE. - visit(start uintptr, pte *PTE, align uintptr) - - // requiresAlloc indicates that new entries should be allocated within - // the walked range. - requiresAlloc() bool - - // requiresSplit indicates that entries in the given range should be - // split if they are huge or jumbo pages. - requiresSplit() bool -} - -// Walker walks page tables. -type Walker struct { - // pageTables are the tables to walk. - pageTables *PageTables - - // Visitor is the set of arguments. - visitor Visitor -} - -// iterateRange iterates over all appropriate levels of page tables for the given range. -// -// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The -// exception is super pages. If a valid super page (huge or jumbo) cannot be -// installed, then the walk will continue to individual entries. -// -// This algorithm will attempt to maximize the use of super pages whenever -// possible. Whether a super page is provided will be clear through the range -// provided in the callback. -// -// Note that if requiresAlloc is true, then no gaps will be present. However, -// if alloc is not set, then the iteration will likely be full of gaps. -// -// Note that this function should generally be avoided in favor of Map, Unmap, -// etc. when not necessary. -// -// Precondition: start must be page-aligned. -// -// Precondition: start must be less than end. -// -// Precondition: If requiresAlloc is true, then start and end should not span -// non-canonical ranges. If they do, a panic will result. -// -//go:nosplit -func (w *Walker) iterateRange(start, end uintptr) { - if start%pteSize != 0 { - panic("unaligned start") - } - if end < start { - panic("start > end") - } - if start < lowerTop { - if end <= lowerTop { - w.iterateRangeCanonical(start, end) - } else if end > lowerTop && end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - w.iterateRangeCanonical(upperBottom, end) - } - } else if start < upperBottom { - if end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(upperBottom, end) - } - } else { - w.iterateRangeCanonical(start, end) - } -} - -// next returns the next address quantized by the given size. -// -//go:nosplit -func next(start uintptr, size uintptr) uintptr { - start &= ^(size - 1) - start += size - return start -} - // iterateRangeCanonical walks a canonical range. // //go:nosplit -func (w *Walker) iterateRangeCanonical(start, end uintptr) { +func (w *Walker) iterateRangeCanonical(start, end uintptr) bool { for pgdIndex := uint16((start & pgdMask) >> pgdShift); start < end && pgdIndex < entriesPerPage; pgdIndex++ { var ( pgdEntry = &w.pageTables.root[pgdIndex] @@ -127,10 +33,10 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pgd. - pudEntries = w.pageTables.Allocator.NewPTEs() + pudEntries = w.pageTables.Allocator.NewPTEs() // escapes: depends on allocator. pgdEntry.setPageTable(w.pageTables, pudEntries) } else { - pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) + pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address()) // escapes: see above. } // Map the next level. @@ -155,7 +61,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // new page for the pmd. if start&(pudSize-1) == 0 && end-start >= pudSize { pudEntry.SetSuper() - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } if pudEntry.Valid() { start = next(start, pudSize) continue @@ -163,14 +71,14 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pud. - pmdEntries = w.pageTables.Allocator.NewPTEs() + pmdEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. pudEntry.setPageTable(w.pageTables, pmdEntries) } else if pudEntry.IsSuper() { // Does this page need to be split? if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) { // Install the relevant entries. - pmdEntries = w.pageTables.Allocator.NewPTEs() + pmdEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. for index := uint16(0); index < entriesPerPage; index++ { pmdEntries[index].SetSuper() pmdEntries[index].Set( @@ -180,7 +88,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pudEntry.setPageTable(w.pageTables, pmdEntries) } else { // A super page to be checked directly. - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start&^(pudSize-1)), pudEntry, pudSize-1) { + return false + } // Might have been cleared. if !pudEntry.Valid() { @@ -192,7 +102,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { continue } } else { - pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) + pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address()) // escapes: see above. } // Map the next level, since this is valid. @@ -216,7 +126,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // As above, we can skip allocating a new page. if start&(pmdSize-1) == 0 && end-start >= pmdSize { pmdEntry.SetSuper() - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } if pmdEntry.Valid() { start = next(start, pmdSize) continue @@ -224,7 +136,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // Allocate a new pmd. - pteEntries = w.pageTables.Allocator.NewPTEs() + pteEntries = w.pageTables.Allocator.NewPTEs() // escapes: see above. pmdEntry.setPageTable(w.pageTables, pteEntries) } else if pmdEntry.IsSuper() { @@ -240,7 +152,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pmdEntry.setPageTable(w.pageTables, pteEntries) } else { // A huge page to be checked directly. - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start&^(pmdSize-1)), pmdEntry, pmdSize-1) { + return false + } // Might have been cleared. if !pmdEntry.Valid() { @@ -252,7 +166,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { continue } } else { - pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) + pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address()) // escapes: see above. } // Map the next level, since this is valid. @@ -269,11 +183,10 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // At this point, we are guaranteed that start%pteSize == 0. - w.visitor.visit(uintptr(start), pteEntry, pteSize-1) - if !pteEntry.Valid() { - if w.visitor.requiresAlloc() { - panic("PTE not set after iteration with requiresAlloc!") - } + if !w.visitor.visit(uintptr(start&^(pteSize-1)), pteEntry, pteSize-1) { + return false + } + if !pteEntry.Valid() && !w.visitor.requiresAlloc() { clearPTEEntries++ } @@ -285,7 +198,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPTEEntries == entriesPerPage { pmdEntry.Clear() - w.pageTables.Allocator.FreePTEs(pteEntries) + w.pageTables.Allocator.FreePTEs(pteEntries) // escapes: see above. clearPMDEntries++ } } @@ -293,7 +206,7 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPMDEntries == entriesPerPage { pudEntry.Clear() - w.pageTables.Allocator.FreePTEs(pmdEntries) + w.pageTables.Allocator.FreePTEs(pmdEntries) // escapes: see above. clearPUDEntries++ } } @@ -301,7 +214,8 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // Check if we no longer need this page. if clearPUDEntries == entriesPerPage { pgdEntry.Clear() - w.pageTables.Allocator.FreePTEs(pudEntries) + w.pageTables.Allocator.FreePTEs(pudEntries) // escapes: see above. } } + return true } diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/ring0/pagetables/walker_arm64.go index c261d393a..5ed881c7a 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/ring0/pagetables/walker_arm64.go @@ -16,104 +16,10 @@ package pagetables -// Visitor is a generic type. -type Visitor interface { - // visit is called on each PTE. - visit(start uintptr, pte *PTE, align uintptr) - - // requiresAlloc indicates that new entries should be allocated within - // the walked range. - requiresAlloc() bool - - // requiresSplit indicates that entries in the given range should be - // split if they are huge or jumbo pages. - requiresSplit() bool -} - -// Walker walks page tables. -type Walker struct { - // pageTables are the tables to walk. - pageTables *PageTables - - // Visitor is the set of arguments. - visitor Visitor -} - -// iterateRange iterates over all appropriate levels of page tables for the given range. -// -// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The -// exception is sect pages. If a valid sect page (huge or jumbo) cannot be -// installed, then the walk will continue to individual entries. -// -// This algorithm will attempt to maximize the use of sect pages whenever -// possible. Whether a sect page is provided will be clear through the range -// provided in the callback. -// -// Note that if requiresAlloc is true, then no gaps will be present. However, -// if alloc is not set, then the iteration will likely be full of gaps. -// -// Note that this function should generally be avoided in favor of Map, Unmap, -// etc. when not necessary. -// -// Precondition: start must be page-aligned. -// -// Precondition: start must be less than end. -// -// Precondition: If requiresAlloc is true, then start and end should not span -// non-canonical ranges. If they do, a panic will result. -// -//go:nosplit -func (w *Walker) iterateRange(start, end uintptr) { - if start%pteSize != 0 { - panic("unaligned start") - } - if end < start { - panic("start > end") - } - if start < lowerTop { - if end <= lowerTop { - w.iterateRangeCanonical(start, end) - } else if end > lowerTop && end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(start, lowerTop) - w.iterateRangeCanonical(upperBottom, end) - } - } else if start < upperBottom { - if end <= upperBottom { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - } else { - if w.visitor.requiresAlloc() { - panic("alloc spans non-canonical range") - } - w.iterateRangeCanonical(upperBottom, end) - } - } else { - w.iterateRangeCanonical(start, end) - } -} - -// next returns the next address quantized by the given size. -// -//go:nosplit -func next(start uintptr, size uintptr) uintptr { - start &= ^(size - 1) - start += size - return start -} - // iterateRangeCanonical walks a canonical range. // //go:nosplit -func (w *Walker) iterateRangeCanonical(start, end uintptr) { +func (w *Walker) iterateRangeCanonical(start, end uintptr) bool { pgdEntryIndex := w.pageTables.root if start >= upperBottom { pgdEntryIndex = w.pageTables.archPageTables.root @@ -160,7 +66,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // new page for the pmd. if start&(pudSize-1) == 0 && end-start >= pudSize { pudEntry.SetSect() - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } if pudEntry.Valid() { start = next(start, pudSize) continue @@ -185,7 +93,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pudEntry.setPageTable(w.pageTables, pmdEntries) } else { // A sect page to be checked directly. - w.visitor.visit(uintptr(start), pudEntry, pudSize-1) + if !w.visitor.visit(uintptr(start), pudEntry, pudSize-1) { + return false + } // Might have been cleared. if !pudEntry.Valid() { @@ -222,7 +132,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { // As above, we can skip allocating a new page. if start&(pmdSize-1) == 0 && end-start >= pmdSize { pmdEntry.SetSect() - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } if pmdEntry.Valid() { start = next(start, pmdSize) continue @@ -246,7 +158,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { pmdEntry.setPageTable(w.pageTables, pteEntries) } else { // A huge page to be checked directly. - w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) + if !w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1) { + return false + } // Might have been cleared. if !pmdEntry.Valid() { @@ -276,7 +190,9 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { } // At this point, we are guaranteed that start%pteSize == 0. - w.visitor.visit(uintptr(start), pteEntry, pteSize-1) + if !w.visitor.visit(uintptr(start), pteEntry, pteSize-1) { + return false + } if !pteEntry.Valid() { if w.visitor.requiresAlloc() { panic("PTE not set after iteration with requiresAlloc!") @@ -311,4 +227,5 @@ func (w *Walker) iterateRangeCanonical(start, end uintptr) { w.pageTables.Allocator.FreePTEs(pudEntries) } } + return true } diff --git a/pkg/ring0/pagetables/walker_generic.go b/pkg/ring0/pagetables/walker_generic.go new file mode 100644 index 000000000..34fba7b84 --- /dev/null +++ b/pkg/ring0/pagetables/walker_generic.go @@ -0,0 +1,110 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pagetables + +// Visitor is a generic type. +type Visitor interface { + // visit is called on each PTE. The returned boolean indicates whether + // the walk should continue. + visit(start uintptr, pte *PTE, align uintptr) bool + + // requiresAlloc indicates that new entries should be allocated within + // the walked range. + requiresAlloc() bool + + // requiresSplit indicates that entries in the given range should be + // split if they are huge or jumbo pages. + requiresSplit() bool +} + +// Walker walks page tables. +type Walker struct { + // pageTables are the tables to walk. + pageTables *PageTables + + // Visitor is the set of arguments. + visitor Visitor +} + +// iterateRange iterates over all appropriate levels of page tables for the given range. +// +// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The +// exception is super pages. If a valid super page (huge or jumbo) cannot be +// installed, then the walk will continue to individual entries. +// +// This algorithm will attempt to maximize the use of super/sect pages whenever +// possible. Whether a super page is provided will be clear through the range +// provided in the callback. +// +// Note that if requiresAlloc is true, then no gaps will be present. However, +// if alloc is not set, then the iteration will likely be full of gaps. +// +// Note that this function should generally be avoided in favor of Map, Unmap, +// etc. when not necessary. +// +// Precondition: start must be page-aligned. +// Precondition: start must be less than end. +// Precondition: If requiresAlloc is true, then start and end should not span +// non-canonical ranges. If they do, a panic will result. +// +//go:nosplit +func (w *Walker) iterateRange(start, end uintptr) { + if start%pteSize != 0 { + panic("unaligned start") + } + if end < start { + panic("start > end") + } + if start < lowerTop { + if end <= lowerTop { + w.iterateRangeCanonical(start, end) + } else if end > lowerTop && end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(start, lowerTop) + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + if !w.iterateRangeCanonical(start, lowerTop) { + return + } + w.iterateRangeCanonical(upperBottom, end) + } + } else if start < upperBottom { + if end <= upperBottom { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + } else { + if w.visitor.requiresAlloc() { + panic("alloc spans non-canonical range") + } + w.iterateRangeCanonical(upperBottom, end) + } + } else { + w.iterateRangeCanonical(start, end) + } +} + +// next returns the next address quantized by the given size. +// +//go:nosplit +func next(start uintptr, size uintptr) uintptr { + start &= ^(size - 1) + start += size + return start +} diff --git a/pkg/sentry/platform/ring0/ring0.go b/pkg/ring0/ring0.go index cdeb1b43a..cdeb1b43a 100644 --- a/pkg/sentry/platform/ring0/ring0.go +++ b/pkg/ring0/ring0.go diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/ring0/x86.go index 34fbc1c35..34fbc1c35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/ring0/x86.go diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index 1d88db12f..de7a0f3ab 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -404,3 +404,16 @@ func ttyName(tty *kernel.TTY) string { } return fmt.Sprintf("pts/%d", tty.Index) } + +// ContainerUsage retrieves per-container CPU usage. +func ContainerUsage(kr *kernel.Kernel) map[string]uint64 { + cusage := make(map[string]uint64) + for _, tg := range kr.TaskSet().Root.ThreadGroups() { + // We want each tg's usage including reaped children. + cid := tg.Leader().ContainerID() + stats := tg.CPUStats() + stats.Accumulate(tg.JoinedChildCPUStats()) + cusage[cid] += uint64(stats.UserTime.Nanoseconds()) + uint64(stats.SysTime.Nanoseconds()) + } + return cusage +} diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index a2f3d5918..07b4fb70f 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -257,7 +257,7 @@ func (c *ConnectedEndpoint) Passcred() bool { } // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. -func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil } diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 089955a96..ae972fcb5 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -299,10 +299,15 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off src = src.TakeFirst64(limit) } - // Do a buffered write. See rationale in PRead. if d.cachedMetadataAuthoritative() { - d.touchCMtime() + if fd.isRegularFile { + d.touchCMtimeLocked() + } else { + d.touchCMtime() + } } + + // Do a buffered write. See rationale in PRead. buf := make([]byte, src.NumBytes()) copied, copyErr := src.CopyIn(ctx, buf) if copied == 0 && copyErr != nil { diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 60acc367f..72aa535f8 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -201,7 +201,7 @@ func (c *ConnectedEndpoint) Passcred() bool { } // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress. -func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index eac578f25..8139bff76 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -371,6 +371,8 @@ type OrderedChildrenOptions struct { // OrderedChildren may modify the tracked children. This applies to // operations related to rename, unlink and rmdir. If an OrderedChildren is // not writable, these operations all fail with EPERM. + // + // Note that writable users must implement the sticky bit (I_SVTX). Writable bool } @@ -556,7 +558,6 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child Inode) return err } - // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. o.removeLocked(name) return nil } @@ -603,8 +604,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c if err := o.checkExistingLocked(oldname, child); err != nil { return err } + o.removeLocked(oldname) - // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. dst.replaceChildLocked(ctx, newname, child) return nil } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 3b6336e94..09c0ccaf2 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -368,17 +368,15 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst }) } -// CopyOutFrom implements usermem.IO.CopyOutFrom. +// CopyOutFrom implements usermem.IO.CopyOutFrom. Note that it is the caller's +// responsibility to call fd.pipe.Notify(waiter.EventIn) after the write is +// completed. // // Preconditions: fd.pipe.mu must be locked. func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) { - n, err := fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { + return fd.pipe.writeLocked(ars.NumBytes(), func(dsts safemem.BlockSeq) (uint64, error) { return src.ReadToBlocks(dsts) }) - if n > 0 { - fd.pipe.Notify(waiter.EventIn) - } - return n, err } // SwapUint32 implements usermem.IO.SwapUint32. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 8ce411102..b3290917e 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -45,14 +45,14 @@ go_library( "//pkg/cpuid", "//pkg/log", "//pkg/procid", + "//pkg/ring0", + "//pkg/ring0/pagetables", "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/sentry/time", "//pkg/sync", "//pkg/usermem", @@ -75,11 +75,11 @@ go_test( "requires-kvm", ], deps = [ + "//pkg/ring0", + "//pkg/ring0/pagetables", "//pkg/sentry/arch", "//pkg/sentry/platform", "//pkg/sentry/platform/kvm/testutil", - "//pkg/sentry/platform/ring0", - "//pkg/sentry/platform/ring0/pagetables", "//pkg/sentry/time", "//pkg/usermem", ], @@ -89,6 +89,6 @@ genrule( name = "bluepill_impl_amd64", srcs = ["bluepill_amd64.s"], outs = ["bluepill_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", - tools = ["//pkg/sentry/platform/ring0/gen_offsets"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/ring0/gen_offsets"], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index af5c5e191..25c21e843 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,9 +18,9 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 4b23f7803..2c970162e 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -19,9 +19,9 @@ import ( "reflect" "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // bluepill enters guest mode. diff --git a/pkg/sentry/platform/kvm/bluepill_allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go index 9485e1301..1825edc3a 100644 --- a/pkg/sentry/platform/kvm/bluepill_allocator.go +++ b/pkg/sentry/platform/kvm/bluepill_allocator.go @@ -17,7 +17,7 @@ package kvm import ( "fmt" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/ring0/pagetables" ) type allocator struct { diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go index ddc1554d5..83a4766fb 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64.go @@ -19,8 +19,8 @@ package kvm import ( "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) var ( diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index f8ccb7430..0063e947b 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -20,8 +20,8 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // dieArchSetup initializes the state for dieTrampoline. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 1f09813ba..35298135a 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -19,8 +19,8 @@ package kvm import ( "syscall" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) var ( diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 4d912769a..dbbf2a897 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -20,8 +20,8 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) // fpsimdPtr returns a fpsimd64 for the given address. diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go index 17268d127..aeae01dbd 100644 --- a/pkg/sentry/platform/kvm/context.go +++ b/pkg/sentry/platform/kvm/context.go @@ -18,10 +18,10 @@ import ( "sync/atomic" pkgcontext "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index 5979aef97..7bdf57436 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -20,9 +20,9 @@ import ( "os" "syscall" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go index 093497bc4..b9ed4a706 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64.go +++ b/pkg/sentry/platform/kvm/kvm_amd64.go @@ -18,7 +18,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/cpuid" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/ring0" ) // userRegs represents KVM user registers. diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index c0b4fd374..76fc594a0 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -19,11 +19,11 @@ package kvm import ( "testing" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) func TestSegments(t *testing.T) { diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 9db1db4e9..b73340f0e 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -17,8 +17,8 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) type kvmOneReg struct { diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index a650877d6..11ca1f0ea 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index e2fffc99b..1ece1b8d8 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 8e03c310d..59c752d73 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -24,10 +24,10 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index aa2d21748..7d7857067 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -17,10 +17,10 @@ package kvm import ( + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a466acf4d..dca0cdb60 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -23,10 +23,10 @@ import ( "syscall" "unsafe" + "gvisor.dev/gvisor/pkg/ring0" + "gvisor.dev/gvisor/pkg/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go index f7fa2f98d..8bdec93ae 100644 --- a/pkg/sentry/platform/kvm/physical_map.go +++ b/pkg/sentry/platform/kvm/physical_map.go @@ -20,7 +20,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/ring0" "gvisor.dev/gvisor/pkg/usermem" ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 7065a0e46..69693f263 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -251,11 +251,11 @@ var errStackType = syserr.New("expected but did not receive a netstack.Stack", l type commonEndpoint interface { // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and // transport.Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and // transport.Endpoint.GetRemoteAddress. - GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) // Readiness implements tcpip.Endpoint.Readiness and // transport.Endpoint.Readiness. @@ -263,19 +263,19 @@ type commonEndpoint interface { // SetSockOpt implements tcpip.Endpoint.SetSockOpt and // transport.Endpoint.SetSockOpt. - SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error + SetSockOpt(tcpip.SettableSocketOption) tcpip.Error // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. - GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error + GetSockOpt(tcpip.GettableSocketOption) tcpip.Error // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -283,7 +283,7 @@ type commonEndpoint interface { // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. - LastError() *tcpip.Error + LastError() tcpip.Error // SocketOptions implements tcpip.Endpoint.SocketOptions and // transport.Endpoint.SocketOptions. @@ -442,7 +442,7 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { return 0, syserror.ErrWouldBlock } if err != nil { @@ -459,17 +459,24 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO var _ tcpip.Payloader = (*limitedPayloader)(nil) type limitedPayloader struct { - io.LimitedReader + inner io.LimitedReader + err error } -func (l limitedPayloader) Len() int { - return int(l.N) +func (l *limitedPayloader) Read(p []byte) (int, error) { + n, err := l.inner.Read(p) + l.err = err + return n, err +} + +func (l *limitedPayloader) Len() int { + return int(l.inner.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := limitedPayloader{ - LimitedReader: io.LimitedReader{ + inner: io.LimitedReader{ R: r, N: count, }, @@ -479,8 +486,8 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrBadBuffer { - err = nil + if _, ok := err.(*tcpip.ErrBadBuffer); ok { + return n, f.err } return n, syserr.TranslateNetstackError(err).ToError() } @@ -526,7 +533,7 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool if family == linux.AF_UNSPEC { err := s.Endpoint.Disconnect() - if err == tcpip.ErrNotSupported { + if _, ok := err.(*tcpip.ErrNotSupported); ok { return syserr.ErrAddressFamilyNotSupported } return syserr.TranslateNetstackError(err) @@ -548,15 +555,16 @@ func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + switch err := s.Endpoint.Connect(addr); err.(type) { + case *tcpip.ErrConnectStarted, *tcpip.ErrAlreadyConnecting: + case *tcpip.ErrNoPortAvailable: if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { // TCP unlike UDP returns EADDRNOTAVAIL when it can't // find an available local ephemeral port. - if err == tcpip.ErrNoPortAvailable { - return syserr.ErrAddressNotAvailable - } + return syserr.ErrAddressNotAvailable } - + return syserr.TranslateNetstackError(err) + default: return syserr.TranslateNetstackError(err) } @@ -614,16 +622,16 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Issue the bind request to the endpoint. err := s.Endpoint.Bind(addr) - if err == tcpip.ErrNoPortAvailable { + if _, ok := err.(*tcpip.ErrNoPortAvailable); ok { // Bind always returns EADDRINUSE irrespective of if the specified port was // already bound or if an ephemeral port was requested but none were // available. // - // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because + // *tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because // UDP connect returns EAGAIN on ephemeral port exhaustion. // // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion. - err = tcpip.ErrPortInUse + err = &tcpip.ErrPortInUse{} } return syserr.TranslateNetstackError(err) @@ -646,7 +654,8 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAdd // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { + ep, wq, err := s.Endpoint.Accept(peerAddr) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { return ep, wq, syserr.TranslateNetstackError(err) } @@ -665,7 +674,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { - if terr != tcpip.ErrWouldBlock || !blocking { + if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } @@ -1098,6 +1107,29 @@ func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, // TODO(b/64800844): Translate fields once they are added to // tcpip.TCPInfoOption. info := linux.TCPInfo{} + switch v.CcState { + case tcpip.RTORecovery: + info.CaState = linux.TCP_CA_Loss + case tcpip.FastRecovery, tcpip.SACKRecovery: + info.CaState = linux.TCP_CA_Recovery + case tcpip.Disorder: + info.CaState = linux.TCP_CA_Disorder + case tcpip.Open: + info.CaState = linux.TCP_CA_Open + } + info.RTO = uint32(v.RTO / time.Microsecond) + info.RTT = uint32(v.RTT / time.Microsecond) + info.RTTVar = uint32(v.RTTVar / time.Microsecond) + info.SndSsthresh = v.SndSsthresh + info.SndCwnd = v.SndCwnd + + // In netstack reorderSeen is updated only when RACK is enabled. + // We only track whether the reordering is seen, which is + // different than Linux where reorderSeen is not specific to + // RACK and is incremented when a reordering event is seen. + if v.ReorderSeen { + info.ReordSeen = 1 + } // Linux truncates the output binary to outLen. buf := t.CopyScratchBuffer(info.SizeBytes()) @@ -2534,7 +2566,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq defer s.readMu.Unlock() res, err := s.Endpoint.Read(w, readOptions) - if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 { + if _, ok := err.(*tcpip.ErrBadBuffer); ok && dst.NumBytes() == 0 { err = nil } if err != nil { @@ -2634,9 +2666,9 @@ func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { } // Update socket error to reflect ICMP errors in queue. - if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + if nextErr := so.PeekErr(); nextErr != nil && nextErr.Cause.Origin().IsICMPErr() { so.SetLastError(nextErr.Err) - } else if err.ErrOrigin.IsICMPErr() { + } else if err.Cause.Origin().IsICMPErr() { so.SetLastError(nil) } return err @@ -2790,13 +2822,15 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b if flags&linux.MSG_DONTWAIT != 0 { return int(total), syserr.TranslateNetstackError(err) } - switch err { + block := true + switch err.(type) { case nil: - if total == src.NumBytes() { - break - } - fallthrough - case tcpip.ErrWouldBlock: + block = total != src.NumBytes() + case *tcpip.ErrWouldBlock: + default: + block = false + } + if block { if ch == nil { // We'll have to block. Register for notification and keep trying to // send all the data. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 3bbdf552e..24922c400 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -130,7 +130,7 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs r := src.Reader(ctx) n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { return 0, syserror.ErrWouldBlock } if err != nil { @@ -154,7 +154,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block } ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { - if terr != tcpip.ErrWouldBlock || !blocking { + if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index c847ff1c7..2515dda80 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -118,7 +118,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* // Create the endpoint. var ep tcpip.Endpoint - var e *tcpip.Error + var e tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 0af805246..ba1cc79e9 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -62,7 +62,7 @@ func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int // Create the endpoint. var ep tcpip.Endpoint - var e *tcpip.Error + var e tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 97729dacc..cc535d794 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -81,10 +81,10 @@ func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { ee := linux.SockExtendedErr{ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), - Origin: errOriginToLinux(sockErr.ErrOrigin), - Type: sockErr.ErrType, - Code: sockErr.ErrCode, - Info: sockErr.ErrInfo, + Origin: errOriginToLinux(sockErr.Cause.Origin()), + Type: sockErr.Cause.Type(), + Code: sockErr.Cause.Code(), + Info: sockErr.Cause.Info(), } switch sockErr.NetProto { diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index b011082dc..fc5b823b0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -48,7 +48,7 @@ type ConnectingEndpoint interface { Type() linux.SockType // GetLocalAddress returns the bound path. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Locker protects the following methods. While locked, only the holder of // the lock can change the return value of the protected methods. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0e3889c6d..70227bbd2 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -169,32 +169,32 @@ type Endpoint interface { Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // GetRemoteAddress returns the address to which the endpoint is // connected. - GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) // SetSockOpt sets a socket option. - SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error + SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error // GetSockOpt gets a socket option. - GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error + GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 // LastError clears and returns the last error reported by the endpoint. - LastError() *tcpip.Error + LastError() tcpip.Error // SocketOptions returns the structure which contains all the socket // level options. @@ -580,7 +580,7 @@ type ConnectedEndpoint interface { Passcred() bool // GetLocalAddress implements Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Send sends a single message. This method does not block. // @@ -640,7 +640,7 @@ type connectedEndpoint struct { Passcred() bool // GetLocalAddress implements Endpoint.GetLocalAddress. - GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + GetLocalAddress() (tcpip.FullAddress, tcpip.Error) // Type implements Endpoint.Type. Type() linux.SockType @@ -655,7 +655,7 @@ func (e *connectedEndpoint) Passcred() bool { } // GetLocalAddress implements ConnectedEndpoint.GetLocalAddress. -func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return e.endpoint.GetLocalAddress() } @@ -836,11 +836,11 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess } // SetSockOpt sets a socket option. -func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: default: @@ -855,34 +855,34 @@ func (e *baseEndpoint) IsUnixSocket() bool { } // GetSendBufferSize implements tcpip.SocketOptionsHandler.GetSendBufferSize. -func (e *baseEndpoint) GetSendBufferSize() (int64, *tcpip.Error) { +func (e *baseEndpoint) GetSendBufferSize() (int64, tcpip.Error) { e.Lock() defer e.Unlock() if !e.Connected() { - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v := e.connected.SendMaxQueueSize() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return v, nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 e.Lock() if !e.Connected() { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v = int(e.receiver.RecvQueuedSize()) e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return v, nil @@ -890,12 +890,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.Lock() if !e.Connected() { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v := e.connected.SendQueuedSize() e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return int(v), nil @@ -903,29 +903,29 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.Lock() if e.receiver == nil { e.Unlock() - return -1, tcpip.ErrNotConnected + return -1, &tcpip.ErrNotConnected{} } v := e.receiver.RecvMaxQueueSize() e.Unlock() if v < 0 { - return -1, tcpip.ErrQueueSizeNotSupported + return -1, &tcpip.ErrQueueSizeNotSupported{} } return int(v), nil default: log.Warningf("Unsupported socket option: %d", opt) - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } // LastError implements Endpoint.LastError. -func (*baseEndpoint) LastError() *tcpip.Error { +func (*baseEndpoint) LastError() tcpip.Error { return nil } @@ -965,7 +965,7 @@ func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { } // GetLocalAddress returns the bound path. -func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.Lock() defer e.Unlock() return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil @@ -973,14 +973,14 @@ func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { // GetRemoteAddress returns the local address of the connected endpoint (if // available). -func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.Lock() c := e.connected e.Unlock() if c != nil { return c.GetLocalAddress() } - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Release implements BoundEndpoint.Release. diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index dab6207c0..d1778d029 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -134,8 +134,8 @@ func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op s // Similar to EPIPE. Return what we wrote this time, and let // ENOSPC be returned on the next call. return true, nil - case syserror.ECONNRESET: - // For TCP sendfile connections, we may have a reset. But we + case syserror.ECONNRESET, syserror.ETIMEDOUT: + // For TCP sendfile connections, we may have a reset or timeout. But we // should just return n as the result. return true, nil case syserror.ErrWouldBlock: diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index e39f074f2..1a31898e8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -123,6 +123,15 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + switch cmd { + case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC, linux.F_GETFD, linux.F_SETFD, linux.F_GETFL: + // allowed + default: + return 0, nil, syserror.EBADF + } + } + switch cmd { case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC: minfd := args[2].Int() @@ -395,6 +404,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + // If the FD refers to a pipe or FIFO, return error. if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { return 0, nil, syserror.ESPIPE diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 20c264fef..c7c3fed57 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -32,6 +32,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + // Handle ioctls that apply to all FDs. switch args[1].Int() { case linux.FIONCLEX: diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 6e9b599e2..1f8a5878c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -36,6 +36,10 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } defer file.DecRef(t) + if file.StatusFlags()&linux.O_PATH != 0 { + return 0, nil, syserror.EBADF + } + return 0, nil, file.SyncFS(t) } diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index a3868bf16..df4990854 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -83,6 +83,7 @@ go_library( "mount.go", "mount_namespace_refs.go", "mount_unsafe.go", + "opath.go", "options.go", "pathname.go", "permissions.go", diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go new file mode 100644 index 000000000..39fbac987 --- /dev/null +++ b/pkg/sentry/vfs/opath.go @@ -0,0 +1,139 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// opathFD implements vfs.FileDescriptionImpl for a file description opened with O_PATH. +// +// +stateify savable +type opathFD struct { + vfsfd FileDescription + FileDescriptionDefaultImpl + NoLockFD +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *opathFD) Release(context.Context) { + // noop +} + +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *opathFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EBADF +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *opathFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *opathFD) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) { + return 0, syserror.EBADF +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *opathFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *opathFD) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) { + return 0, syserror.EBADF +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *opathFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return 0, syserror.EBADF +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *opathFD) IterDirents(ctx context.Context, cb IterDirentsCallback) error { + return syserror.EBADF +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *opathFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, syserror.EBADF +} + +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *opathFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return syserror.EBADF +} + +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *opathFD) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return nil, syserror.EBADF +} + +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *opathFD) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { + return "", syserror.EBADF +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *opathFD) SetXattr(ctx context.Context, opts SetXattrOptions) error { + return syserror.EBADF +} + +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *opathFD) RemoveXattr(ctx context.Context, name string) error { + return syserror.EBADF +} + +// Sync implements vfs.FileDescriptionImpl.Sync. +func (fd *opathFD) Sync(ctx context.Context) error { + return syserror.EBADF +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *opathFD) SetStat(ctx context.Context, opts SetStatOptions) error { + return syserror.EBADF +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) { + vfsObj := fd.vfsfd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vfsfd.vd, + Start: fd.vfsfd.vd, + }) + stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) + vfsObj.putResolvingPath(ctx, rp) + return stat, err +} + +// StatFS returns metadata for the filesystem containing the file represented +// by fd. +func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { + vfsObj := fd.vfsfd.vd.mount.vfs + rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ + Root: fd.vfsfd.vd, + Start: fd.vfsfd.vd, + }) + statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) + vfsObj.putResolvingPath(ctx, rp) + return statfs, err +} diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index bc79e5ecc..c9907843c 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -129,7 +129,7 @@ type OpenOptions struct { // // FilesystemImpls are responsible for implementing the following flags: // O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC, - // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and + // O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_SYNC, O_TMPFILE, and // O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and // O_NOFOLLOW. VFS users are responsible for handling O_CLOEXEC, since file // descriptors are mostly outside the scope of VFS. diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 6fd1bb0b2..0aff2dd92 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -425,6 +425,18 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp.mustBeDir = true rp.mustBeDirOrig = true } + if opts.Flags&linux.O_PATH != 0 { + vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{}) + if err != nil { + return nil, err + } + fd := &opathFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{}); err != nil { + return nil, err + } + vd.DecRef(ctx) + return &fd.vfsfd, err + } for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 28e62abbb..2e2395807 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -67,6 +67,7 @@ go_library( ], marshal = False, stateify = False, + visibility = ["//:sandbox"], deps = [ "//pkg/goid", ], diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index cb8981633..a6a91e064 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -23,95 +23,114 @@ import ( // Mapping for tcpip.Error types. var ( - ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) - ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) - ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) - ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) - ErrDuplicateAddress = New(tcpip.ErrDuplicateAddress.String(), linux.EEXIST) - ErrBadLinkEndpoint = New(tcpip.ErrBadLinkEndpoint.String(), linux.EINVAL) - ErrAlreadyBound = New(tcpip.ErrAlreadyBound.String(), linux.EINVAL) - ErrInvalidEndpointState = New(tcpip.ErrInvalidEndpointState.String(), linux.EINVAL) - ErrAlreadyConnecting = New(tcpip.ErrAlreadyConnecting.String(), linux.EALREADY) - ErrNoPortAvailable = New(tcpip.ErrNoPortAvailable.String(), linux.EAGAIN) - ErrPortInUse = New(tcpip.ErrPortInUse.String(), linux.EADDRINUSE) - ErrBadLocalAddress = New(tcpip.ErrBadLocalAddress.String(), linux.EADDRNOTAVAIL) - ErrClosedForSend = New(tcpip.ErrClosedForSend.String(), linux.EPIPE) - ErrClosedForReceive = New(tcpip.ErrClosedForReceive.String(), nil) - ErrTimeout = New(tcpip.ErrTimeout.String(), linux.ETIMEDOUT) - ErrAborted = New(tcpip.ErrAborted.String(), linux.EPIPE) - ErrConnectStarted = New(tcpip.ErrConnectStarted.String(), linux.EINPROGRESS) - ErrDestinationRequired = New(tcpip.ErrDestinationRequired.String(), linux.EDESTADDRREQ) - ErrNotSupported = New(tcpip.ErrNotSupported.String(), linux.EOPNOTSUPP) - ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY) - ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT) - ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL) - ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES) - ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) - ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT) + ErrUnknownProtocol = New((&tcpip.ErrUnknownProtocol{}).String(), linux.EINVAL) + ErrUnknownNICID = New((&tcpip.ErrUnknownNICID{}).String(), linux.ENODEV) + ErrUnknownDevice = New((&tcpip.ErrUnknownDevice{}).String(), linux.ENODEV) + ErrUnknownProtocolOption = New((&tcpip.ErrUnknownProtocolOption{}).String(), linux.ENOPROTOOPT) + ErrDuplicateNICID = New((&tcpip.ErrDuplicateNICID{}).String(), linux.EEXIST) + ErrDuplicateAddress = New((&tcpip.ErrDuplicateAddress{}).String(), linux.EEXIST) + ErrAlreadyBound = New((&tcpip.ErrAlreadyBound{}).String(), linux.EINVAL) + ErrInvalidEndpointState = New((&tcpip.ErrInvalidEndpointState{}).String(), linux.EINVAL) + ErrAlreadyConnecting = New((&tcpip.ErrAlreadyConnecting{}).String(), linux.EALREADY) + ErrNoPortAvailable = New((&tcpip.ErrNoPortAvailable{}).String(), linux.EAGAIN) + ErrPortInUse = New((&tcpip.ErrPortInUse{}).String(), linux.EADDRINUSE) + ErrBadLocalAddress = New((&tcpip.ErrBadLocalAddress{}).String(), linux.EADDRNOTAVAIL) + ErrClosedForSend = New((&tcpip.ErrClosedForSend{}).String(), linux.EPIPE) + ErrClosedForReceive = New((&tcpip.ErrClosedForReceive{}).String(), nil) + ErrTimeout = New((&tcpip.ErrTimeout{}).String(), linux.ETIMEDOUT) + ErrAborted = New((&tcpip.ErrAborted{}).String(), linux.EPIPE) + ErrConnectStarted = New((&tcpip.ErrConnectStarted{}).String(), linux.EINPROGRESS) + ErrDestinationRequired = New((&tcpip.ErrDestinationRequired{}).String(), linux.EDESTADDRREQ) + ErrNotSupported = New((&tcpip.ErrNotSupported{}).String(), linux.EOPNOTSUPP) + ErrQueueSizeNotSupported = New((&tcpip.ErrQueueSizeNotSupported{}).String(), linux.ENOTTY) + ErrNoSuchFile = New((&tcpip.ErrNoSuchFile{}).String(), linux.ENOENT) + ErrInvalidOptionValue = New((&tcpip.ErrInvalidOptionValue{}).String(), linux.EINVAL) + ErrBroadcastDisabled = New((&tcpip.ErrBroadcastDisabled{}).String(), linux.EACCES) + ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) + ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) ) -var netstackErrorTranslations map[string]*Error - -func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) { - key := tcpipErr.String() - if _, ok := netstackErrorTranslations[key]; ok { - panic(fmt.Sprintf("duplicate error key: %s", key)) - } - netstackErrorTranslations[key] = netstackErr -} - -func init() { - netstackErrorTranslations = make(map[string]*Error) - addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol) - addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID) - addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice) - addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption) - addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID) - addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress) - addErrMapping(tcpip.ErrNoRoute, ErrNoRoute) - addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint) - addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound) - addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState) - addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting) - addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected) - addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable) - addErrMapping(tcpip.ErrPortInUse, ErrPortInUse) - addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress) - addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend) - addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive) - addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock) - addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused) - addErrMapping(tcpip.ErrTimeout, ErrTimeout) - addErrMapping(tcpip.ErrAborted, ErrAborted) - addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted) - addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired) - addErrMapping(tcpip.ErrNotSupported, ErrNotSupported) - addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported) - addErrMapping(tcpip.ErrNotConnected, ErrNotConnected) - addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset) - addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) - addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) - addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) - addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) - addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) - addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) - addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace) - addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) - addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) - addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) - addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer) -} - // TranslateNetstackError converts an error from the tcpip package to a sentry // internal error. -func TranslateNetstackError(err *tcpip.Error) *Error { - if err == nil { +func TranslateNetstackError(err tcpip.Error) *Error { + switch err.(type) { + case nil: return nil + case *tcpip.ErrUnknownProtocol: + return ErrUnknownProtocol + case *tcpip.ErrUnknownNICID: + return ErrUnknownNICID + case *tcpip.ErrUnknownDevice: + return ErrUnknownDevice + case *tcpip.ErrUnknownProtocolOption: + return ErrUnknownProtocolOption + case *tcpip.ErrDuplicateNICID: + return ErrDuplicateNICID + case *tcpip.ErrDuplicateAddress: + return ErrDuplicateAddress + case *tcpip.ErrNoRoute: + return ErrNoRoute + case *tcpip.ErrAlreadyBound: + return ErrAlreadyBound + case *tcpip.ErrInvalidEndpointState: + return ErrInvalidEndpointState + case *tcpip.ErrAlreadyConnecting: + return ErrAlreadyConnecting + case *tcpip.ErrAlreadyConnected: + return ErrAlreadyConnected + case *tcpip.ErrNoPortAvailable: + return ErrNoPortAvailable + case *tcpip.ErrPortInUse: + return ErrPortInUse + case *tcpip.ErrBadLocalAddress: + return ErrBadLocalAddress + case *tcpip.ErrClosedForSend: + return ErrClosedForSend + case *tcpip.ErrClosedForReceive: + return ErrClosedForReceive + case *tcpip.ErrWouldBlock: + return ErrWouldBlock + case *tcpip.ErrConnectionRefused: + return ErrConnectionRefused + case *tcpip.ErrTimeout: + return ErrTimeout + case *tcpip.ErrAborted: + return ErrAborted + case *tcpip.ErrConnectStarted: + return ErrConnectStarted + case *tcpip.ErrDestinationRequired: + return ErrDestinationRequired + case *tcpip.ErrNotSupported: + return ErrNotSupported + case *tcpip.ErrQueueSizeNotSupported: + return ErrQueueSizeNotSupported + case *tcpip.ErrNotConnected: + return ErrNotConnected + case *tcpip.ErrConnectionReset: + return ErrConnectionReset + case *tcpip.ErrConnectionAborted: + return ErrConnectionAborted + case *tcpip.ErrNoSuchFile: + return ErrNoSuchFile + case *tcpip.ErrInvalidOptionValue: + return ErrInvalidOptionValue + case *tcpip.ErrBadAddress: + return ErrBadAddress + case *tcpip.ErrNetworkUnreachable: + return ErrNetworkUnreachable + case *tcpip.ErrMessageTooLong: + return ErrMessageTooLong + case *tcpip.ErrNoBufferSpace: + return ErrNoBufferSpace + case *tcpip.ErrBroadcastDisabled: + return ErrBroadcastDisabled + case *tcpip.ErrNotPermitted: + return ErrNotPermittedNet + case *tcpip.ErrAddressFamilyNotSupported: + return ErrAddressFamilyNotSupported + case *tcpip.ErrBadBuffer: + return ErrBadBuffer + default: + panic(fmt.Sprintf("unknown error %T", err)) } - se, ok := netstackErrorTranslations[err.String()] - if !ok { - panic("Unknown error: " + err.String()) - } - return se } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e7924e5c2..f979d22f0 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "tcpip", srcs = [ + "errors.go", "sock_err_list.go", "socketops.go", "tcpip.go", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 7c7495c30..c188aaa18 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -248,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) l.wq.EventRegister(&waitEntry, waiter.EventIn) @@ -257,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { for { n, wq, err = l.ep.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } @@ -298,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} res, err := ep.Read(&w, opts) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { res, err = ep.Read(&w, opts) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } select { @@ -316,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s } } - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { return 0, io.EOF } @@ -356,7 +356,7 @@ func (c *TCPConn) Write(b []byte) (int, error) { } // We must handle two soft failure conditions simultaneously: - // 1. Write may write nothing and return tcpip.ErrWouldBlock. + // 1. Write may write nothing and return *tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have // not already and wait to try again. // 2. Write may write fewer than the full number of bytes and return @@ -376,9 +376,9 @@ func (c *TCPConn) Write(b []byte) (int, error) { r.Reset(b[nbytes:]) n, err := c.ep.Write(&r, tcpip.WriteOptions{}) nbytes += int(n) - switch err { + switch err.(type) { case nil: - case tcpip.ErrWouldBlock: + case *tcpip.ErrWouldBlock: if ch == nil { entry, ch = waiter.NewChannelEntry(nil) @@ -495,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { select { case <-ctx.Done(): ep.Close() @@ -649,7 +649,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { var r bytes.Reader r.Reset(b) n, err := c.ep.Write(&r, writeOptions) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.wq.EventRegister(&waitEntry, waiter.EventOut) @@ -662,7 +662,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } n, err = c.ep.Write(&r, writeOptions) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { break } } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index b196324c7..2b3ea4bdf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) { } } -func newLoopbackStack() (*stack.Stack, *tcpip.Error) { +func newLoopbackStack() (*stack.Stack, tcpip.Error) { // Create the stack and add a NIC. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, @@ -94,7 +94,7 @@ type testConnection struct { ep tcpip.Endpoint } -func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { +func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { @@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er wq.EventRegister(&entry, waiter.EventOut) err = ep.Connect(addr) - if err == tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); ok { <-ch err = ep.LastError() } @@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) { ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} - _, err := DialTCP(s, addr, ipv4.ProtocolNumber) - got, ok := err.(*net.OpError) - want := tcpip.ErrNoRoute - if !ok || got.Err.Error() != want.String() { - t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute) + switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) { + case *net.OpError: + if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() { + t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{}) + } + default: + t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{}) } } diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index c9bcf9326..23aa0ad05 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "prependable.go", "view.go", + "view_unsafe.go", ], visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 91cc62cc8..b05e81526 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -239,6 +239,16 @@ func (vv *VectorisedView) Size() int { return vv.size } +// MemSize returns the estimation size of the vv in memory, including backing +// buffer data. +func (vv *VectorisedView) MemSize() int { + var size int + for _, v := range vv.views { + size += cap(v) + } + return size + cap(vv.views)*viewStructSize + vectorisedViewStructSize +} + // ToView returns a single view containing the content of the vectorised view. // // If the vectorised view contains a single view, that view will be returned diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index e7f7cc9f1..78b2faa26 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -20,6 +20,7 @@ import ( "io" "reflect" "testing" + "unsafe" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -578,3 +579,15 @@ func TestAppendView(t *testing.T) { } } } + +func TestMemSize(t *testing.T) { + const perViewCap = 128 + views := make([]buffer.View, 2, 32) + views[0] = make(buffer.View, 10, perViewCap) + views[1] = make(buffer.View, 20, perViewCap) + vv := buffer.NewVectorisedView(30, views) + want := int(unsafe.Sizeof(vv)) + cap(views)*int(unsafe.Sizeof(views)) + 2*perViewCap + if got := vv.MemSize(); got != want { + t.Errorf("vv.MemSize() = %d, want %d", got, want) + } +} diff --git a/pkg/tcpip/buffer/view_unsafe.go b/pkg/tcpip/buffer/view_unsafe.go new file mode 100644 index 000000000..75ccd40f8 --- /dev/null +++ b/pkg/tcpip/buffer/view_unsafe.go @@ -0,0 +1,22 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import "unsafe" + +const ( + vectorisedViewStructSize = int(unsafe.Sizeof(VectorisedView{})) + viewStructSize = int(unsafe.Sizeof(View{})) +) diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go new file mode 100644 index 000000000..af46da1d2 --- /dev/null +++ b/pkg/tcpip/errors.go @@ -0,0 +1,538 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" +) + +// Error represents an error in the netstack error space. +// +// The error interface is intentionally omitted to avoid loss of type +// information that would occur if these errors were passed as error. +type Error interface { + isError() + + // IgnoreStats indicates whether this error should be included in failure + // counts in tcpip.Stats structs. + IgnoreStats() bool + + fmt.Stringer +} + +// ErrAborted indicates the operation was aborted. +// +// +stateify savable +type ErrAborted struct{} + +func (*ErrAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrAborted) IgnoreStats() bool { + return false +} +func (*ErrAborted) String() string { + return "operation aborted" +} + +// ErrAddressFamilyNotSupported indicates the operation does not support the +// given address family. +// +// +stateify savable +type ErrAddressFamilyNotSupported struct{} + +func (*ErrAddressFamilyNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrAddressFamilyNotSupported) IgnoreStats() bool { + return false +} +func (*ErrAddressFamilyNotSupported) String() string { + return "address family not supported by protocol" +} + +// ErrAlreadyBound indicates the endpoint is already bound. +// +// +stateify savable +type ErrAlreadyBound struct{} + +func (*ErrAlreadyBound) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyBound) IgnoreStats() bool { + return true +} +func (*ErrAlreadyBound) String() string { return "endpoint already bound" } + +// ErrAlreadyConnected indicates the endpoint is already connected. +// +// +stateify savable +type ErrAlreadyConnected struct{} + +func (*ErrAlreadyConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnected) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" } + +// ErrAlreadyConnecting indicates the endpoint is already connecting. +// +// +stateify savable +type ErrAlreadyConnecting struct{} + +func (*ErrAlreadyConnecting) isError() {} + +// IgnoreStats implements Error. +func (*ErrAlreadyConnecting) IgnoreStats() bool { + return true +} +func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" } + +// ErrBadAddress indicates a bad address was provided. +// +// +stateify savable +type ErrBadAddress struct{} + +func (*ErrBadAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadAddress) IgnoreStats() bool { + return false +} +func (*ErrBadAddress) String() string { return "bad address" } + +// ErrBadBuffer indicates a bad buffer was provided. +// +// +stateify savable +type ErrBadBuffer struct{} + +func (*ErrBadBuffer) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadBuffer) IgnoreStats() bool { + return false +} +func (*ErrBadBuffer) String() string { return "bad buffer" } + +// ErrBadLocalAddress indicates a bad local address was provided. +// +// +stateify savable +type ErrBadLocalAddress struct{} + +func (*ErrBadLocalAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrBadLocalAddress) IgnoreStats() bool { + return false +} +func (*ErrBadLocalAddress) String() string { return "bad local address" } + +// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint. +// +// +stateify savable +type ErrBroadcastDisabled struct{} + +func (*ErrBroadcastDisabled) isError() {} + +// IgnoreStats implements Error. +func (*ErrBroadcastDisabled) IgnoreStats() bool { + return false +} +func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" } + +// ErrClosedForReceive indicates the endpoint is closed for incoming data. +// +// +stateify savable +type ErrClosedForReceive struct{} + +func (*ErrClosedForReceive) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForReceive) IgnoreStats() bool { + return false +} +func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" } + +// ErrClosedForSend indicates the endpoint is closed for outgoing data. +// +// +stateify savable +type ErrClosedForSend struct{} + +func (*ErrClosedForSend) isError() {} + +// IgnoreStats implements Error. +func (*ErrClosedForSend) IgnoreStats() bool { + return false +} +func (*ErrClosedForSend) String() string { return "endpoint is closed for send" } + +// ErrConnectStarted indicates the endpoint is connecting asynchronously. +// +// +stateify savable +type ErrConnectStarted struct{} + +func (*ErrConnectStarted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectStarted) IgnoreStats() bool { + return true +} +func (*ErrConnectStarted) String() string { return "connection attempt started" } + +// ErrConnectionAborted indicates the connection was aborted. +// +// +stateify savable +type ErrConnectionAborted struct{} + +func (*ErrConnectionAborted) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionAborted) IgnoreStats() bool { + return false +} +func (*ErrConnectionAborted) String() string { return "connection aborted" } + +// ErrConnectionRefused indicates the connection was refused. +// +// +stateify savable +type ErrConnectionRefused struct{} + +func (*ErrConnectionRefused) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionRefused) IgnoreStats() bool { + return false +} +func (*ErrConnectionRefused) String() string { return "connection was refused" } + +// ErrConnectionReset indicates the connection was reset. +// +// +stateify savable +type ErrConnectionReset struct{} + +func (*ErrConnectionReset) isError() {} + +// IgnoreStats implements Error. +func (*ErrConnectionReset) IgnoreStats() bool { + return false +} +func (*ErrConnectionReset) String() string { return "connection reset by peer" } + +// ErrDestinationRequired indicates the operation requires a destination +// address, and one was not provided. +// +// +stateify savable +type ErrDestinationRequired struct{} + +func (*ErrDestinationRequired) isError() {} + +// IgnoreStats implements Error. +func (*ErrDestinationRequired) IgnoreStats() bool { + return false +} +func (*ErrDestinationRequired) String() string { return "destination address is required" } + +// ErrDuplicateAddress indicates the operation encountered a duplicate address. +// +// +stateify savable +type ErrDuplicateAddress struct{} + +func (*ErrDuplicateAddress) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateAddress) IgnoreStats() bool { + return false +} +func (*ErrDuplicateAddress) String() string { return "duplicate address" } + +// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID. +// +// +stateify savable +type ErrDuplicateNICID struct{} + +func (*ErrDuplicateNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrDuplicateNICID) IgnoreStats() bool { + return false +} +func (*ErrDuplicateNICID) String() string { return "duplicate nic id" } + +// ErrInvalidEndpointState indicates the endpoint is in an invalid state. +// +// +stateify savable +type ErrInvalidEndpointState struct{} + +func (*ErrInvalidEndpointState) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidEndpointState) IgnoreStats() bool { + return false +} +func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" } + +// ErrInvalidOptionValue indicates an invalid option value was provided. +// +// +stateify savable +type ErrInvalidOptionValue struct{} + +func (*ErrInvalidOptionValue) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidOptionValue) IgnoreStats() bool { + return false +} +func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } + +// ErrMalformedHeader indicates the operation encountered a malformed header. +// +// +stateify savable +type ErrMalformedHeader struct{} + +func (*ErrMalformedHeader) isError() {} + +// IgnoreStats implements Error. +func (*ErrMalformedHeader) IgnoreStats() bool { + return false +} +func (*ErrMalformedHeader) String() string { return "header is malformed" } + +// ErrMessageTooLong indicates the operation encountered a message whose length +// exceeds the maximum permitted. +// +// +stateify savable +type ErrMessageTooLong struct{} + +func (*ErrMessageTooLong) isError() {} + +// IgnoreStats implements Error. +func (*ErrMessageTooLong) IgnoreStats() bool { + return false +} +func (*ErrMessageTooLong) String() string { return "message too long" } + +// ErrNetworkUnreachable indicates the operation is not able to reach the +// destination network. +// +// +stateify savable +type ErrNetworkUnreachable struct{} + +func (*ErrNetworkUnreachable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNetworkUnreachable) IgnoreStats() bool { + return false +} +func (*ErrNetworkUnreachable) String() string { return "network is unreachable" } + +// ErrNoBufferSpace indicates no buffer space is available. +// +// +stateify savable +type ErrNoBufferSpace struct{} + +func (*ErrNoBufferSpace) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoBufferSpace) IgnoreStats() bool { + return false +} +func (*ErrNoBufferSpace) String() string { return "no buffer space available" } + +// ErrNoPortAvailable indicates no port could be allocated for the operation. +// +// +stateify savable +type ErrNoPortAvailable struct{} + +func (*ErrNoPortAvailable) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoPortAvailable) IgnoreStats() bool { + return false +} +func (*ErrNoPortAvailable) String() string { return "no ports are available" } + +// ErrNoRoute indicates the operation is not able to find a route to the +// destination. +// +// +stateify savable +type ErrNoRoute struct{} + +func (*ErrNoRoute) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoRoute) IgnoreStats() bool { + return false +} +func (*ErrNoRoute) String() string { return "no route" } + +// ErrNoSuchFile is used to indicate that ENOENT should be returned the to +// calling application. +// +// +stateify savable +type ErrNoSuchFile struct{} + +func (*ErrNoSuchFile) isError() {} + +// IgnoreStats implements Error. +func (*ErrNoSuchFile) IgnoreStats() bool { + return false +} +func (*ErrNoSuchFile) String() string { return "no such file" } + +// ErrNotConnected indicates the endpoint is not connected. +// +// +stateify savable +type ErrNotConnected struct{} + +func (*ErrNotConnected) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotConnected) IgnoreStats() bool { + return false +} +func (*ErrNotConnected) String() string { return "endpoint not connected" } + +// ErrNotPermitted indicates the operation is not permitted. +// +// +stateify savable +type ErrNotPermitted struct{} + +func (*ErrNotPermitted) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotPermitted) IgnoreStats() bool { + return false +} +func (*ErrNotPermitted) String() string { return "operation not permitted" } + +// ErrNotSupported indicates the operation is not supported. +// +// +stateify savable +type ErrNotSupported struct{} + +func (*ErrNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrNotSupported) IgnoreStats() bool { + return false +} +func (*ErrNotSupported) String() string { return "operation not supported" } + +// ErrPortInUse indicates the provided port is in use. +// +// +stateify savable +type ErrPortInUse struct{} + +func (*ErrPortInUse) isError() {} + +// IgnoreStats implements Error. +func (*ErrPortInUse) IgnoreStats() bool { + return false +} +func (*ErrPortInUse) String() string { return "port is in use" } + +// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size +// operation. +// +// +stateify savable +type ErrQueueSizeNotSupported struct{} + +func (*ErrQueueSizeNotSupported) isError() {} + +// IgnoreStats implements Error. +func (*ErrQueueSizeNotSupported) IgnoreStats() bool { + return false +} +func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" } + +// ErrTimeout indicates the operation timed out. +// +// +stateify savable +type ErrTimeout struct{} + +func (*ErrTimeout) isError() {} + +// IgnoreStats implements Error. +func (*ErrTimeout) IgnoreStats() bool { + return false +} +func (*ErrTimeout) String() string { return "operation timed out" } + +// ErrUnknownDevice indicates an unknown device identifier was provided. +// +// +stateify savable +type ErrUnknownDevice struct{} + +func (*ErrUnknownDevice) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownDevice) IgnoreStats() bool { + return false +} +func (*ErrUnknownDevice) String() string { return "unknown device" } + +// ErrUnknownNICID indicates an unknown NIC ID was provided. +// +// +stateify savable +type ErrUnknownNICID struct{} + +func (*ErrUnknownNICID) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownNICID) IgnoreStats() bool { + return false +} +func (*ErrUnknownNICID) String() string { return "unknown nic id" } + +// ErrUnknownProtocol indicates an unknown protocol was requested. +// +// +stateify savable +type ErrUnknownProtocol struct{} + +func (*ErrUnknownProtocol) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocol) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocol) String() string { return "unknown protocol" } + +// ErrUnknownProtocolOption indicates an unknown protocol option was provided. +// +// +stateify savable +type ErrUnknownProtocolOption struct{} + +func (*ErrUnknownProtocolOption) isError() {} + +// IgnoreStats implements Error. +func (*ErrUnknownProtocolOption) IgnoreStats() bool { + return false +} +func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" } + +// ErrWouldBlock indicates the operation would block. +// +// +stateify savable +type ErrWouldBlock struct{} + +func (*ErrWouldBlock) isError() {} + +// IgnoreStats implements Error. +func (*ErrWouldBlock) IgnoreStats() bool { + return true +} +func (*ErrWouldBlock) String() string { return "operation would block" } diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD index 114d43df3..bb9d44aff 100644 --- a/pkg/tcpip/faketime/BUILD +++ b/pkg/tcpip/faketime/BUILD @@ -6,10 +6,7 @@ go_library( name = "faketime", srcs = ["faketime.go"], visibility = ["//visibility:public"], - deps = [ - "//pkg/tcpip", - "@com_github_dpjacques_clockwork//:go_default_library", - ], + deps = ["//pkg/tcpip"], ) go_test( diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index f7a4fbde1..fb819d7a8 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -17,10 +17,10 @@ package faketime import ( "container/heap" + "fmt" "sync" "time" - "github.com/dpjacques/clockwork" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -44,38 +44,85 @@ func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer { return nil } +type notificationChannels struct { + mu struct { + sync.Mutex + + ch []<-chan struct{} + } +} + +func (n *notificationChannels) add(ch <-chan struct{}) { + n.mu.Lock() + defer n.mu.Unlock() + n.mu.ch = append(n.mu.ch, ch) +} + +// wait returns once all the notification channels are readable. +// +// Channels that are added while waiting on existing channels will be waited on +// as well. +func (n *notificationChannels) wait() { + for { + n.mu.Lock() + ch := n.mu.ch + n.mu.ch = nil + n.mu.Unlock() + + if len(ch) == 0 { + break + } + + for _, c := range ch { + <-c + } + } +} + // ManualClock implements tcpip.Clock and only advances manually with Advance // method. type ManualClock struct { - clock clockwork.FakeClock + // runningTimers tracks the completion of timer callbacks that began running + // immediately upon their scheduling. It is used to ensure the proper ordering + // of timer callback dispatch. + runningTimers notificationChannels + + mu struct { + sync.RWMutex - // mu protects the fields below. - mu sync.RWMutex + // now is the current (fake) time of the clock. + now time.Time - // times is min-heap of times. A heap is used for quick retrieval of the next - // upcoming time of scheduled work. - times *timeHeap + // times is min-heap of times. + times timeHeap - // waitGroups stores one WaitGroup for all work scheduled to execute at the - // same time via AfterFunc. This allows parallel execution of all functions - // passed to AfterFunc scheduled for the same time. - waitGroups map[time.Time]*sync.WaitGroup + // timers holds the timers scheduled for each time. + timers map[time.Time]map[*manualTimer]struct{} + } } // NewManualClock creates a new ManualClock instance. func NewManualClock() *ManualClock { - return &ManualClock{ - clock: clockwork.NewFakeClock(), - times: &timeHeap{}, - waitGroups: make(map[time.Time]*sync.WaitGroup), - } + c := &ManualClock{} + + c.mu.Lock() + defer c.mu.Unlock() + + // Set the initial time to a non-zero value since the zero value is used to + // detect inactive timers. + c.mu.now = time.Unix(0, 0) + c.mu.timers = make(map[time.Time]map[*manualTimer]struct{}) + + return c } var _ tcpip.Clock = (*ManualClock)(nil) // NowNanoseconds implements tcpip.Clock.NowNanoseconds. func (mc *ManualClock) NowNanoseconds() int64 { - return mc.clock.Now().UnixNano() + mc.mu.RLock() + defer mc.mu.RUnlock() + return mc.mu.now.UnixNano() } // NowMonotonic implements tcpip.Clock.NowMonotonic. @@ -85,128 +132,203 @@ func (mc *ManualClock) NowMonotonic() int64 { // AfterFunc implements tcpip.Clock.AfterFunc. func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := mc.clock.Now().Add(d) - wg := mc.addWait(until) - return &manualTimer{ + mt := &manualTimer{ clock: mc, - until: until, - timer: mc.clock.AfterFunc(d, func() { - defer wg.Done() - f() - }), + f: f, } -} -// addWait adds an additional wait to the WaitGroup for parallel execution of -// all work scheduled for t. Returns a reference to the WaitGroup modified. -func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup { - mc.mu.RLock() - wg, ok := mc.waitGroups[t] - mc.mu.RUnlock() + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() - if ok { - wg.Add(1) - return wg + mc.resetTimerLocked(mt, d) + return mt +} + +// resetTimerLocked schedules a timer to be fired after the given duration. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) { + if !mt.mu.firesAt.IsZero() { + panic("tried to reset an active timer") } - mc.mu.Lock() - heap.Push(mc.times, t) - mc.mu.Unlock() + t := mc.mu.now.Add(d) - wg = &sync.WaitGroup{} - wg.Add(1) + if !mc.mu.now.Before(t) { + // If the timer is scheduled to fire immediately, call its callback + // in a new goroutine immediately. + // + // It needs to be called in its own goroutine to escape its current + // execution context - like an actual timer. + ch := make(chan struct{}) + mc.runningTimers.add(ch) - mc.mu.Lock() - mc.waitGroups[t] = wg - mc.mu.Unlock() + go func() { + defer close(ch) + + mt.f() + }() - return wg + return + } + + mt.mu.firesAt = t + + timers, ok := mc.mu.timers[t] + if !ok { + timers = make(map[*manualTimer]struct{}) + mc.mu.timers[t] = timers + heap.Push(&mc.mu.times, t) + } + + timers[mt] = struct{}{} } -// removeWait removes a wait from the WaitGroup for parallel execution of all -// work scheduled for t. -func (mc *ManualClock) removeWait(t time.Time) { - mc.mu.RLock() - defer mc.mu.RUnlock() +// stopTimerLocked stops a timer from firing. +// +// Precondition: mc.mu and mt.mu must be locked. +func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { + t := mt.mu.firesAt + mt.mu.firesAt = time.Time{} + + if t.IsZero() { + panic("tried to stop an inactive timer") + } - wg := mc.waitGroups[t] - wg.Done() + timers, ok := mc.mu.timers[t] + if !ok { + err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt) + for t := range mc.mu.timers { + err += fmt.Sprintf("%s\n", t.UTC()) + } + panic(err) + } + + if _, ok := timers[mt]; !ok { + panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC())) + } + + delete(timers, mt) + + if len(timers) == 0 { + delete(mc.mu.timers, t) + } } // Advance executes all work that have been scheduled to execute within d from -// the current time. Blocks until all work has completed execution. +// the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { - // Block until all the work is done - until := mc.clock.Now().Add(d) - for { - mc.mu.Lock() - if mc.times.Len() == 0 { - mc.mu.Unlock() - break - } + // We spawn goroutines for timers that were scheduled to fire at the time of + // being reset. Wait for those goroutines to complete before proceeding so + // that timer callbacks are called in the right order. + mc.runningTimers.wait() - t := heap.Pop(mc.times).(time.Time) + mc.mu.Lock() + defer mc.mu.Unlock() + + until := mc.mu.now.Add(d) + for mc.mu.times.Len() > 0 { + t := heap.Pop(&mc.mu.times).(time.Time) if t.After(until) { // No work to do - heap.Push(mc.times, t) - mc.mu.Unlock() + heap.Push(&mc.mu.times, t) break } - mc.mu.Unlock() - diff := t.Sub(mc.clock.Now()) - mc.clock.Advance(diff) + timers := mc.mu.timers[t] + delete(mc.mu.timers, t) + + mc.mu.now = t + + // Mark the timers as inactive since they will be fired. + // + // This needs to be done while holding mc's lock because we remove the entry + // in the map of timers for the current time. If an attempt to stop a + // timer is made after mc's lock was dropped but before the timer is + // marked inactive, we would panic since no entry exists for the time when + // the timer was expected to fire. + for mt := range timers { + mt.mu.Lock() + mt.mu.firesAt = time.Time{} + mt.mu.Unlock() + } - mc.mu.RLock() - wg := mc.waitGroups[t] - mc.mu.RUnlock() + // Release the lock before calling the timer's callback fn since the + // callback fn might try to schedule a timer which requires obtaining + // mc's lock. + mc.mu.Unlock() - wg.Wait() + for mt := range timers { + mt.f() + } + // The timer callbacks may have scheduled a timer to fire immediately. + // We spawn goroutines for these timers and need to wait for them to + // finish before proceeding so that timer callbacks are called in the + // right order. + mc.runningTimers.wait() mc.mu.Lock() - delete(mc.waitGroups, t) - mc.mu.Unlock() } - if now := mc.clock.Now(); until.After(now) { - mc.clock.Advance(until.Sub(now)) + + mc.mu.now = until +} + +func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if !mt.mu.firesAt.IsZero() { + mc.stopTimerLocked(mt) } + + mc.resetTimerLocked(mt, d) +} + +func (mc *ManualClock) stopTimer(mt *manualTimer) bool { + mc.mu.Lock() + defer mc.mu.Unlock() + + mt.mu.Lock() + defer mt.mu.Unlock() + + if mt.mu.firesAt.IsZero() { + return false + } + + mc.stopTimerLocked(mt) + return true } type manualTimer struct { clock *ManualClock - timer clockwork.Timer + f func() - mu sync.RWMutex - until time.Time + mu struct { + sync.Mutex + + // firesAt is the time when the timer will fire. + // + // Zero only when the timer is not active. + firesAt time.Time + } } var _ tcpip.Timer = (*manualTimer)(nil) // Reset implements tcpip.Timer.Reset. -func (t *manualTimer) Reset(d time.Duration) { - if !t.timer.Reset(d) { - return - } - - t.mu.Lock() - defer t.mu.Unlock() - - t.clock.removeWait(t.until) - t.until = t.clock.clock.Now().Add(d) - t.clock.addWait(t.until) +func (mt *manualTimer) Reset(d time.Duration) { + mt.clock.resetTimer(mt, d) } // Stop implements tcpip.Timer.Stop. -func (t *manualTimer) Stop() bool { - if !t.timer.Stop() { - return false - } - - t.mu.RLock() - defer t.mu.RUnlock() - - t.clock.removeWait(t.until) - return true +func (mt *manualTimer) Stop() bool { + return mt.clock.stopTimer(mt) } type timeHeap []time.Time diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 5f9b8e9e2..f840a4322 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,7 +16,6 @@ package header import ( "encoding/binary" - "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -208,16 +207,3 @@ func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { return ^xsum } - -// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when -// a packet having a `net` header causing an ICMP error. -func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { - switch net { - case IPv4ProtocolNumber: - return tcpip.SockExtErrorOriginICMP - case IPv6ProtocolNumber: - return tcpip.SockExtErrorOriginICMP6 - default: - panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) - } -} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 5580d6a78..f2403978c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -453,9 +453,9 @@ const ( ) // ScopeForIPv6Address returns the scope for an IPv6 address. -func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { if len(addr) != IPv6AddressSize { - return GlobalScope, tcpip.ErrBadAddress + return GlobalScope, &tcpip.ErrBadAddress{} } switch { diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index e3fbd64f3..f10f446a6 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -299,7 +299,7 @@ func TestScopeForIPv6Address(t *testing.T) { name string addr tcpip.Address scope header.IPv6AddressScope - err *tcpip.Error + err tcpip.Error }{ { name: "Unique Local", @@ -329,15 +329,15 @@ func TestScopeForIPv6Address(t *testing.T) { name: "IPv4", addr: "\x01\x02\x03\x04", scope: header.GlobalScope, - err: tcpip.ErrBadAddress, + err: &tcpip.ErrBadAddress{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { got, err := header.ScopeForIPv6Address(test.addr) - if err != test.err { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + if diff := cmp.Diff(test.err, err); diff != "" { + t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff) } if got != test.scope { t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a068d93a4..cd76272de 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, @@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 2f2d9d4ac..d873766a6 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f86c383d8..0164d851b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -57,7 +57,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { - dispatch() (bool, *tcpip.Error) + dispatch() (bool, tcpip.Error) } // PacketDispatchMode are the various supported methods of receiving and @@ -118,7 +118,7 @@ type endpoint struct { // closed is a function to be called when the FD's peer (if any) closes // its end of the communication pipe. - closed func(*tcpip.Error) + closed func(tcpip.Error) inboundDispatchers []linkDispatcher dispatcher stack.NetworkDispatcher @@ -149,7 +149,7 @@ type Options struct { // ClosedFunc is a function to be called when an endpoint's peer (if // any) closes its end of the communication pipe. - ClosedFunc func(*tcpip.Error) + ClosedFunc func(tcpip.Error) // Address is the link address for this endpoint. Only used if // EthernetHeader is true. @@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { @@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. -func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) } // dispatchLoop reads packets from the file descriptor in a loop and dispatches // them to the network stack. -func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error { +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { for { cont, err := inboundDispatcher.dispatch() if err != nil || !cont { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e2985cb84..e82371798 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -95,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context { } done := make(chan struct{}, 2) - opt.ClosedFunc = func(*tcpip.Error) { + opt.ClosedFunc = func(tcpip.Error) { done <- struct{}{} } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index c475dda20..a2b63fe6b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -129,7 +129,7 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { event := rawfile.PollEvent{ @@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) { // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. -func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { pkt, err := d.readMMappedPacket() if err != nil { return false, err diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index edab110b5..ecae1ad2d 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { } // dispatch reads one packet from the file descriptor and dispatches it. -func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) if n == 0 || err != nil { return false, err @@ -226,7 +226,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { // recvMMsgDispatch reads more than one packet at a time from the file // descriptor and dispatches it. -func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { // Fill message headers. for k := range d.msgHdrs { if d.msgHdrs[k].Msg.Iovlen > 0 { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index ac6a6be87..691467870 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 316f508e6..668f72eee 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } return endpoint.WritePackets(r, gso, pkts, protocol) } @@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. -func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { +func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error { endpoint, ok := m.routes[dest] if !ok { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } return endpoint.InjectOutbound(dest, packet) } diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 814a54f23..97ad9fdd5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return e.child.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return e.child.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index c95cdd681..6cbe18a56 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 36aa9055c..bbe84f220 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -75,7 +75,7 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.linked.IsAttached() { var pkts stack.PacketBufferList pkts.PushBack(pkt) @@ -86,7 +86,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if e.linked.IsAttached() { e.deliverPackets(r, proto, pkts) } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 03efba606..128ef6e87 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. pkt.EgressRoute = r @@ -158,7 +158,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { - return tcpip.ErrNoBufferSpace + return &tcpip.ErrNoBufferSpace{} } d.newPacketWaker.Assert() return nil @@ -171,7 +171,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -180,7 +180,7 @@ func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.Pa if enqueued > 0 { d.newPacketWaker.Assert() } - return enqueued, tcpip.ErrNoBufferSpace + return enqueued, &tcpip.ErrNoBufferSpace{} } pkt = nxt enqueued++ diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 6c410c5a6..e1047da50 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -27,5 +27,6 @@ go_test( library = "rawfile", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 604868fd8..406b97709 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -17,7 +17,6 @@ package rawfile import ( - "fmt" "syscall" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,48 +24,54 @@ import ( const maxErrno = 134 -var translations [maxErrno]*tcpip.Error - // TranslateErrno translate an errno from the syscall package into a -// *tcpip.Error. +// tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). -func TranslateErrno(e syscall.Errno) *tcpip.Error { - if e > 0 && e < syscall.Errno(len(translations)) { - if err := translations[e]; err != nil { - return err - } - } - return tcpip.ErrInvalidEndpointState -} - -func addTranslation(host syscall.Errno, trans *tcpip.Error) { - if translations[host] != nil { - panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) +// *tcpip.ErrInvalidEndpointState (EINVAL). +func TranslateErrno(e syscall.Errno) tcpip.Error { + switch e { + case syscall.EEXIST: + return &tcpip.ErrDuplicateAddress{} + case syscall.ENETUNREACH: + return &tcpip.ErrNoRoute{} + case syscall.EINVAL: + return &tcpip.ErrInvalidEndpointState{} + case syscall.EALREADY: + return &tcpip.ErrAlreadyConnecting{} + case syscall.EISCONN: + return &tcpip.ErrAlreadyConnected{} + case syscall.EADDRINUSE: + return &tcpip.ErrPortInUse{} + case syscall.EADDRNOTAVAIL: + return &tcpip.ErrBadLocalAddress{} + case syscall.EPIPE: + return &tcpip.ErrClosedForSend{} + case syscall.EWOULDBLOCK: + return &tcpip.ErrWouldBlock{} + case syscall.ECONNREFUSED: + return &tcpip.ErrConnectionRefused{} + case syscall.ETIMEDOUT: + return &tcpip.ErrTimeout{} + case syscall.EINPROGRESS: + return &tcpip.ErrConnectStarted{} + case syscall.EDESTADDRREQ: + return &tcpip.ErrDestinationRequired{} + case syscall.ENOTSUP: + return &tcpip.ErrNotSupported{} + case syscall.ENOTTY: + return &tcpip.ErrQueueSizeNotSupported{} + case syscall.ENOTCONN: + return &tcpip.ErrNotConnected{} + case syscall.ECONNRESET: + return &tcpip.ErrConnectionReset{} + case syscall.ECONNABORTED: + return &tcpip.ErrConnectionAborted{} + case syscall.EMSGSIZE: + return &tcpip.ErrMessageTooLong{} + case syscall.ENOBUFS: + return &tcpip.ErrNoBufferSpace{} + default: + return &tcpip.ErrInvalidEndpointState{} } - translations[host] = trans -} - -func init() { - addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) - addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) - addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) - addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) - addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) - addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) - addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) - addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) - addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) - addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) - addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) - addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) - addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) - addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) - addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) - addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) - addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) - addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) - addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) - addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go index e4cdc66bd..61aea1744 100644 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -20,34 +20,35 @@ import ( "syscall" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) func TestTranslateErrno(t *testing.T) { for _, test := range []struct { errno syscall.Errno - translated *tcpip.Error + translated tcpip.Error }{ { errno: syscall.Errno(0), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(maxErrno), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.Errno(514), - translated: tcpip.ErrInvalidEndpointState, + translated: &tcpip.ErrInvalidEndpointState{}, }, { errno: syscall.EEXIST, - translated: tcpip.ErrDuplicateAddress, + translated: &tcpip.ErrDuplicateAddress{}, }, } { got := TranslateErrno(test.errno) - if got != test.translated { - t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + if diff := cmp.Diff(test.translated, got); diff != "" { + t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff) } } } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index f4c32c2da..06f3ee21e 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) { // NonBlockingWrite writes the given buffer to a file descriptor. It fails if // partial data is written. -func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { +func NonBlockingWrite(fd int, buf []byte) tcpip.Error { var ptr unsafe.Pointer if len(buf) > 0 { ptr = unsafe.Pointer(&buf[0]) @@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { // NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. // It fails if partial data is written. -func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error { iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { @@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { } // NonBlockingSendMMsg sends multiple messages on a socket. -func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e != 0 { return 0, TranslateErrno(e) @@ -97,7 +97,7 @@ type PollEvent struct { // BlockingRead reads from a file descriptor that is set up as non-blocking. If // no data is available, it will block in a poll() syscall until the file // descriptor becomes readable. -func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { +func BlockingRead(fd int, b []byte) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) if e == 0 { @@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { // BlockingReadv reads from a file descriptor that is set up as non-blocking and // stores the data in a list of iovecs buffers. If no data is available, it will // block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { @@ -149,7 +149,7 @@ type MMsgHdr struct { // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor // becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) { +func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0) if e == 0 { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6c937c858..2599bc406 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N e.mu.Unlock() if !ok { - return tcpip.ErrWouldBlock + return &tcpip.ErrWouldBlock{} } return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 23242b9e0..d480ad656 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) { Data: buf.ToVectorisedView(), }) err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) - if want := tcpip.ErrWouldBlock; err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } @@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { - t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 5859851d8..bd2b8d4bf 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket(directionSend, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index bfac358f4..3829ca9c9 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ Name: endpoint.name, }) - switch err { + switch err.(type) { case nil: return endpoint, nil - case tcpip.ErrDuplicateNICID: + case *tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 30f1ad540..20259b285 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index b139de7dd..e368a9eaa 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9ebf31b78..0caa65251 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -25,5 +25,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index c7ab876bf..933845269 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,7 +10,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 6bc8c5c02..0d7fadc31 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -22,7 +22,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +34,8 @@ const ( ProtocolNumber = header.ARPProtocolNumber ) +var _ stack.LinkAddressResolver = (*endpoint)(nil) + // ARP endpoints need to implement stack.NetworkEndpoint because the stack // considers the layer above the link-layer a network layer; the only // facility provided by the stack to deliver packets to a layer above @@ -49,15 +50,13 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - stats sharedStats + nic stack.NetworkInterface + stats sharedStats } -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } e.setEnabled(true) @@ -101,12 +100,10 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.protocol.forgetEndpoint(e.nic.ID()) -} +func (*endpoint) Close() {} -func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -115,12 +112,12 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { + return 0, &tcpip.ErrNotSupported{} } -func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { @@ -151,10 +148,12 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr) - } else { - e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + switch err := e.nic.HandleNeighborProbe(header.IPv4ProtocolNumber, remoteAddr, remoteLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) } respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -195,14 +194,9 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - if e.nud == nil { - e.linkAddrCache.AddLinkAddress(addr, linkAddr) - return - } - // The solicited, override, and isRouter flags are not available for ARP; // they are only available for IPv6 Neighbor Advertisements. - e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + switch err := e.nic.HandleNeighborConfirmation(header.IPv4ProtocolNumber, addr, linkAddr, stack.ReachabilityConfirmationFlags{ // Solicited and unsolicited (also referred to as gratuitous) ARP Replies // are handled equivalently to a solicited Neighbor Advertisement. Solicited: true, @@ -211,7 +205,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { Override: false, // ARP does not distinguish between router and non-router hosts. IsRouter: false, - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ARP but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } } } @@ -221,19 +221,10 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { } var _ stack.NetworkProtocol = (*protocol)(nil) -var _ stack.LinkAddressResolver = (*protocol)(nil) // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack - - mu struct { - sync.RWMutex - - // eps is keyed by NICID to allow protocol methods to retrieve the correct - // endpoint depending on the NIC. - eps map[tcpip.NICID]*endpoint - } } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -244,12 +235,10 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - protocol: p, - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, + protocol: p, + nic: nic, } tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) @@ -257,60 +246,43 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L stackStats := p.stack.Stats() e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP) - p.mu.Lock() - p.mu.eps[nic.ID()] = e - p.mu.Unlock() - return e } -func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { - p.mu.Lock() - defer p.mu.Unlock() - delete(p.mu.eps, nicID) -} - // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return tcpip.ErrNotConnected - } +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + nicID := e.nic.ID() - stats := netEP.stats.arp + stats := e.stats.arp if len(remoteLinkAddr) == 0 { remoteLinkAddr = header.EthernetBroadcastAddress } if len(localAddr) == 0 { - addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + addr, ok := e.protocol.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } if len(addr.Address) == 0 { stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment() - return tcpip.ErrNetworkUnreachable + return &tcpip.ErrNetworkUnreachable{} } localAddr = addr.Address - } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + } else if e.protocol.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { stats.outgoingRequestBadLocalAddressErrors.Increment() - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) pkt.NetworkProtocolNumber = ProtocolNumber @@ -318,14 +290,14 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot h.SetOp(header.ARPRequest) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a // link address. - _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + _ = copy(h.HardwareAddressSender(), e.nic.LinkAddress()) if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) } - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } @@ -334,7 +306,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return header.EthernetBroadcastAddress, true } @@ -345,13 +317,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. @@ -369,9 +341,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{ stack: s, - mu: struct { - sync.RWMutex - eps map[tcpip.NICID]*endpoint - }{eps: make(map[tcpip.NICID]*endpoint)}, } } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 001fca727..24357e15d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) } case <-ctx.Done(): return fmt.Errorf("%s for %s", ctx.Err(), want) @@ -491,9 +491,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { t.Fatal(err) } - neighbors, err := c.s.Neighbors(nicID) + neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) if err != nil { - t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + t.Fatalf("c.s.Neighbors(%d, %d): %s", nicID, ipv4.ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) @@ -530,52 +530,19 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } -var _ stack.NetworkInterface = (*testInterface)(nil) +var _ stack.LinkEndpoint = (*testLinkEndpoint)(nil) -type testInterface struct { +type testLinkEndpoint struct { stack.LinkEndpoint - nicID tcpip.NICID - - writeErr *tcpip.Error -} - -func (t *testInterface) ID() tcpip.NICID { - return t.nicID -} - -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - -func (*testInterface) Enabled() bool { - return true -} - -func (*testInterface) Promiscuous() bool { - return false -} - -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) -} - -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) + writeErr tcpip.Error } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - var r stack.RouteInfo - r.NetProto = protocol - r.RemoteLinkAddress = remoteLinkAddr return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } @@ -589,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr tcpip.Address localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - linkErr *tcpip.Error - expectedErr *tcpip.Error + linkErr tcpip.Error + expectedErr tcpip.Error expectedLocalAddr tcpip.Address expectedRemoteLinkAddr tcpip.LinkAddress expectedRequestsSent uint64 @@ -651,7 +618,7 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr: stackAddr, localAddr: testAddr, remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrBadLocalAddress, + expectedErr: &tcpip.ErrBadLocalAddress{}, expectedRequestsSent: 0, expectedRequestBadLocalAddressErrors: 1, expectedRequestInterfaceHasNoLocalAddressErrors: 0, @@ -662,7 +629,7 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr: stackAddr, localAddr: testAddr, remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, + expectedErr: &tcpip.ErrBadLocalAddress{}, expectedRequestsSent: 0, expectedRequestBadLocalAddressErrors: 1, expectedRequestInterfaceHasNoLocalAddressErrors: 0, @@ -673,7 +640,7 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr: "", localAddr: "", remoteLinkAddr: remoteLinkAddr, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, expectedRequestsSent: 0, expectedRequestBadLocalAddressErrors: 0, expectedRequestInterfaceHasNoLocalAddressErrors: 1, @@ -684,7 +651,7 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr: "", localAddr: "", remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, expectedRequestsSent: 0, expectedRequestBadLocalAddressErrors: 0, expectedRequestInterfaceHasNoLocalAddressErrors: 1, @@ -695,8 +662,8 @@ func TestLinkAddressRequest(t *testing.T) { nicAddr: stackAddr, localAddr: stackAddr, remoteLinkAddr: remoteLinkAddr, - linkErr: tcpip.ErrInvalidEndpointState, - expectedErr: tcpip.ErrInvalidEndpointState, + linkErr: &tcpip.ErrInvalidEndpointState{}, + expectedErr: &tcpip.ErrInvalidEndpointState{}, expectedRequestsSent: 0, expectedRequestBadLocalAddressErrors: 0, expectedRequestInterfaceHasNoLocalAddressErrors: 0, @@ -709,31 +676,31 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := s.CreateNIC(nicID, linkEP); err != nil { + if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + ep, err := s.GetNetworkEndpoint(nicID, arp.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, arp.ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same - // NIC ID and link endpoint used by the NIC we created earlier so that we - // can mock a link address request and observe the packets sent to the - // link endpoint even though the stack uses the real NIC to validate the - // local address. - iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr} - if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + { + err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff) + } } if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent { @@ -781,18 +748,3 @@ func TestLinkAddressRequest(t *testing.T) { }) } } - -func TestLinkAddressRequestWithoutNIC(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, - }) - p := s.NetworkProtocolInstance(arp.ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } - - if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrNotConnected { - t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrNotConnected) - } -} diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go index 036fdf739..65c708ac4 100644 --- a/pkg/tcpip/network/arp/stats_test.go +++ b/pkg/tcpip/network/arp/stats_test.go @@ -34,55 +34,13 @@ func (t *testInterface) ID() tcpip.NICID { return t.nicID } -func knownNICIDs(proto *protocol) []tcpip.NICID { - var nicIDs []tcpip.NICID - - for k := range proto.mu.eps { - nicIDs = append(nicIDs, k) - } - - return nicIDs -} - -func TestClearEndpointFromProtocolOnClose(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - }) - proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) - var nicIDs []tcpip.NICID - - proto.mu.Lock() - foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - - if !hasEndpointBeforeClose { - t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) - } - if foundEP != ep { - t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID()) - } - - ep.Close() - - proto.mu.Lock() - _, hasEndpointAfterClose := proto.mu.eps[nic.ID()] - nicIDs = knownNICIDs(proto) - proto.mu.Unlock() - if hasEndpointAfterClose { - t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) - } -} - func TestMultiCounterStatsInitialization(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1af87d713..243738951 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -84,7 +84,7 @@ type Fragmentation struct { lowLimit int reassemblers map[FragmentID]*reassembler rList reassemblerList - size int + memSize int timeout time.Duration blockSize uint16 clock tcpip.Clock @@ -156,22 +156,22 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // the protocol to identify a fragment. func (f *Fragmentation) Process( id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( - buffer.VectorisedView, uint8, bool, error) { + *stack.PacketBuffer, uint8, bool, error) { if first > last { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) } if first%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) } fragmentSize := last - first + 1 if more && fragmentSize%f.blockSize != 0 { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } if l := pkt.Data.Size(); l != int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } f.mu.Lock() @@ -190,24 +190,24 @@ func (f *Fragmentation) Process( } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) + resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() f.release(r, false /* timedOut */) f.mu.Unlock() - return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) + return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() - f.size += consumed + f.memSize += memConsumed if done { f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. - if f.size > f.highLimit { - for f.size > f.lowLimit { + if f.memSize > f.highLimit { + for f.memSize > f.lowLimit { tail := f.rList.Back() if tail == nil { break @@ -216,7 +216,7 @@ func (f *Fragmentation) Process( } } f.mu.Unlock() - return res, firstFragmentProto, done, nil + return resPkt, firstFragmentProto, done, nil } func (f *Fragmentation) release(r *reassembler, timedOut bool) { @@ -228,10 +228,10 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { delete(f.reassemblers, r.id) f.rList.Remove(r) - f.size -= r.size - if f.size < 0 { - log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) - f.size = 0 + f.memSize -= r.memSize + if f.memSize < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize) + f.memSize = 0 } if h := f.timeoutHandler; timedOut && h != nil { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 3a79688a8..905bbc19b 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -16,7 +16,6 @@ package fragmentation import ( "errors" - "reflect" "testing" "time" @@ -112,20 +111,20 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) + resPkt, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } - if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) - } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if diff := cmp.Diff(c.out[i].vv.ToOwnedView(), resPkt.Data.ToOwnedView()); diff != "" { + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) result mismatch (-want, +got):\n%s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, diff) + } if firstFragmentProto != proto { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) @@ -173,9 +172,17 @@ func TestReassemblingTimeout(t *testing.T) { // reassembly is done after the fragment is processd. expectDone bool - // sizeAfterEvent is the expected size of the fragmentation instance after - // the event. - sizeAfterEvent int + // memSizeAfterEvent is the expected memory size of the fragmentation + // instance after the event. + memSizeAfterEvent int + } + + memSizeOfFrags := func(frags ...*fragment) int { + var size int + for _, frag := range frags { + size += pkt(len(frag.data), frag.data).MemSize() + } + return size } half1 := &fragment{first: 0, last: 0, more: true, data: "0"} @@ -189,16 +196,16 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 and half2 are reassembled successfully", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half2", - fragment: half2, - expectDone: true, - sizeAfterEvent: 0, + name: "half2", + fragment: half2, + expectDone: true, + memSizeAfterEvent: 0, }, }, }, @@ -206,36 +213,36 @@ func TestReassemblingTimeout(t *testing.T) { name: "half1 timeout, half2 timeout", events: []event{ { - name: "half1", - fragment: half1, - expectDone: false, - sizeAfterEvent: 1, + name: "half1", + fragment: half1, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half1 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half1), }, { - name: "half1 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half1 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, { - name: "half2", - fragment: half2, - expectDone: false, - sizeAfterEvent: 1, + name: "half2", + fragment: half2, + expectDone: false, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 just before reassembly timeout", - clockAdvance: reassemblyTimeout - 1, - sizeAfterEvent: 1, + name: "half2 just before reassembly timeout", + clockAdvance: reassemblyTimeout - 1, + memSizeAfterEvent: memSizeOfFrags(half2), }, { - name: "half2 reassembly timeout", - clockAdvance: 1, - sizeAfterEvent: 0, + name: "half2 reassembly timeout", + clockAdvance: 1, + memSizeAfterEvent: 0, }, }, }, @@ -255,8 +262,8 @@ func TestReassemblingTimeout(t *testing.T) { t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone) } } - if got, want := f.size, event.sizeAfterEvent; got != want { - t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want) + if got, want := f.memSize, event.memSizeAfterEvent; got != want { + t.Errorf("%s: got f.memSize = %d, want = %d", event.name, got, want) } } }) @@ -264,7 +271,9 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) + lowLimit := pkt(1, "0").MemSize() + highLimit := 3 * lowLimit // Allow at most 3 such packets. + f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. @@ -288,15 +297,14 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) + memSize := pkt(1, "0").MemSize() + f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) - got := f.size - want := 1 - if got != want { + if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9b20bb1d8..933d63d32 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -29,13 +28,15 @@ type hole struct { last uint16 filled bool final bool - data buffer.View + // pkt is the fragment packet if hole is filled. We keep the whole pkt rather + // than the fragmented payload to prevent binding to specific buffer types. + pkt *stack.PacketBuffer } type reassembler struct { reassemblerEntry id FragmentID - size int + memSize int proto uint8 mu sync.Mutex holes []hole @@ -59,18 +60,18 @@ func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { return r } -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (*stack.PacketBuffer, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, 0, nil + return nil, 0, false, 0, nil } var holeFound bool - var consumed int + var memConsumed int for i := range r.holes { currentHole := &r.holes[i] @@ -90,12 +91,12 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 if first < currentHole.first || currentHole.last < last { // Incoming fragment only partially fits in the free hole. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + return nil, 0, false, 0, ErrFragmentOverlap } if !more { if !currentHole.final || currentHole.filled && currentHole.last != last { // We have another final fragment, which does not perfectly overlap. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } } @@ -124,16 +125,15 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s }) currentHole.final = false } - v := pkt.Data.ToOwnedView() - consumed = v.Size() - r.size += consumed + memConsumed = pkt.MemSize() + r.memSize += memConsumed // Update the current hole to precisely match the incoming fragment. r.holes[i] = hole{ first: first, last: last, filled: true, final: currentHole.final, - data: v, + pkt: pkt, } r.filled++ // For IPv6, it is possible to have different Protocol values between @@ -153,25 +153,24 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s } if !holeFound { // Incoming fragment is beyond end. - return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + return nil, 0, false, 0, ErrFragmentConflict } // Check if all the holes have been filled and we are ready to reassemble. if r.filled < len(r.holes) { - return buffer.VectorisedView{}, 0, false, consumed, nil + return nil, 0, false, memConsumed, nil } sort.Slice(r.holes, func(i, j int) bool { return r.holes[i].first < r.holes[j].first }) - var size int - views := make([]buffer.View, 0, len(r.holes)) - for _, hole := range r.holes { - views = append(views, hole.data) - size += hole.data.Size() + resPkt := r.holes[0].pkt + for i := 1; i < len(r.holes); i++ { + fragPkt := r.holes[i].pkt + fragPkt.Data.ReadToVV(&resPkt.Data, fragPkt.Data.Size()) } - return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil + return resPkt, r.proto, true, memConsumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index 2ff03eeeb..214a93709 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -15,6 +15,7 @@ package fragmentation import ( + "bytes" "math" "testing" @@ -44,16 +45,21 @@ func TestReassemblerProcess(t *testing.T) { return payload } - pkt := func(size int) *stack.PacketBuffer { + pkt := func(sizes ...int) *stack.PacketBuffer { + var vv buffer.VectorisedView + for _, size := range sizes { + vv.AppendView(v(size)) + } return stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: v(size).ToVectorisedView(), + Data: vv, }) } var tests = []struct { - name string - params []processParams - want []hole + name string + params []processParams + want []hole + wantPkt *stack.PacketBuffer }{ { name: "No fragments", @@ -64,7 +70,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at beginning", params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 0, last: 1, filled: true, final: false, pkt: pkt(2)}, {first: 2, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -72,7 +78,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment in the middle", params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 1, last: 2, filled: true, final: false, pkt: pkt(2)}, {first: 0, last: 0, filled: false, final: false}, {first: 3, last: math.MaxUint16, filled: false, final: true}, }, @@ -81,7 +87,7 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment at the end", params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 1, last: 2, filled: true, final: true, pkt: pkt(2)}, {first: 0, last: 0, filled: false}, }, }, @@ -89,8 +95,9 @@ func TestReassemblerProcess(t *testing.T) { name: "One fragment completing a packet", params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: true}, }, + wantPkt: pkt(2), }, { name: "Two fragments completing a packet", @@ -99,9 +106,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a duplicate", @@ -111,9 +119,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true, final: false, data: v(2)}, - {first: 2, last: 3, filled: true, final: true, data: v(2)}, + {first: 0, last: 1, filled: true, final: false}, + {first: 2, last: 3, filled: true, final: true}, }, + wantPkt: pkt(2, 2), }, { name: "Two fragments completing a packet with a partial duplicate", @@ -123,9 +132,10 @@ func TestReassemblerProcess(t *testing.T) { {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 3, filled: true, final: false, data: v(4)}, - {first: 4, last: 5, filled: true, final: true, data: v(2)}, + {first: 0, last: 3, filled: true, final: false}, + {first: 4, last: 5, filled: true, final: true}, }, + wantPkt: pkt(4, 2), }, { name: "Two overlapping fragments", @@ -134,7 +144,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, }, want: []hole{ - {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 0, last: 10, filled: true, final: false, pkt: pkt(11)}, {first: 11, last: math.MaxUint16, filled: false, final: true}, }, }, @@ -145,7 +155,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 10, last: 14, filled: true, final: true, pkt: pkt(5)}, {first: 0, last: 9, filled: false, final: false}, }, }, @@ -156,7 +166,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -167,7 +177,7 @@ func TestReassemblerProcess(t *testing.T) { {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 5, last: 14, filled: true, final: true, pkt: pkt(10)}, {first: 0, last: 4, filled: false, final: false}, }, }, @@ -176,14 +186,47 @@ func TestReassemblerProcess(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { r := newReassembler(FragmentID{}, &faketime.NullClock{}) + var resPkt *stack.PacketBuffer + var isDone bool for _, param := range test.params { - _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + pkt, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) if done != param.wantDone || err != param.wantError { t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } + if done { + resPkt = pkt + isDone = true + } + } + + ignorePkt := func(a, b *stack.PacketBuffer) bool { return true } + cmpPktData := func(a, b *stack.PacketBuffer) bool { + if a == nil || b == nil { + return a == b + } + return bytes.Equal(a.Data.ToOwnedView(), b.Data.ToOwnedView()) } - if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { - t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + + if isDone { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + // Do not compare pkt in hole. Data will be altered. + cmp.Comparer(ignorePkt), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(test.wantPkt, resPkt, cmp.Comparer(cmpPktData)); diff != "" { + t.Errorf("Reassembled pkt mismatch (-want +got):\n%s", diff) + } + } else { + if diff := cmp.Diff( + test.want, r.holes, + cmp.AllowUnexported(hole{}), + cmp.Comparer(cmpPktData), + ); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index f2f0e069c..b9f129728 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -126,6 +126,16 @@ type multicastGroupState struct { // // Must not be nil. delayedReportJob *tcpip.Job + + // delyedReportJobFiresAt is the time when the delayed report job will fire. + // + // A zero value indicates that the job is not scheduled. + delayedReportJobFiresAt time.Time +} + +func (m *multicastGroupState) cancelDelayedReportJob() { + m.delayedReportJob.Cancel() + m.delayedReportJobFiresAt = time.Time{} } // GenericMulticastProtocolOptions holds options for the generic multicast @@ -174,10 +184,10 @@ type MulticastGroupProtocol interface { // // Returns false if the caller should queue the report to be sent later. Note, // returning false does not mean that the receiver hit an error. - SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error) // SendLeave sends a multicast leave for the specified group address. - SendLeave(groupAddress tcpip.Address) *tcpip.Error + SendLeave(groupAddress tcpip.Address) tcpip.Error } // GenericMulticastProtocolState is the per interface generic multicast protocol @@ -428,7 +438,7 @@ func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Ad // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { - info.delayedReportJob.Cancel() + info.cancelDelayedReportJob() info.lastToSendReport = false info.state = idleMember g.memberships[groupAddress] = info @@ -603,7 +613,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress return } - info.delayedReportJob.Cancel() + info.cancelDelayedReportJob() g.maybeSendLeave(groupAddress, info.lastToSendReport) info.lastToSendReport = false info.state = nonMember @@ -645,14 +655,24 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. + now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) if info.state == delayingMember { - // TODO: Reset the timer if time remaining is greater than maxResponseTime. - return + if info.delayedReportJobFiresAt.IsZero() { + panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) + } + + if info.delayedReportJobFiresAt.Sub(now) <= maxResponseTime { + // The timer is scheduled to fire before the maximum response time so we + // leave our timer as is. + return + } } info.state = delayingMember - info.delayedReportJob.Cancel() - info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) + info.cancelDelayedReportJob() + maxResponseTime = g.calculateDelayTimerDuration(maxResponseTime) + info.delayedReportJob.Schedule(maxResponseTime) + info.delayedReportJobFiresAt = now.Add(maxResponseTime) } // calculateDelayTimerDuration returns a random time between (0, maxRespTime]. diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 85593f211..60eaea37e 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) @@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: m.mu must be locked. -func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { m.mu.Unlock() m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) @@ -408,40 +408,46 @@ func TestHandleReport(t *testing.T) { func TestHandleQuery(t *testing.T) { tests := []struct { - name string - queryAddr tcpip.Address - maxDelay time.Duration - expectReportsFor []tcpip.Address + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectQueriedReportsFor []tcpip.Address + expectDelayedReportsFor []tcpip.Address }{ { - name: "Unpecified empty", - queryAddr: "", - maxDelay: 0, - expectReportsFor: []tcpip.Address{addr1, addr2}, + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, }, { - name: "Unpecified any", - queryAddr: "\x00", - maxDelay: 1, - expectReportsFor: []tcpip.Address{addr1, addr2}, + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectQueriedReportsFor: []tcpip.Address{addr1, addr2}, + expectDelayedReportsFor: nil, }, { - name: "Specified", - queryAddr: addr1, - maxDelay: 2, - expectReportsFor: []tcpip.Address{addr1}, + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectQueriedReportsFor: []tcpip.Address{addr1}, + expectDelayedReportsFor: []tcpip.Address{addr2}, }, { - name: "Specified all-nodes", - queryAddr: addr3, - maxDelay: 3, - expectReportsFor: nil, + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, }, { - name: "Specified other", - queryAddr: addr4, - maxDelay: 4, - expectReportsFor: nil, + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectQueriedReportsFor: nil, + expectDelayedReportsFor: []tcpip.Address{addr1, addr2}, }, } @@ -469,20 +475,20 @@ func TestHandleQuery(t *testing.T) { if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - // Generic multicast protocol timers are expected to take the job mutex. - clock.Advance(maxUnsolicitedReportDelay) - if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - // Receiving a query should make us schedule a new delayed report if it - // is a query directed at us or a general query. + // Receiving a query should make us reschedule our delayed report timer + // to some time within the new max response delay. mgp.handleQuery(test.queryAddr, test.maxDelay) - if len(test.expectReportsFor) != 0 { - clock.Advance(test.maxDelay) - if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectQueriedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The groups that were not affected by the query should still send a + // report after the max unsolicited report delay. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectDelayedReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 2a6ec19dc..6a1f11a36 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -18,6 +18,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -58,6 +59,14 @@ var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{ PrefixLen: 120, } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. // The former is used to pretend that it's a link endpoint so that we can // inspect packets written by the network endpoints. The latter is used to @@ -73,8 +82,7 @@ type testObject struct { srcAddr tcpip.Address dstAddr tcpip.Address v4 bool - typ stack.ControlType - extra uint32 + transErr transportError dataCalls int controlCalls int @@ -118,16 +126,23 @@ func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumb return stack.TransportPacketHandled } -// DeliverTransportControlPacket is called by network endpoints after parsing +// DeliverTransportError is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr stack.TransportError, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) - if typ != t.typ { - t.t.Errorf("typ = %v, want %v", typ, t.typ) - } - if extra != t.extra { - t.t.Errorf("extra = %v, want %v", extra, t.extra) + if diff := cmp.Diff( + t.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.t.Errorf("transport error mismatch (-want +got):\n%s", diff) } t.controlCalls++ } @@ -167,7 +182,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -189,7 +204,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } @@ -203,7 +218,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -219,7 +234,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -306,8 +321,16 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (*testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + return nil +} + +func (*testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + return nil } func TestSourceAddressValidation(t *testing.T) { @@ -464,7 +487,7 @@ func TestEnableWhenNICDisabled(t *testing.T) { // We pass nil for all parameters except the NetworkInterface and Stack // since Enable only depends on these. - ep := p.NewEndpoint(&nic, nil, nil, nil) + ep := p.NewEndpoint(&nic, nil) // The endpoint should initially be disabled, regardless the NIC's enabled // status. @@ -479,8 +502,9 @@ func TestEnableWhenNICDisabled(t *testing.T) { // Attempting to enable the endpoint while the NIC is disabled should // fail. nic.setEnabled(false) - if err := ep.Enable(); err != tcpip.ErrNotPermitted { - t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted) + err := ep.Enable() + if _, ok := err.(*tcpip.ErrNotPermitted); !ok { + t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{}) } // ep should consider the NIC's enabled status when determining its own // enabled status so we "enable" the NIC to read just the endpoint's @@ -525,7 +549,7 @@ func TestIPv4Send(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() // Allocate and initialize the payload view. @@ -659,7 +683,7 @@ func TestReceive(t *testing.T) { v4: test.v4, }, } - ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -692,24 +716,81 @@ func TestReceive(t *testing.T) { } func TestIPv4ReceiveControl(t *testing.T) { - const mtu = 0xbeef - header.IPv4MinimumSize + const ( + mtu = 0xbeef - header.IPv4MinimumSize + dataLen = 8 + ) + cases := []struct { name string expectedCount int fragmentOffset uint16 code header.ICMPv4Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, - {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, - {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, + { + name: "FragmentationNeeded", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4FragmentationNeeded), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing IPv4 header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize, + }, + { + name: "Truncated (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.IPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: 0, + code: header.ICMPv4FragmentationNeeded, + trunc: header.ICMPv4MinimumSize + header.ICMPv4MinimumSize + header.IPv4MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: 0, + code: header.ICMPv4PortUnreachable, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + }, + trunc: 0, + }, + { + name: "Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: 100, + code: header.ICMPv4PortUnreachable, + trunc: 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -720,7 +801,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -728,7 +809,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv4 header. ip := header.IPv4(view) @@ -775,8 +856,7 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv4Addr nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr addressableEndpoint, ok := ep.(stack.AddressableEndpoint) if !ok { @@ -809,7 +889,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -904,7 +984,7 @@ func TestIPv6Send(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, nil) + ep := proto.NewEndpoint(&nic, nil) defer ep.Close() if err := ep.Enable(); err != nil { @@ -943,30 +1023,112 @@ func TestIPv6Send(t *testing.T) { } func TestIPv6ReceiveControl(t *testing.T) { + const ( + mtu = 0xffff + outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + dataLen = 8 + ) + newUint16 := func(v uint16) *uint16 { return &v } - const mtu = 0xffff - const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + portUnreachableTransErr := transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6PortUnreachable), + kind: stack.DestinationPortUnreachableTransportError, + } + cases := []struct { name string expectedCount int fragmentOffset *uint16 typ header.ICMPv6Type code header.ICMPv6Code - expectedTyp stack.ControlType - expectedExtra uint32 + transErr transportError trunc int }{ - {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, - {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, - {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, - {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, - {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, - {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + { + name: "PacketTooBig", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6PacketTooBig), + code: uint8(header.ICMPv6UnusedCode), + info: mtu, + kind: stack.PacketTooBigTransportError, + }, + trunc: 0, + }, + { + name: "Truncated (missing offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize, + }, + { + name: "Truncated PacketTooBig (partial offending packet's IPv6 header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "Truncated (partial offending packet's data)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6PacketTooBig, + code: header.ICMPv6UnusedCode, + trunc: header.IPv6MinimumSize + header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + dataLen - 1, + }, + { + name: "Port unreachable", + expectedCount: 1, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Truncated DstPortUnreachable (partial offending packet's IP header)", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + header.IPv6MinimumSize - 1, + }, + { + name: "DstPortUnreachable for Fragmented, zero offset", + expectedCount: 1, + fragmentOffset: newUint16(0), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "DstPortUnreachable for Non-zero fragment offset", + expectedCount: 0, + fragmentOffset: newUint16(100), + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + transErr: portUnreachableTransErr, + trunc: 0, + }, + { + name: "Zero-length packet", + expectedCount: 0, + fragmentOffset: nil, + typ: header.ICMPv6DstUnreachable, + code: header.ICMPv6PortUnreachable, + trunc: 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + dataLen, + }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -977,7 +1139,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) + ep := proto.NewEndpoint(&nic, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -988,7 +1150,7 @@ func TestIPv6ReceiveControl(t *testing.T) { if c.fragmentOffset != nil { dataOffset += header.IPv6FragmentHeaderSize } - view := buffer.NewView(dataOffset + 8) + view := buffer.NewView(dataOffset + dataLen) // Create the outer IPv6 header. ip := header.IPv6(view) @@ -1039,8 +1201,7 @@ func TestIPv6ReceiveControl(t *testing.T) { nic.testObject.srcAddr = remoteIPv6Addr nic.testObject.dstAddr = localIPv6Addr nic.testObject.contents = view[dataOffset:] - nic.testObject.typ = c.expectedTyp - nic.testObject.extra = c.expectedExtra + nic.testObject.transErr = c.transErr // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) @@ -1122,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { remoteAddr tcpip.Address pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) - expectedErr *tcpip.Error + expectedErr tcpip.Error }{ { name: "IPv4", @@ -1187,7 +1348,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 too small", @@ -1205,7 +1366,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, { name: "IPv4 minimum size", @@ -1465,7 +1626,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, - expectedErr: tcpip.ErrMalformedHeader, + expectedErr: &tcpip.ErrMalformedHeader{}, }, } @@ -1506,10 +1667,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } defer r.Release() - if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr), - })); err != test.expectedErr { - t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + { + err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr), + })) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff) + } } if test.expectedErr != nil { diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 330a7d170..9713c4448 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -44,6 +44,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 6bb97c46a..74e70e283 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -23,11 +23,108 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv4DestinationUnreachableSockError is a general ICMPv4 Destination +// Unreachable error. +// +// +stateify savable +type icmpv4DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv4DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv4DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv4DestinationHostUnreachableSockError)(nil) + +// icmpv4DestinationHostUnreachableSockError is an ICMPv4 Destination Host +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination host. +// +// +stateify savable +type icmpv4DestinationHostUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationHostUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4HostUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationHostUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4DestinationPortUnreachableSockError)(nil) + +// icmpv4DestinationPortUnreachableSockError is an ICMPv4 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv4DestinationPortUnreachableSockError struct { + icmpv4DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv4PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv4DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv4FragmentationNeededSockError)(nil) + +// icmpv4FragmentationNeededSockError is an ICMPv4 Destination Unreachable error +// due to fragmentation being required but the packet was set to not be +// fragmented. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv4FragmentationNeededSockError struct { + icmpv4DestinationUnreachableSockError + + mtu uint32 +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv4FragmentationNeededSockError) Code() uint8 { + return uint8(header.ICMPv4FragmentationNeeded) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv4FragmentationNeededSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv4FragmentationNeededSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP error packet contains the headers // of the original packet that caused the ICMP one to be sent. This information // is used to find out which transport endpoint must be notified about the ICMP // packet. We only expect the payload, not the enclosing ICMP packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -54,10 +151,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack return } - // Skip the ip header, then deliver control message. + // Skip the ip header, then deliver the error. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -222,19 +319,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4HostUnreachable: - e.handleControl(stack.ControlNoRoute, 0, pkt) - + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) - + e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv4FragmentationNeededSockError{mtu: networkMTU}, pkt) } - case header.ICMPv4SrcQuench: received.srcQuench.Increment() @@ -310,7 +404,7 @@ func (*icmpReasonParamProblem) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv4(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -376,7 +470,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi netEP, ok := p.mu.eps[pkt.NICID] p.mu.Unlock() if !ok { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } sent := netEP.stats.icmp.packetsSent diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 4550aacd6..acc126c3b 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag // saying we were the last host to report is cleared, this action MAY be @@ -215,6 +215,11 @@ func (igmp *igmpState) setV1Present(v bool) { } } +func (igmp *igmpState) resetV1Present() { + igmp.igmpV1Job.Cancel() + igmp.setV1Present(false) +} + // handleMembershipQuery handles a membership query. // // Precondition: igmp.ep.mu must be locked. @@ -242,7 +247,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { // writePacket assembles and sends an IGMP packet. // // Precondition: igmp.ep.mu must be read locked. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -293,7 +298,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // messages. // // If the group already exists in the membership map, returns -// tcpip.ErrDuplicateAddress. +// *tcpip.ErrDuplicateAddress. // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { @@ -312,13 +317,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { // if required. // // Precondition: igmp.ep.mu must be locked. -func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of IGMP, but remains diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 1ee573ac8..95fd75ab7 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -101,10 +101,10 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma }) } -// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is -// present on the network. The IGMP stack will then send IGMPv1 Membership -// reports for backwards compatibility. -func TestIgmpV1Present(t *testing.T) { +// TestIGMPV1Present tests the node's ability to fallback to V1 when a V1 +// router is detected. V1 present status is expected to be reset when the NIC +// cycles. +func TestIGMPV1Present(t *testing.T) { e, s, clock := createStack(t, true) if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) @@ -116,14 +116,16 @@ func TestIgmpV1Present(t *testing.T) { // This NIC will send an IGMPv2 report immediately, before this test can get // the IGMPv1 General Membership Query in. - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) if t.Failed() { t.FailNow() } @@ -145,19 +147,38 @@ func TestIgmpV1Present(t *testing.T) { // Verify the solicited Membership Report is sent. Now that this NIC has seen // an IGMPv1 query, it should send an IGMPv1 Membership Report. - p, ok = e.Read() - if ok { + if p, ok := e.Read(); ok { t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) } clock.Advance(ipv4.UnsolicitedReportIntervalMax) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { - t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) + } + + // Cycling the interface should reset the V1 present flag. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + { + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } func TestSendQueuedIGMPReports(t *testing.T) { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a05275a5b..b2d626107 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -101,11 +101,11 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { // Use the same control type as an ICMPv4 destination host unreachable error // since the host is considered unreachable if we cannot resolve the link // address to the next hop. - e.handleControl(stack.ControlNoRoute, 0, pkt) + e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, dispatcher: dispatcher, @@ -137,14 +137,14 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { } // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -212,7 +212,9 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } @@ -221,10 +223,18 @@ func (e *endpoint) disableLocked() { e.mu.igmp.softLeaveAll() // The address may have already been removed. - if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + // Reset the IGMP V1 present flag. + // + // If the node comes back up on the same network, it will re-learn that it + // needs to perform IGMPv1. + e.mu.igmp.resetV1Present() + if !e.setEnabled(false) { panic("should have only done work to disable the endpoint if it was enabled") } @@ -256,7 +266,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error { hdrLen := header.IPv4MinimumSize var optLen int if options != nil { @@ -264,12 +274,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet } hdrLen += optLen if hdrLen > header.IPv4MaximumHeaderSize { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams @@ -294,7 +304,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -314,7 +324,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return err } @@ -353,7 +363,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -377,7 +387,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -398,7 +408,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -423,7 +433,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -488,22 +498,22 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } hdrLen := header.IPv4(h).HeaderLength() if hdrLen < header.IPv4MinimumSize { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } h, ok = pkt.Data.PullUp(int(hdrLen)) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv4(h) @@ -541,14 +551,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // wire format. We also want to check if the header's fields are valid before // sending the packet. if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv4(pkt.NetworkHeader().View()) ttl := h.TTL() if ttl == 0 { @@ -568,7 +578,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -730,7 +740,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } proto := h.Protocol() - data, _, ready, err := e.protocol.fragmentation.Process( + resPkt, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -753,7 +763,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if !ready { return } - pkt.Data = data + pkt = resPkt + h = header.IPv4(pkt.NetworkHeader().View()) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. @@ -825,7 +836,7 @@ func (e *endpoint) Close() { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() @@ -837,7 +848,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.mu.addressableEndpointState.RemovePermanentAddress(addr) @@ -894,7 +905,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -903,9 +914,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.igmp.joinGroup(addr) @@ -913,7 +924,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -922,7 +933,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.igmp.leaveGroup(addr) } @@ -995,24 +1006,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1058,9 +1069,9 @@ func (p *protocol) SetForwarding(v bool) { // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. -func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv4MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in @@ -1068,7 +1079,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro // The maximal internet header is 60 octets, and a typical internet header // is 20 octets, allowing a margin for headers of higher level protocols. if networkHeaderSize > header.IPv4MaximumHeaderSize { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index dac7cbfd4..a296bed79 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -38,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -78,8 +79,11 @@ func TestExcludeBroadcast(t *testing.T) { defer ep.Close() // Cannot connect using a broadcast address as the source. - if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + { + err := ep.Connect(randomAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{}) + } } // However, we can bind to a broadcast address to listen. @@ -1376,8 +1380,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -1386,8 +1390,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -1396,8 +1400,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -1406,8 +1410,8 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag MTU smaller than header", @@ -1416,8 +1420,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize: 500, allowPackets: 0, outgoingErrors: 4, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than IPv4 minimum MTU", @@ -1427,7 +1431,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -1441,8 +1445,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -2055,7 +2059,7 @@ func TestReceiveFragments(t *testing.T) { // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] - // Used to test the max reassembled payload length (65,535 octets). + // Used to test the max reassembled IPv4 payload length. ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] @@ -2403,6 +2407,7 @@ func TestReceiveFragments(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, }) e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) if err := s.CreateNIC(nicID, e); err != nil { @@ -2428,6 +2433,13 @@ func TestReceiveFragments(t *testing.T) { t.Fatalf("Bind(%+v): %s", bindAddr, err) } + // Bring up a raw endpoint so we can examine network headers. + epRaw, err := s.NewRawEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewRawEndpoint(%d, %d, _, true): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) + } + defer epRaw.Close() + // Prepare and send the fragments. for _, frag := range test.fragments { hdr := buffer.NewPrependable(header.IPv4MinimumSize) @@ -2459,10 +2471,11 @@ func TestReceiveFragments(t *testing.T) { } for i, expectedPayload := range test.expectedPayloads { + // Check UDP payload delivered by UDP endpoint. var buf bytes.Buffer result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read: %s", i, err) + t.Fatalf("(i=%d) ep.Read: %s", i, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: len(expectedPayload), @@ -2471,12 +2484,30 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) } if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { - t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + t.Errorf("(i=%d) ep.Read: UDP payload mismatch (-want +got):\n%s", i, diff) + } + + // Check IPv4 header in packet delivered by raw endpoint. + buf.Reset() + result, err = epRaw.Read(&buf, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("(i=%d) epRaw.Read: %s", i, err) + } + // Reassambly does not take care of checksum. Here we write our own + // check routine instead of using checker.IPv4. + ip := header.IPv4(buf.Bytes()) + for _, check := range []checker.NetworkChecker{ + checker.FragmentFlags(0), + checker.FragmentOffset(0), + checker.IPFullLength(uint16(header.IPv4MinimumSize + header.UDPMinimumSize + len(expectedPayload))), + } { + check(t, []header.Network{ip}) } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2554,11 +2585,11 @@ func TestWriteStats(t *testing.T) { // Parameterize the tests to run with both WritePacket and WritePackets. writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2570,7 +2601,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2580,7 +2611,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go index b28e7dcde..fbbc6e69c 100644 --- a/pkg/tcpip/network/ipv4/stats_test.go +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -50,7 +50,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) nic := testInterface{nicID: 1} - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -82,7 +82,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // expected to be bound by a MultiCounterStat. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 95efada3a..dcfd93bab 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -23,11 +23,136 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// icmpv6DestinationUnreachableSockError is a general ICMPv6 Destination +// Unreachable error. +// +// +stateify savable +type icmpv6DestinationUnreachableSockError struct{} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Type() uint8 { + return uint8(header.ICMPv6DstUnreachable) +} + +// Info implements tcpip.SockErrorCause. +func (*icmpv6DestinationUnreachableSockError) Info() uint32 { + return 0 +} + +var _ stack.TransportError = (*icmpv6DestinationNetworkUnreachableSockError)(nil) + +// icmpv6DestinationNetworkUnreachableSockError is an ICMPv6 Destination Network +// Unreachable error. +// +// It indicates that the destination network is unreachable. +// +// +stateify savable +type icmpv6DestinationNetworkUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationNetworkUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6NetworkUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationNetworkUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationNetworkUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationPortUnreachableSockError)(nil) + +// icmpv6DestinationPortUnreachableSockError is an ICMPv6 Destination Port +// Unreachable error. +// +// It indicates that a packet reached the destination host, but the transport +// protocol was not active on the destination port. +// +// +stateify savable +type icmpv6DestinationPortUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationPortUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6PortUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationPortUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationPortUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6DestinationAddressUnreachableSockError)(nil) + +// icmpv6DestinationAddressUnreachableSockError is an ICMPv6 Destination Address +// Unreachable error. +// +// It indicates that a packet was not able to reach the destination. +// +// +stateify savable +type icmpv6DestinationAddressUnreachableSockError struct { + icmpv6DestinationUnreachableSockError +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6DestinationAddressUnreachableSockError) Code() uint8 { + return uint8(header.ICMPv6AddressUnreachable) +} + +// Kind implements stack.TransportError. +func (*icmpv6DestinationAddressUnreachableSockError) Kind() stack.TransportErrorKind { + return stack.DestinationHostUnreachableTransportError +} + +var _ stack.TransportError = (*icmpv6PacketTooBigSockError)(nil) + +// icmpv6PacketTooBigSockError is an ICMPv6 Packet Too Big error. +// +// It indicates that a link exists on the path to the destination with an MTU +// that is too small to carry the packet. +// +// +stateify savable +type icmpv6PacketTooBigSockError struct { + mtu uint32 +} + +// Origin implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Origin() tcpip.SockErrOrigin { + return tcpip.SockExtErrorOriginICMP6 +} + +// Type implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Type() uint8 { + return uint8(header.ICMPv6PacketTooBig) +} + +// Code implements tcpip.SockErrorCause. +func (*icmpv6PacketTooBigSockError) Code() uint8 { + return uint8(header.ICMPv6UnusedCode) +} + +// Info implements tcpip.SockErrorCause. +func (e *icmpv6PacketTooBigSockError) Info() uint32 { + return e.mtu +} + +// Kind implements stack.TransportError. +func (*icmpv6PacketTooBigSockError) Kind() stack.TransportErrorKind { + return stack.PacketTooBigTransportError +} + // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -67,8 +192,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack p = fragHdr.TransportProtocol() } - // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportError(src, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -175,7 +299,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { networkMTU = 0 } - e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) + e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: received.dstUnreachable.Increment() @@ -187,11 +311,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { case header.ICMPv6NetworkUnreachable: - e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationPortUnreachableSockError{}, pkt) } - case header.ICMPv6NeighborSolicit: received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { @@ -237,7 +360,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } } @@ -287,10 +412,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } else if unspecifiedSource { received.invalid.Increment() return - } else if e.nud != nil { - e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr) + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } // As per RFC 4861 section 7.1.1: @@ -413,10 +542,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState { + switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: + return + default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) } - return } it, err := na.Options().Iter(false /* check */) @@ -441,20 +572,30 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - // If the NA message has the target link layer option, update the link - // address cache with the link address for the target of the message. - if e.nud == nil { - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr) - } + // As per RFC 4861 section 7.1.2: + // A node MUST silently discard any received Neighbor Advertisement + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP Destination Address is a multicast address the + // Solicited flag is zero. + if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() { + received.invalid.Increment() return } - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + // If the NA message has the target link layer option, update the link + // address cache with the link address for the target of the message. + switch err := e.nic.HandleNeighborConfirmation(ProtocolNumber, targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ Solicited: na.SolicitedFlag(), Override: na.OverrideFlag(), IsRouter: na.RouterFlag(), - }) + }); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor confirmation message: %s", err)) + } case header.ICMPv6EchoRequest: received.echoRequest.Increment() @@ -560,10 +701,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } - if e.nud != nil { - // A RS with a specified source IP address modifies the NUD state - // machine in the same way a reachability probe would. - e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + // A RS with a specified source IP address modifies the neighbor table + // in the same way a regular probe would. + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, srcAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) } } @@ -612,8 +757,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. - if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) + if len(sourceLinkAddr) != 0 { + switch err := e.nic.HandleNeighborProbe(ProtocolNumber, routerAddr, sourceLinkAddr); err.(type) { + case nil: + case *tcpip.ErrNotSupported: + // The stack may support ICMPv6 but the NIC may not need link resolution. + default: + panic(fmt.Sprintf("unexpected error when informing NIC of neighbor probe message: %s", err)) + } } e.mu.Lock() @@ -679,24 +830,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } } -var _ stack.LinkAddressResolver = (*protocol)(nil) - // LinkAddressProtocol implements stack.LinkAddressResolver. -func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*endpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } // LinkAddressRequest implements stack.LinkAddressResolver. -func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { - nicID := nic.ID() - - p.mu.Lock() - netEP, ok := p.mu.eps[nicID] - p.mu.Unlock() - if !ok { - return tcpip.ErrNotConnected - } - +func (e *endpoint) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { remoteAddr := targetAddr if len(remoteLinkAddr) == 0 { remoteAddr = header.SolicitedNodeAddr(targetAddr) @@ -704,22 +844,22 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } if len(localAddr) == 0 { - addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) + addressEndpoint := e.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */) if addressEndpoint == nil { - return tcpip.ErrNetworkUnreachable + return &tcpip.ErrNetworkUnreachable{} } localAddr = addressEndpoint.AddressWithPrefix().Address - } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 { - return tcpip.ErrBadLocalAddress + } else if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, localAddr) == 0 { + return &tcpip.ErrBadLocalAddress{} } optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(e.nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) @@ -736,9 +876,9 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot panic(fmt.Sprintf("failed to add IP header: %s", err)) } - stat := netEP.stats.icmp.packetsSent + stat := e.stats.icmp.packetsSent - if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { stat.dropped.Increment() return err } @@ -748,7 +888,7 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot } // ResolveStaticAddress implements stack.LinkAddressResolver. -func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if header.IsV6MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv6Address(addr), true } @@ -813,7 +953,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { origIPHdr := header.IPv6(pkt.NetworkHeader().View()) origIPHdrSrc := origIPHdr.SourceAddress() origIPHdrDst := origIPHdr.DestinationAddress() @@ -884,7 +1024,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi netEP, ok := p.mu.eps[pkt.NICID] p.mu.Unlock() if !ok { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } sent := netEP.stats.icmp.packetsSent diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 641c60b7c..92f9ee2c2 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -79,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return nil } @@ -93,35 +93,14 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } -var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil) - -type stubLinkAddressCache struct{} - -func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {} - -type stubNUDHandler struct { - probeCount int - confirmationCount int -} - -var _ stack.NUDHandler = (*stubNUDHandler)(nil) - -func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { - s.probeCount++ -} - -func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { - s.confirmationCount++ -} - -func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { -} - var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.LinkEndpoint + probeCount int + confirmationCount int + nicID tcpip.NICID } @@ -145,21 +124,31 @@ func (*testInterface) Promiscuous() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var r stack.RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) } +func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { + t.probeCount++ + return nil +} + +func (t *testInterface) HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) tcpip.Error { + t.confirmationCount++ + return nil +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -202,7 +191,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -360,7 +349,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -573,12 +562,19 @@ func newTestContext(t *testing.T) *testContext { }}, ) - return c -} + t.Cleanup(func() { + if err := c.s0.RemoveNIC(nicID); err != nil { + t.Errorf("c.s0.RemoveNIC(%d): %s", nicID, err) + } + if err := c.s1.RemoveNIC(nicID); err != nil { + t.Errorf("c.s1.RemoveNIC(%d): %s", nicID, err) + } -func (c *testContext) cleanup() { - c.linkEP0.Close() - c.linkEP1.Close() + c.linkEP0.Close() + c.linkEP1.Close() + }) + + return c } type routeArgs struct { @@ -628,7 +624,6 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. func TestLinkResolution(t *testing.T) { c := newTestContext(t) - defer c.cleanup() r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { @@ -1283,7 +1278,7 @@ func TestLinkAddressRequest(t *testing.T) { localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectedErr *tcpip.Error + expectedErr tcpip.Error expectedRemoteAddr tcpip.Address expectedRemoteLinkAddr tcpip.LinkAddress }{ @@ -1321,23 +1316,23 @@ func TestLinkAddressRequest(t *testing.T) { name: "Unicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrBadLocalAddress, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Multicast with unassigned address", localAddr: lladdr1, remoteLinkAddr: "", - expectedErr: tcpip.ErrBadLocalAddress, + expectedErr: &tcpip.ErrBadLocalAddress{}, }, { name: "Unicast with no local address available", remoteLinkAddr: linkAddr1, - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, { name: "Multicast with no local address available", remoteLinkAddr: "", - expectedErr: tcpip.ErrNetworkUnreachable, + expectedErr: &tcpip.ErrNetworkUnreachable{}, }, } @@ -1346,28 +1341,32 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) - p := s.NetworkProtocolInstance(ProtocolNumber) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") - } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) if err := s.CreateNIC(nicID, linkEP); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err) + } + linkRes, ok := ep.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected %T to implement stack.LinkAddressResolver", ep) + } + if len(test.nicAddr) != 0 { if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) } } - // We pass a test network interface to LinkAddressRequest with the same NIC - // ID and link endpoint used by the NIC we created earlier so that we can - // mock a link address request and observe the packets sent to the link - // endpoint even though the stack uses the real NIC. - if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + { + err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff) + } } if test.expectedErr != nil { @@ -1797,8 +1796,9 @@ func TestCallsToNeighborCache(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + + testInterface := testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)} + ep := netProto.NewEndpoint(&testInterface, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1833,11 +1833,11 @@ func TestCallsToNeighborCache(t *testing.T) { ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. - if nudHandler.probeCount != test.wantProbeCount { - t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + if testInterface.probeCount != test.wantProbeCount { + t.Errorf("got testInterface.probeCount = %d, want = %d", testInterface.probeCount, test.wantProbeCount) } - if nudHandler.confirmationCount != test.wantConfirmationCount { - t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + if testInterface.confirmationCount != test.wantConfirmationCount { + t.Errorf("got testInterface.confirmationCount = %d, want = %d", testInterface.confirmationCount, test.wantConfirmationCount) } }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d658f9bcb..c2e8c3ea7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -164,6 +164,7 @@ func getLabel(addr tcpip.Address) uint8 { panic(fmt.Sprintf("should have a label for address = %s", addr)) } +var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) @@ -172,13 +173,11 @@ var _ stack.NDPEndpoint = (*endpoint)(nil) var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { - nic stack.NetworkInterface - linkAddrCache stack.LinkAddressCache - nud stack.NUDHandler - dispatcher stack.TransportDispatcher - protocol *protocol - stack *stack.Stack - stats sharedStats + nic stack.NetworkInterface + dispatcher stack.TransportDispatcher + protocol *protocol + stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -236,7 +235,7 @@ func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { }) pkt.NICID = e.nic.ID() pkt.NetworkProtocolNumber = ProtocolNumber - e.handleControl(stack.ControlAddressUnreachable, 0, pkt) + e.handleControl(&icmpv6DestinationAddressUnreachableSockError{}, pkt) } // onAddressAssignedLocked handles an address being assigned. @@ -307,17 +306,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } if addressEndpoint.GetKind() != stack.PermanentTentative { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an @@ -369,14 +368,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) { } // Enable implements stack.NetworkEndpoint. -func (e *endpoint) Enable() *tcpip.Error { +func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If the NIC is not enabled, the endpoint can't do anything meaningful so // don't enable the endpoint. if !e.nic.Enabled() { - return tcpip.ErrNotPermitted + return &tcpip.ErrNotPermitted{} } // If the endpoint is already enabled, there is nothing for it to do. @@ -418,7 +417,7 @@ func (e *endpoint) Enable() *tcpip.Error { // // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. - var err *tcpip.Error + var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { @@ -499,7 +498,9 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } @@ -555,11 +556,11 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error { +func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error { extHdrsLen := extensionHeaders.Length() length := pkt.Size() + extensionHeaders.Length() if length > math.MaxUint16 { - return tcpip.ErrMessageTooLong + return &tcpip.ErrMessageTooLong{} } ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ @@ -585,7 +586,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta // fragments left to be processed. The IP header must already be present in the // original packet. The transport header protocol number is required to avoid // parsing the IPv6 extension headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { networkHeader := header.IPv6(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are @@ -598,13 +599,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // of 8 as per RFC 8200 section 4.5: // Each complete fragment, except possibly the last ("rightmost") one, is // an integer multiple of 8 octets long. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. - return 0, 1, tcpip.ErrMessageTooLong + return 0, 1, &tcpip.ErrMessageTooLong{} } pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) @@ -624,7 +625,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil { return err } @@ -662,7 +663,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { @@ -685,7 +686,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -707,7 +708,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } @@ -731,7 +732,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -798,11 +799,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } ip := header.IPv6(h) @@ -827,14 +828,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // sending the packet. proto, _, _, _, ok := parse.IPv6(pkt) if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv6(pkt.NetworkHeader().View()) hopLimit := h.HopLimit() if hopLimit <= 1 { @@ -856,7 +857,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { networkEndpoint.(*endpoint).handlePacket(pkt) return nil } - if err != tcpip.ErrBadAddress { + if _, ok := err.(*tcpip.ErrBadAddress); !ok { return err } @@ -1165,7 +1166,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - data, proto, ready, err := e.protocol.fragmentation.Process( + resPkt, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ @@ -1186,7 +1187,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { } if ready { - pkt.Data = data + pkt = resPkt // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC @@ -1330,7 +1331,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. e.mu.Lock() @@ -1345,7 +1346,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p // solicited-node multicast group and start duplicate address detection. // // Precondition: e.mu must be write locked. -func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { +func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) if err != nil { return nil, err @@ -1374,13 +1375,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } // RemovePermanentAddress implements stack.AddressableEndpoint. -func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return e.removePermanentEndpointLocked(addressEndpoint, true) @@ -1390,7 +1391,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() unicast := header.IsV6UnicastAddress(addr.Address) if unicast { @@ -1415,12 +1416,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) + err := e.leaveGroupLocked(snmc) // The endpoint may have already left the multicast group. - if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { - return err + if _, ok := err.(*tcpip.ErrBadLocalAddress); ok { + err = nil } - - return nil + return err } // hasPermanentAddressLocked returns true if the endpoint has a permanent @@ -1630,7 +1631,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.joinGroupLocked(addr) @@ -1639,9 +1640,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { // joinGroupLocked is like JoinGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return tcpip.ErrBadAddress + return &tcpip.ErrBadAddress{} } e.mu.mld.joinGroup(addr) @@ -1649,7 +1650,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() return e.leaveGroupLocked(addr) @@ -1658,7 +1659,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // leaveGroupLocked is like LeaveGroup but with locking requirements. // // Precondition: e.mu must be locked. -func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error { return e.mu.mld.leaveGroup(addr) } @@ -1730,13 +1731,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ - nic: nic, - linkAddrCache: linkAddrCache, - nud: nud, - dispatcher: dispatcher, - protocol: p, + nic: nic, + dispatcher: dispatcher, + protocol: p, } e.mu.Lock() e.mu.addressableEndpointState.Init(e) @@ -1762,24 +1761,24 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: p.SetDefaultTTL(uint8(*v)) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -1842,9 +1841,9 @@ func (p *protocol) SetForwarding(v bool) { // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, // which includes the length of the extension headers. -func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) { if linkMTU < header.IPv6MinimumMTU { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } // As per RFC 7112 section 5, we should discard packets if their IPv6 header @@ -1855,7 +1854,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro // bytes ensures that the header chain length does not exceed the IPv6 // minimum MTU. if networkHeadersLen > header.IPv6MinimumMTU { - return 0, tcpip.ErrMalformedHeader + return 0, &tcpip.ErrMalformedHeader{} } networkMTU := linkMTU - uint32(networkHeadersLen) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 5276878a0..1c6c37c91 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -996,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { } // Should not have any more UDP packets. - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -1988,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) + res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{}) } }) } @@ -2472,11 +2474,11 @@ func TestWriteStats(t *testing.T) { writers := []struct { name string - writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error) }{ { name: "WritePacket", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { @@ -2488,7 +2490,7 @@ func TestWriteStats(t *testing.T) { }, }, { name: "WritePackets", - writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) }, }, @@ -2498,7 +2500,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { @@ -2597,7 +2599,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) var nicIDs []tcpip.NICID proto.mu.Lock() @@ -2832,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) { payloadSize int allowPackets int outgoingErrors int - mockError *tcpip.Error - wantError *tcpip.Error + mockError tcpip.Error + wantError tcpip.Error }{ { description: "No frag", @@ -2842,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 1, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on first frag", @@ -2852,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 0, outgoingErrors: 3, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error on second frag", @@ -2862,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) { transHdrLen: 0, allowPackets: 1, outgoingErrors: 2, - mockError: tcpip.ErrAborted, - wantError: tcpip.ErrAborted, + mockError: &tcpip.ErrAborted{}, + wantError: &tcpip.ErrAborted{}, }, { description: "Error when MTU is smaller than transport header", @@ -2873,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrMessageTooLong, + wantError: &tcpip.ErrMessageTooLong{}, }, { description: "Error when MTU is smaller than IPv6 minimum MTU", @@ -2883,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) { allowPackets: 0, outgoingErrors: 1, mockError: nil, - wantError: tcpip.ErrInvalidEndpointState, + wantError: &tcpip.ErrInvalidEndpointState{}, }, } @@ -2897,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) { TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.wantError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + if diff := cmp.Diff(ft.wantError, err); diff != "" { + t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff) } if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) @@ -3073,7 +3075,7 @@ func TestMultiCounterStatsInitialization(t *testing.T) { }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) var nic testInterface - ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + ep := proto.NewEndpoint(&nic, nil).(*endpoint) // At this point, the Stack's stats and the NetworkEndpoint's stats are // supposed to be bound. refStack := s.Stats() diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index c376016e9..2cc0dfebd 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool { // SendReport implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } @@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { // joinGroup handles joining a new group and sending and scheduling the required // messages. // -// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// If the group is already joined, returns *tcpip.ErrDuplicateAddress. // // Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { @@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { // required. // // Precondition: mld.ep.mu must be locked. -func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error { // LeaveGroup returns false only if the group was not joined. if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } // softLeaveAll leaves all groups from the perspective of MLD, but remains @@ -166,7 +166,7 @@ func (mld *mldState) sendQueuedReports() { // writePacket assembles and sends an MLD packet. // // Precondition: mld.ep.mu must be read locked. -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) { sentStats := mld.ep.stats.icmp.packetsSent var mldStat tcpip.MultiCounterStat switch mldType { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index ca4ff621d..d7dde1767 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -241,7 +241,7 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered @@ -614,10 +614,10 @@ type slaacPrefixState struct { // tentative. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { // addr must be a valid unicast IPv6 address. if !header.IsV6UnicastAddress(addr) { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } if addressEndpoint.GetKind() != stack.PermanentTentative { @@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE dadDone := remaining == 0 - var err *tcpip.Error + var err tcpip.Error if !dadDone { err = ndp.sendDADPacket(addr, addressEndpoint) } @@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error { snmc := header.SolicitedNodeAddr(addr) icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 7a22309e5..8edaa9508 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -63,7 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) if err := ep.Enable(); err != nil { t.Fatalf("ep.Enable(): %s", err) } @@ -90,7 +90,7 @@ type testNDPDispatcher struct { addr tcpip.Address } -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { @@ -167,10 +167,10 @@ type linkResolutionResult struct { ok bool } -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. -func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -199,6 +199,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -242,17 +243,19 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { }) wantInvalid := uint64(0) - wantErr := (*tcpip.Error)(nil) wantSucccess := true if len(test.expectedLinkAddr) == 0 { wantInvalid = 1 - wantErr = tcpip.ErrWouldBlock wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) + } + } else { + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } } - if err != wantErr { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) - } if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) } @@ -263,11 +266,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } -// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NS message with the Source Link Layer Address // option results in a new entry in the link address cache for the sender of // the message. -func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -335,18 +338,18 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -380,7 +383,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } } -func TestNeighorSolicitationResponse(t *testing.T) { +func TestNeighborSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 remoteAddr := lladdr1 @@ -719,10 +722,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { } } -// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a // valid NDP NA message with the Target Link Layer Address option results in a // new entry in the link address cache for the target of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { const nicID = 1 tests := []struct { @@ -756,8 +759,10 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseLinkAddrCache: true, }) e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -801,17 +806,19 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { }) wantInvalid := uint64(0) - wantErr := (*tcpip.Error)(nil) wantSucccess := true if len(test.expectedLinkAddr) == 0 { wantInvalid = 1 - wantErr = tcpip.ErrWouldBlock wantSucccess = false + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{}) + } + } else { + if err != nil { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } } - if err != wantErr { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) - } if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) } @@ -822,11 +829,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { } } -// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests // that receiving a valid NDP NA message with the Target Link Layer Address // option does not result in a new entry in the neighbor cache for the target // of the message. -func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { +func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { const nicID = 1 tests := []struct { @@ -901,18 +908,18 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test Data: hdr.View().ToVectorisedView(), })) - neighbors, err := s.Neighbors(nicID) + neighbors, err := s.Neighbors(nicID, ProtocolNumber) if err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) } neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) for _, n := range neighbors { if existing, ok := neighborByAddr[n.Addr]; ok { if diff := cmp.Diff(existing, n); diff != "" { - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, ProtocolNumber, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) + t.Fatalf("s.Neighbors(%d, %d) returned unexpected duplicate neighbor entry: %#v", nicID, ProtocolNumber, existing) } neighborByAddr[n.Addr] = n } @@ -1168,6 +1175,118 @@ func TestNDPValidation(t *testing.T) { } +// TestNeighborAdvertisementValidation tests that the NIC validates received +// Neighbor Advertisements. +// +// In particular, if the IP Destination Address is a multicast address, and the +// Solicited flag is not zero, the Neighbor Advertisement is invalid and should +// be discarded. +func TestNeighborAdvertisementValidation(t *testing.T) { + tests := []struct { + name string + ipDstAddr tcpip.Address + solicitedFlag bool + valid bool + }{ + { + name: "Multicast IP destination address with Solicited flag set", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: true, + valid: false, + }, + { + name: "Multicast IP destination address with Solicited flag unset", + ipDstAddr: header.IPv6AllNodesMulticastAddress, + solicitedFlag: false, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag set", + ipDstAddr: lladdr0, + solicitedFlag: true, + valid: true, + }, + { + name: "Unicast IP destination address with Solicited flag unset", + ipDstAddr: lladdr0, + solicitedFlag: false, + valid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + UseNeighborCache: true, + }) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.MessageBody()) + na.SetTargetAddress(lladdr1) + na.SetSolicitedFlag(test.solicitedFlag) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: test.ipDstAddr, + }) + + stats := s.Stats().ICMP.V6.PacketsReceived + invalid := stats.Invalid + rxNA := stats.NeighborAdvert + + if got := rxNA.Value(); got != 0 { + t.Fatalf("got rxNA = %d, want = 0", got) + } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + + if got := rxNA.Value(); got != 1 { + t.Fatalf("got rxNA = %d, want = 1", got) + } + var wantInvalid uint64 = 1 + if test.valid { + wantInvalid = 0 + } + if got := invalid.Value(); got != wantInvalid { + t.Fatalf("got invalid = %d, want = %d", got, wantInvalid) + } + // As per RFC 4861 section 7.2.5: + // When a valid Neighbor Advertisement is received ... + // If no entry exists, the advertisement SHOULD be silently discarded. + // There is no need to create an entry if none exists, since the + // recipient has apparently not initiated any communication with the + // target. + if neighbors, err := s.Neighbors(nicID, ProtocolNumber); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + }) + } +} + // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 9bd009374..f5fa77b65 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -35,7 +35,7 @@ type MockLinkEndpoint struct { WrittenPackets []*stack.PacketBuffer mtu uint32 - err *tcpip.Error + err tcpip.Error allowPackets int } @@ -43,7 +43,7 @@ type MockLinkEndpoint struct { // // err is the error that will be returned once allowPackets packets are written // to the endpoint. -func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint { +func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint { return &MockLinkEndpoint{ mtu: mtu, err: err, @@ -64,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } // WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if ep.allowPackets == 0 { return ep.err } @@ -74,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip } // WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { var n int for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 2bad05a2e..57abec5c9 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -18,5 +18,6 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index d87193650..11dbdbbcf 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -329,7 +329,7 @@ func NewPortManager() *PortManager { // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { offset := uint32(rand.Int31n(numEphemeralPorts)) return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) } @@ -348,7 +348,7 @@ func (s *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) if err == nil { s.incPortHint() @@ -361,7 +361,7 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { +func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { for i := uint32(0); i < count; i++ { port = uint16(FirstEphemeral + (offset+i)%count) ok, err := testPort(port) @@ -374,7 +374,7 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } } - return 0, tcpip.ErrNoPortAvailable + return 0, &tcpip.ErrNoPortAvailable{} } // IsPortAvailable tests if the given port is available on all given protocols. @@ -404,7 +404,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // An optional testPort closure can be passed in which if provided will be used // to test if the picked port can be used. The function should return true if // the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -414,17 +414,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // protocols. if port != 0 { if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } if testPort != nil && !testPort(port) { s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, tcpip.ErrPortInUse + return 0, &tcpip.ErrPortInUse{} } return port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { return false, nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 4bc949fd8..e70fbb72b 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math/rand" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -32,7 +33,7 @@ const ( type portReserveTestAction struct { port uint16 ip tcpip.Address - want *tcpip.Error + want tcpip.Error flags Flags release bool device tcpip.NICID @@ -50,19 +51,19 @@ func TestPortReservation(t *testing.T) { {port: 80, ip: fakeIPAddress, want: nil}, {port: 80, ip: fakeIPAddress1, want: nil}, /* N.B. Order of tests matters! */ - {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse}, - {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, }, }, { tname: "bind to inaddr any", actions: []portReserveTestAction{ {port: 22, ip: anyIPAddress, want: nil}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, /* release fakeIPAddress, but anyIPAddress is still inuse */ {port: 22, ip: fakeIPAddress, release: true}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}}, /* Release port 22 from any IP address, then try to reserve fake IP address on 22 */ {port: 22, ip: anyIPAddress, want: nil, release: true}, {port: 22, ip: fakeIPAddress, want: nil}, @@ -80,8 +81,8 @@ func TestPortReservation(t *testing.T) { {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, }, @@ -91,14 +92,14 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, - {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true}, {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil}, @@ -107,7 +108,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with device fails", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 3, want: nil}, - {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind to device", @@ -119,50 +120,50 @@ func TestPortReservation(t *testing.T) { tname: "bind to device and then without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind without device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, }, }, { tname: "binding with reuseport and device", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "mixing reuseport and not reuseport by binding to device", @@ -177,14 +178,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 456, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind and release", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil}, // Release the bind to device 0 and try again. @@ -195,7 +196,7 @@ func TestPortReservation(t *testing.T) { tname: "bind twice with reuseport once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "release an unreserved device", @@ -213,16 +214,16 @@ func TestPortReservation(t *testing.T) { tname: "bind with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}}, {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil}, }, }, { tname: "bind twice with reuseaddr once", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil}, - {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport", @@ -236,14 +237,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport, and then reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseaddr and reuseport twice, and then reuseaddr", @@ -264,14 +265,14 @@ func TestPortReservation(t *testing.T) { actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind with reuseport, and then reuseaddr and reuseport", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", @@ -283,7 +284,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, - {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", @@ -295,7 +296,7 @@ func TestPortReservation(t *testing.T) { tname: "bind tuple with reuseaddr, and then wildcard", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind two tuples with reuseaddr", @@ -313,7 +314,7 @@ func TestPortReservation(t *testing.T) { tname: "bind wildcard, and then tuple with reuseaddr", actions: []portReserveTestAction{ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, - {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}}, }, }, { tname: "bind wildcard twice with reuseaddr", @@ -333,8 +334,8 @@ func TestPortReservation(t *testing.T) { continue } gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) - if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) + if diff := cmp.Diff(test.want, err); diff != "" { + t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) @@ -345,30 +346,29 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -378,49 +378,52 @@ func TestPickEphemeralPort(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPort(test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } } func TestPickEphemeralPortStable(t *testing.T) { - customErr := &tcpip.Error{} for _, test := range []struct { name string - f func(port uint16) (bool, *tcpip.Error) - wantErr *tcpip.Error + f func(port uint16) (bool, tcpip.Error) + wantErr tcpip.Error wantPort uint16 }{ { name: "no-port-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, { name: "port-tester-error", - f: func(port uint16) (bool, *tcpip.Error) { - return false, customErr + f: func(port uint16) (bool, tcpip.Error) { + return false, &tcpip.ErrBadBuffer{} }, - wantErr: customErr, + wantErr: &tcpip.ErrBadBuffer{}, }, { name: "only-port-16042-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port == FirstEphemeral+42 { return true, nil } @@ -430,20 +433,24 @@ func TestPickEphemeralPortStable(t *testing.T) { }, { name: "only-port-under-16000-available", - f: func(port uint16) (bool, *tcpip.Error) { + f: func(port uint16) (bool, tcpip.Error) { if port < FirstEphemeral { return true, nil } return false, nil }, - wantErr: tcpip.ErrNoPortAvailable, + wantErr: &tcpip.ErrNoPortAvailable{}, }, } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) - if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr { - t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + port, err := pm.PickEphemeralPortStable(portOffset, test.f) + if diff := cmp.Diff(test.wantErr, err); diff != "" { + t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) + } + if port != test.wantPort { + t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort) } }) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3d9954c84..856ea998d 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -175,7 +175,7 @@ func main() { waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventOut) terr := ep.Connect(remote) - if terr == tcpip.ErrConnectStarted { + if _, ok := terr.(*tcpip.ErrConnectStarted); ok { fmt.Println("Connect is pending...") <-notifyCh terr = ep.LastError() @@ -198,11 +198,11 @@ func main() { for { _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrClosedForReceive { + if _, ok := err.(*tcpip.ErrClosedForReceive); ok { break } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index ae9cf44e7..9b23df3a9 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -51,7 +51,7 @@ type endpointWriter struct { } type tcpipError struct { - inner *tcpip.Error + inner tcpip.Error } func (e *tcpipError) Error() string { @@ -89,7 +89,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { for { _, err := ep.Read(&w, tcpip.ReadOptions{}) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } @@ -217,7 +217,7 @@ func main() { for { n, wq, err := ep.Accept(nil) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-notifyCh continue } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 7eabbc599..1e00144a5 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -46,16 +46,16 @@ type SocketOptionsHandler interface { OnCorkOptionSet(v bool) // LastError is invoked when SO_ERROR is read for an endpoint. - LastError() *Error + LastError() Error // UpdateLastError updates the endpoint specific last error field. - UpdateLastError(err *Error) + UpdateLastError(err Error) // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. HasNIC(v int32) bool // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE. - GetSendBufferSize() (int64, *Error) + GetSendBufferSize() (int64, Error) // IsUnixSocket is invoked to check if the socket is of unix domain. IsUnixSocket() bool @@ -83,12 +83,12 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} // LastError implements SocketOptionsHandler.LastError. -func (*DefaultSocketOptionsHandler) LastError() *Error { +func (*DefaultSocketOptionsHandler) LastError() Error { return nil } // UpdateLastError implements SocketOptionsHandler.UpdateLastError. -func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} +func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {} // HasNIC implements SocketOptionsHandler.HasNIC. func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { @@ -96,7 +96,7 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { } // GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize. -func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, *Error) { +func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) { return 0, nil } @@ -109,11 +109,11 @@ func (*DefaultSocketOptionsHandler) IsUnixSocket() bool { // implemented by the stack. type StackHandler interface { // Option allows retrieving stack wide options. - Option(option interface{}) *Error + Option(option interface{}) Error // TransportProtocolOption allows retrieving individual protocol level // option values. - TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) *Error + TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error } // SocketOptions contains all the variables which store values for SOL_SOCKET, @@ -238,7 +238,7 @@ func storeAtomicBool(addr *uint32, v bool) { } // SetLastError sets the last error for a socket. -func (so *SocketOptions) SetLastError(err *Error) { +func (so *SocketOptions) SetLastError(err Error) { so.handler.UpdateLastError(err) } @@ -423,7 +423,7 @@ func (so *SocketOptions) SetRecvError(v bool) { } // GetLastError gets value for SO_ERROR option. -func (so *SocketOptions) GetLastError() *Error { +func (so *SocketOptions) GetLastError() Error { return so.handler.LastError() } @@ -473,6 +473,48 @@ func (origin SockErrOrigin) IsICMPErr() bool { return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 } +// SockErrorCause is the cause of a socket error. +type SockErrorCause interface { + // Origin is the source of the error. + Origin() SockErrOrigin + + // Type is the origin specific type of error. + Type() uint8 + + // Code is the origin and type specific error code. + Code() uint8 + + // Info is any extra information about the error. + Info() uint32 +} + +// LocalSockError is a socket error that originated from the local host. +// +// +stateify savable +type LocalSockError struct { + info uint32 +} + +// Origin implements SockErrorCause. +func (*LocalSockError) Origin() SockErrOrigin { + return SockExtErrorOriginLocal +} + +// Type implements SockErrorCause. +func (*LocalSockError) Type() uint8 { + return 0 +} + +// Code implements SockErrorCause. +func (*LocalSockError) Code() uint8 { + return 0 +} + +// Info implements SockErrorCause. +func (l *LocalSockError) Info() uint32 { + return l.info +} + // SockError represents a queue entry in the per-socket error queue. // // +stateify savable @@ -480,15 +522,9 @@ type SockError struct { sockErrorEntry // Err is the error caused by the errant packet. - Err *Error - // ErrOrigin indicates the error origin. - ErrOrigin SockErrOrigin - // ErrType is the type in the ICMP header. - ErrType uint8 - // ErrCode is the code in the ICMP header. - ErrCode uint8 - // ErrInfo is additional info about the error. - ErrInfo uint32 + Err Error + // Cause is the detailed cause of the error. + Cause SockErrorCause // Payload is the errant packet's payload. Payload []byte @@ -538,14 +574,13 @@ func (so *SocketOptions) QueueErr(err *SockError) { } // QueueLocalErr queues a local error onto the local queue. -func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { +func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { so.QueueErr(&SockError{ - Err: err, - ErrOrigin: SockExtErrorOriginLocal, - ErrInfo: info, - Payload: payload, - Dst: dst, - NetProto: net, + Err: err, + Cause: &LocalSockError{info: info}, + Payload: payload, + Dst: dst, + NetProto: net, }) } @@ -555,9 +590,9 @@ func (so *SocketOptions) GetBindToDevice() int32 { } // SetBindToDevice sets value for SO_BINDTODEVICE option. -func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { if !so.handler.HasNIC(bindToDevice) { - return ErrUnknownDevice + return &ErrUnknownDevice{} } atomic.StoreInt32(&so.bindToDevice, bindToDevice) @@ -565,7 +600,7 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { } // GetSendBufferSize gets value for SO_SNDBUF option. -func (so *SocketOptions) GetSendBufferSize() (int64, *Error) { +func (so *SocketOptions) GetSendBufferSize() (int64, Error) { if so.handler.IsUnixSocket() { return so.handler.GetSendBufferSize() } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bb30556cf..ee23c9b98 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -72,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "packet_buffer_unsafe.go", "pending_packets.go", "rand.go", "registration.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index cd423bf71..e5590ecc0 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) @@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr // AddAndAcquireTemporaryAddress adds a temporary address. // -// Returns tcpip.ErrDuplicateAddress if the address exists. +// Returns *tcpip.ErrDuplicateAddress if the address exists. // // The temporary address's endpoint is acquired and returned. -func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) @@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // If the addressable endpoint already has the address in a non-permanent state, // and addAndAcquireAddressLocked is adding a permanent address, that address is // promoted in place and its properties set to the properties provided. If the -// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// address already exists in any other state, then *tcpip.ErrDuplicateAddress is // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We are adding a non-permanent address but the address exists. No need // to go any further since we can only promote existing temporary/expired // addresses to permanent. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } addrState.mu.Lock() @@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState.mu.Unlock() // We are adding a permanent address but a permanent address already // exists. - return nil, tcpip.ErrDuplicateAddress + return nil, &tcpip.ErrDuplicateAddress{} } if addrState.mu.refs == 0 { @@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // RemovePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { a.mu.Lock() defer a.mu.Unlock() return a.removePermanentAddressLocked(addr) @@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error { addrState, ok := a.mu.endpoints[addr] if !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } return a.removePermanentEndpointLocked(addrState) @@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre // RemovePermanentEndpoint removes the passed endpoint if it is associated with // a and permanent. -func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error { addrState, ok := ep.(*addressState) if !ok || addrState.addressableEndpointState != a { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } a.mu.Lock() @@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * // requirements. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error { if !addrState.GetKind().IsPermanent() { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } addrState.SetKind(PermanentExpired) @@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() { defer a.mu.Unlock() for _, ep := range a.mu.endpoints { - // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is // not a permanent address. - if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + switch err := a.removePermanentEndpointLocked(ep); err.(type) { + case nil, *tcpip.ErrBadLocalAddress: + default: panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5e649cca6..54617f2e6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -198,15 +198,15 @@ type bucket struct { // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. -func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { +func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { - return tupleID{}, tcpip.ErrUnknownProtocol + return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ @@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ conn, _ := ct.connForTID(tid) if conn == nil { // Not a tracked connection. - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } else if conn.manip == manipNone { // Unmanipulated connection. - return "", 0, tcpip.ErrInvalidOptionValue + return "", 0, &tcpip.ErrInvalidOptionValue{} } return conn.original.dstAddr, conn.original.dstPort, nil diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index d29c9a49b..c24f56ece 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -41,6 +41,7 @@ const ( protocolNumberOffset = 2 ) +var _ LinkAddressResolver = (*fwdTestNetworkEndpoint)(nil) var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -55,7 +56,7 @@ type fwdTestNetworkEndpoint struct { dispatcher TransportDispatcher } -func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { +func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { return nil } @@ -112,7 +113,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -124,14 +125,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error { // The network header should not already be populated. if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { - return tcpip.ErrMalformedHeader + return &tcpip.ErrMalformedHeader{} } return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) @@ -153,7 +154,6 @@ type fwdTestNetworkEndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) // fwdTestNetworkProtocol is a network-layer protocol that implements Address @@ -161,10 +161,9 @@ var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) type fwdTestNetworkProtocol struct { stack *Stack - addrCache *linkAddrCache - neigh *neighborCache + neighborTable neighborTable addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(neighborTable, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) mu struct { @@ -197,7 +196,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ nic: nic, proto: f, @@ -207,35 +206,35 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress return e } -func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - if f.onLinkAddressResolved != nil { - time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) +func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { + if fn := f.proto.onLinkAddressResolved; fn != nil { + time.AfterFunc(f.proto.addrResolveDelay, func() { + fn(f.proto.neighborTable, addr, remoteLinkAddr) }) } return nil } -func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if f.onResolveStaticAddress != nil { - return f.onResolveStaticAddress(addr) +func (f *fwdTestNetworkEndpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if fn := f.proto.onResolveStaticAddress; fn != nil { + return fn(addr) } return "", false } -func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -319,7 +318,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -335,7 +334,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.WritePacket(r, gso, protocol, pkt) @@ -401,11 +400,9 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC if !ok { t.Fatal("NIC 2 does not exist") } - if useNeighborCache { - // Control the neighbor cache for NIC 2. - proto.neigh = nic.neigh - } else { - proto.addrCache = nic.linkAddrCache + + if l, ok := nic.linkAddrResolvers[fwdTestNetNumber]; ok { + proto.neighborTable = l.neighborTable } // Route all packets to NIC 2. @@ -482,43 +479,35 @@ func TestForwardingWithFakeResolver(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any address will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -573,7 +562,7 @@ func TestForwardingWithNoResolver(t *testing.T) { func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { proto := &fwdTestNetworkProtocol{ addrResolveDelay: 50 * time.Millisecond, - onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + onLinkAddressResolved: func(neighborTable, tcpip.Address, tcpip.LinkAddress) { // Don't resolve the link address. }, } @@ -606,49 +595,38 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.AddLinkAddress(addr, "c") - } - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) } }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -693,43 +671,35 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -769,43 +739,35 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { tests := []struct { name string useNeighborCache bool - proto *fwdTestNetworkProtocol }{ { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -864,38 +826,31 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { { name: "linkAddrCache", useNeighborCache: false, - proto: &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.AddLinkAddress(addr, "c") - }, - }, }, { name: "neighborCache", useNeighborCache: true, - proto: &fwdTestNetworkProtocol{ + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + proto := fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + onLinkAddressResolved: func(neigh neighborTable, addr tcpip.Address, linkAddr tcpip.LinkAddress) { t.Helper() - if len(remoteLinkAddr) != 0 { - t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr=%q, want unspecified", linkAddr) } // Any packets will be resolved to the link address "c". - neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + neigh.handleConfirmation(addr, "c", ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + } + ep1, ep2 := fwdTestNetFactory(t, &proto, test.useNeighborCache) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 04af933a6..63832c200 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { // ReplaceTable replaces or inserts table by name. It panics when an invalid id // is provided. -func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { - return "", 0, tcpip.ErrNotConnected + return "", 0, &tcpip.ErrNotConnected{} } return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index ba6d56a7d..5b6b58b1d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -24,8 +24,6 @@ import ( const linkAddrCacheSize = 512 // max cache entries -var _ LinkAddressCache = (*linkAddrCache)(nil) - // linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. // // The entries are stored in a ring buffer, oldest entry replaced first. @@ -34,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil) type linkAddrCache struct { nic *NIC + linkRes LinkAddressResolver + // ageLimit is how long a cache entry is valid for. ageLimit time.Duration @@ -45,7 +45,7 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - cache struct { + mu struct { sync.Mutex table map[tcpip.Address]*linkAddrEntry lru linkAddrEntryList @@ -83,32 +83,32 @@ type linkAddrEntry struct { cache *linkAddrCache - // TODO(gvisor.dev/issue/5150): move these fields under mu. - // mu protects the fields below. - mu sync.RWMutex + mu struct { + sync.RWMutex - addr tcpip.Address - linkAddr tcpip.LinkAddress - expiration time.Time - s entryState + addr tcpip.Address + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState - // done is closed when address resolution is complete. It is nil iff s is - // incomplete and resolution is not yet in progress. - done chan struct{} + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. + done chan struct{} - // onResolve is called with the result of address resolution. - onResolve []func(LinkResolutionResult) + // onResolve is called with the result of address resolution. + onResolve []func(LinkResolutionResult) + } } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} - for _, callback := range e.onResolve { + for _, callback := range e.mu.onResolve { callback(res) } - e.onResolve = nil - if ch := e.done; ch != nil { + e.mu.onResolve = nil + if ch := e.mu.done; ch != nil { close(ch) - e.done = nil + e.mu.done = nil // Dequeue the pending packets in a new goroutine to not hold up the current // goroutine as writing packets may be a costly operation. // @@ -129,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { // // Precondition: e.mu must be locked func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { - if e.s == incomplete && ns == ready { - e.notifyCompletionLocked(e.linkAddr) + if e.mu.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.mu.linkAddr) } - if expiration.IsZero() || expiration.After(e.expiration) { - e.expiration = expiration + if expiration.IsZero() || expiration.After(e.mu.expiration) { + e.mu.expiration = expiration } - e.s = ns + e.mu.s = ns } // add adds a k -> v mapping to the cache. -func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { +func (c *linkAddrCache) add(k tcpip.Address, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is // relative to the time when information was learned, rather than when it // happened to be inserted into the cache. expiration := time.Now().Add(c.ageLimit) - c.cache.Lock() + c.mu.Lock() entry := c.getOrCreateEntryLocked(k) - c.cache.Unlock() - entry.mu.Lock() defer entry.mu.Unlock() - entry.linkAddr = v + c.mu.Unlock() + + entry.mu.linkAddr = v entry.changeStateLocked(ready, expiration) } @@ -166,18 +166,18 @@ func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) { // cache is not full, a new entry with state incomplete is allocated and // returned. func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { - if entry, ok := c.cache.table[k]; ok { - c.cache.lru.Remove(entry) - c.cache.lru.PushFront(entry) + if entry, ok := c.mu.table[k]; ok { + c.mu.lru.Remove(entry) + c.mu.lru.PushFront(entry) return entry } var entry *linkAddrEntry - if len(c.cache.table) == linkAddrCacheSize { - entry = c.cache.lru.Back() + if len(c.mu.table) == linkAddrCacheSize { + entry = c.mu.lru.Back() entry.mu.Lock() - delete(c.cache.table, entry.addr) - c.cache.lru.Remove(entry) + delete(c.mu.table, entry.mu.addr) + c.mu.lru.Remove(entry) // Wake waiters and mark the soon-to-be-reused entry as expired. entry.notifyCompletionLocked("" /* linkAddr */) @@ -188,53 +188,55 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { *entry = linkAddrEntry{ cache: c, - addr: k, - s: incomplete, } - c.cache.table[k] = entry - c.cache.lru.PushFront(entry) + entry.mu.Lock() + entry.mu.addr = k + entry.mu.s = incomplete + entry.mu.Unlock() + c.mu.table[k] = entry + c.mu.lru.PushFront(entry) return entry } -// get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - c.cache.Lock() - defer c.cache.Unlock() - entry := c.getOrCreateEntryLocked(k) +// get reports any known link address for addr. +func (c *linkAddrCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + c.mu.Lock() + entry := c.getOrCreateEntryLocked(addr) entry.mu.Lock() defer entry.mu.Unlock() + c.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: - if !time.Now().After(entry.expiration) { + if !time.Now().After(entry.mu.expiration) { // Not expired. if onResolve != nil { - onResolve(LinkResolutionResult{LinkAddress: entry.linkAddr, Success: true}) + onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true}) } - return entry.linkAddr, nil, nil + return entry.mu.linkAddr, nil, nil } entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: if onResolve != nil { - entry.onResolve = append(entry.onResolve, onResolve) + entry.mu.onResolve = append(entry.mu.onResolve, onResolve) } - if entry.done == nil { - entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + if entry.mu.done == nil { + entry.mu.done = make(chan struct{}) + go c.startAddressResolution(addr, localAddr, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.Address, localAddr tcpip.Address, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic) + c.linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */) select { case now := <-time.After(c.resolutionTimeout): @@ -251,9 +253,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddr // succeeded and mark the entry accordingly. Returns true if request can stop, // false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool { - c.cache.Lock() - defer c.cache.Unlock() - entry, ok := c.cache.table[k] + c.mu.Lock() + defer c.mu.Unlock() + entry, ok := c.mu.table[k] if !ok { // Entry was evicted from the cache. return true @@ -261,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt entry.mu.Lock() defer entry.mu.Unlock() - switch s := entry.s; s { + switch s := entry.mu.s; s { case ready: // Entry was made ready by resolver. case incomplete: @@ -271,20 +273,87 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt } // Max number of retries reached, delete entry. entry.notifyCompletionLocked("" /* linkAddr */) - delete(c.cache.table, k) + delete(c.mu.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } return true } -func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - c := &linkAddrCache{ +func (c *linkAddrCache) init(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int, linkRes LinkAddressResolver) { + *c = linkAddrCache{ nic: nic, + linkRes: linkRes, ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, } - c.cache.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) - return c + + c.mu.Lock() + c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize) + c.mu.Unlock() +} + +var _ neighborTable = (*linkAddrCache)(nil) + +func (*linkAddrCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return nil, &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + c.add(addr, linkAddr) +} + +func (*linkAddrCache) remove(addr tcpip.Address) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) removeAll() tcpip.Error { + return &tcpip.ErrNotSupported{} +} + +func (c *linkAddrCache) handleProbe(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + if len(linkAddr) != 0 { + // NUD allows probes without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.3, + // + // Source link-layer address + // The link-layer address for the sender. MUST NOT be + // included when the source IP address is the + // unspecified address. Otherwise, on link layers + // that have addresses this option MUST be included in + // multicast solicitations and SHOULD be included in + // unicast solicitations. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + if len(linkAddr) != 0 { + // NUD allows confirmations without a link address but linkAddrCache + // is a simple neighbor table which does not implement NUD. + // + // As per RFC 4861 section 4.4, + // + // Target link-layer address + // The link-layer address for the target, i.e., the + // sender of the advertisement. This option MUST be + // included on link layers that have addresses when + // responding to multicast solicitations. When + // responding to a unicast Neighbor Solicitation this + // option SHOULD be included. + c.add(addr, linkAddr) + } +} + +func (c *linkAddrCache) handleUpperLevelConfirmation(tcpip.Address) {} + +func (*linkAddrCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return NUDConfigurations{}, &tcpip.ErrNotSupported{} +} + +func (*linkAddrCache) setNUDConfig(NUDConfigurations) tcpip.Error { + return &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 88fbbf3fe..9e7f331c9 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { @@ -60,7 +60,7 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { for _, ta := range testAddrs { if ta.addr == addr { - r.cache.AddLinkAddress(ta.addr, ta.linkAddr) + r.cache.add(ta.addr, ta.linkAddr) break } } @@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return 1 } -func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { +func getBlocking(c *linkAddrCache, addr tcpip.Address) (tcpip.LinkAddress, tcpip.Error) { var attemptedResolution bool for { - got, ch, err := c.get(addr, linkRes, "", nil, nil) - if err == tcpip.ErrWouldBlock { + got, ch, err := c.get(addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { if attemptedResolution { - return got, tcpip.ErrTimeout + return got, &tcpip.ErrTimeout{} } attemptedResolution = true <-ch @@ -100,50 +100,52 @@ func newEmptyNIC() *NIC { } func TestCacheOverflow(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) for i := len(testAddrs) - 1; i >= 0; i-- { e := testAddrs[i] - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("insert %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("insert %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { e := testAddrs[i] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err) + t.Errorf("check %d, c.get(%s, '', nil): %s", i, e.addr, err) } if got != e.linkAddr { - t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr) + t.Errorf("check %d, got c.get(%s, '', nil) = %s, want = %s", i, e.addr, got, e.linkAddr) } } // The earliest entries should no longer be in the cache. - c.cache.Lock() - defer c.cache.Unlock() + c.mu.Lock() + defer c.mu.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry) + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } } func TestCacheConcurrent(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, linkRes) var wg sync.WaitGroup for r := 0; r < 16; r++ { wg.Add(1) go func() { for _, e := range testAddrs { - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) } wg.Done() }() @@ -154,54 +156,57 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } e = testAddrs[0] - c.cache.Lock() - defer c.cache.Unlock() - if entry, ok := c.cache.table[e.addr]; ok { - t.Errorf("unexpected entry at c.cache.table[%s]: %#v", e.addr, entry) + c.mu.Lock() + defer c.mu.Unlock() + if entry, ok := c.mu.table[e.addr]; ok { + t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry) } } func TestCacheAgeLimit(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3, linkRes) e := testAddrs[0] - c.AddLinkAddress(e.addr, e.linkAddr) + c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err) + _, _, err := c.get(e.addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.get(%s, '', nil) = %s, want = ErrWouldBlock", e.addr, err) } } func TestCacheReplace(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3) + var c linkAddrCache + c.init(newEmptyNIC(), 1<<63-1, 1*time.Second, 3, nil) e := testAddrs[0] l2 := e.linkAddr + "2" - c.AddLinkAddress(e.addr, e.linkAddr) - got, _, err := c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } - c.AddLinkAddress(e.addr, l2) - got, _, err = c.get(e.addr, nil, "", nil, nil) + c.add(e.addr, l2) + got, _, err = c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != l2 { - t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, l2) } } @@ -212,34 +217,36 @@ func TestCacheResolution(t *testing.T) { // // Using a large resolution timeout decreases the probability of experiencing // this race condition and does not affect how long this test takes to run. - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1, linkRes) for i, ta := range testAddrs { - got, err := getBlocking(c, ta.addr, linkRes) + got, err := getBlocking(&c, ta.addr) if err != nil { - t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err) + t.Errorf("check %d, getBlocking(_, %s): %s", i, ta.addr, err) } if got != ta.linkAddr { - t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr) + t.Errorf("check %d, got getBlocking(_, %s) = %s, want = %s", i, ta.addr, got, ta.linkAddr) } } // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, linkRes, "", nil, nil) + got, _, err := c.get(e.addr, "", nil) if err != nil { - t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err) + t.Errorf("c.get(%s, '', nil): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got c.get(%s, '', nil) = %s, want = %s", e.addr, got, e.linkAddr) } } } func TestCacheResolutionFailed(t *testing.T) { - c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5) - linkRes := &testLinkAddressResolver{cache: c} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c} + c.init(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5, linkRes) var requestCount uint32 linkRes.onLinkAddressRequest = func() { @@ -248,19 +255,20 @@ func TestCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working... e := testAddrs[0] - got, err := getBlocking(c, e.addr, linkRes) + got, err := getBlocking(&c, e.addr) if err != nil { - t.Errorf("getBlocking(_, %s, _): %s", e.addr, err) + t.Errorf("getBlocking(_, %s): %s", e.addr, err) } if got != e.linkAddr { - t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr) + t.Errorf("got getBlocking(_, %s) = %s, want = %s", e.addr, got, e.linkAddr) } before := atomic.LoadUint32(&requestCount) e.addr += "2" - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + a, err := getBlocking(&c, e.addr) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { @@ -271,11 +279,13 @@ func TestCacheResolutionFailed(t *testing.T) { func TestCacheResolutionTimeout(t *testing.T) { resolverDelay := 500 * time.Millisecond expiration := resolverDelay / 10 - c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3) - linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + var c linkAddrCache + linkRes := &testLinkAddressResolver{cache: &c, delay: resolverDelay} + c.init(newEmptyNIC(), expiration, 1*time.Millisecond, 3, linkRes) e := testAddrs[0] - if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout { - t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout) + a, err := getBlocking(&c, e.addr) + if _, ok := err.(*tcpip.ErrTimeout); !ok { + t.Errorf("got getBlocking(_, %s) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{}) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d7bbb25ea..0238605af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -104,7 +104,7 @@ type ndpDADEvent struct { nicID tcpip.NICID addr tcpip.Address resolved bool - err *tcpip.Error + err tcpip.Error } type ndpRouterEvent struct { @@ -174,7 +174,7 @@ type ndpDispatcher struct { } // Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, @@ -311,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) } @@ -465,8 +465,8 @@ func TestDADResolve(t *testing.T) { // tentative address. { r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -474,8 +474,8 @@ func TestDADResolve(t *testing.T) { } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) - if err != tcpip.ErrNoRoute { - t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{}) } if r != nil { r.Release() @@ -2796,14 +2796,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN NIC: nicID, }}) - if useNeighborCache { - if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } - } else { - if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil { - t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err) - } + if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } return ndpDisp, e, s } @@ -3222,8 +3216,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { defer ep.Close() ep.SocketOptions().SetV6Only(true) - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + { + err := ep.Connect(dstAddr) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{}) + } } }) } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 204196d00..7e3132058 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -42,11 +42,10 @@ type NeighborStats struct { // 2. Static entries are explicitly added by a user and have no expiration. // Their state is always Static. The amount of static entries stored in the // cache is unbounded. -// -// neighborCache implements NUDHandler. type neighborCache struct { - nic *NIC - state *NUDState + nic *NIC + state *NUDState + linkRes LinkAddressResolver // mu protects the fields below. mu sync.RWMutex @@ -62,8 +61,6 @@ type neighborCache struct { } } -var _ NUDHandler = (*neighborCache)(nil) - // getOrCreateEntry retrieves a cache entry associated with addr. The // returned entry is always refreshed in the cache (it is reachable via the // map, and its place is bumped in LRU). @@ -73,7 +70,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -89,7 +86,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) + entry := newNeighborEntry(n, remoteAddr, n.state) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -126,8 +123,8 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() defer entry.mu.Unlock() @@ -154,7 +151,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } entry.handlePacketQueuedLocked(localAddr) - return entry.neigh, entry.done, tcpip.ErrWouldBlock + return entry.neigh, entry.done, &tcpip.ErrWouldBlock{} default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } @@ -206,7 +203,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = newStaticNeighborEntry(n, addr, linkAddr, n.state) } // removeEntry removes a dynamic or static entry by address from the neighbor @@ -263,27 +260,45 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { n.state.SetConfig(config) } -// HandleProbe implements NUDHandler.HandleProbe by following the logic defined -// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled -// by the caller. -func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, linkRes) +var _ neighborTable = (*neighborCache)(nil) + +func (n *neighborCache) neighbors() ([]NeighborEntry, tcpip.Error) { + return n.entries(), nil +} + +func (n *neighborCache) get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + entry, ch, err := n.entry(addr, localAddr, onResolve) + return entry.LinkAddr, ch, err +} + +func (n *neighborCache) remove(addr tcpip.Address) tcpip.Error { + if !n.removeEntry(addr) { + return &tcpip.ErrBadAddress{} + } + + return nil +} + +func (n *neighborCache) removeAll() tcpip.Error { + n.clear() + return nil +} + +// handleProbe handles a neighbor probe as defined by RFC 4861 section 7.2.3. +// +// Validation of the probe is expected to be handled by the caller. +func (n *neighborCache) handleProbe(remoteAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() } -// HandleConfirmation implements NUDHandler.HandleConfirmation by following the -// logic defined in RFC 4861 section 7.2.5. +// handleConfirmation handles a neighbor confirmation as defined by +// RFC 4861 section 7.2.5. // -// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other -// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol -// should be deployed where preventing access to the broadcast segment might -// not be possible. SEND uses RSA key pairs to produce cryptographically -// generated addresses, as defined in RFC 3972, Cryptographically Generated -// Addresses (CGA). This ensures that the claimed source of an NDP message is -// the owner of the claimed address. -func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { +// Validation of the confirmation is expected to be handled by the caller. +func (n *neighborCache) handleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() @@ -297,10 +312,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li // no matching entry for the remote address. } -// HandleUpperLevelConfirmation implements -// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in -// RFC 4861 section 7.3.1. -func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { +// handleUpperLevelConfirmation processes a confirmation of reachablity from +// some protocol that operates at a layer above the IP/link layer. +func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { n.mu.RLock() entry, ok := n.cache[addr] n.mu.RUnlock() @@ -310,3 +324,12 @@ func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { entry.mu.Unlock() } } + +func (n *neighborCache) nudConfig() (NUDConfigurations, tcpip.Error) { + return n.config(), nil +} + +func (n *neighborCache) setNUDConfig(c NUDConfigurations) tcpip.Error { + n.setConfig(c) + return nil +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 6723aef9b..b489b5e08 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -76,10 +76,15 @@ func entryDiffOptsWithSort() []cmp.Option { })) } -func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { +func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - neigh := &neighborCache{ + linkRes := &testNeighborResolver{ + clock: clock, + entries: newTestEntryStore(), + delay: typicalLatency, + } + linkRes.neigh = &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, @@ -88,11 +93,11 @@ func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock id: 1, stats: makeNICStats(), }, - state: NewNUDState(config, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + state: NewNUDState(config, rng), + linkRes: linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } - neigh.nic.neigh = neigh - return neigh + return linkRes } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -194,7 +199,7 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress) tcpip.Error { if !r.dropReplies { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { @@ -212,7 +217,7 @@ func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ // fakeRequest emulates handling a response for a link address request. func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { if entry, ok := r.entries.entryByAddr(addr); ok { - r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + r.neigh.handleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, @@ -242,17 +247,17 @@ func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -260,21 +265,21 @@ func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) + linkRes := newTestNeighborResolver(&nudDisp, c, clock) c.MinRandomFactor = 1 c.MaxRandomFactor = 1 - neigh.setConfig(c) + linkRes.neigh.setConfig(c) - if got, want := neigh.config(), c; got != want { - t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + if got, want := linkRes.neigh.config(), c; got != want { + t.Errorf("got linkRes.neigh.config() = %+v, want = %+v", got, want) } // No events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -282,21 +287,15 @@ func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, c, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -321,22 +320,22 @@ func TestNeighborCacheEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -345,22 +344,16 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -385,14 +378,14 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - neigh.removeEntry(entry.Addr) + linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -407,22 +400,23 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + { + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } } } type testContext struct { clock *faketime.ManualClock - neigh *neighborCache - store *testEntryStore linkRes *testNeighborResolver nudDisp *testNUDDispatcher } @@ -430,19 +424,10 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nudDisp, c, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(nudDisp, c, clock) return testContext{ clock: clock, - neigh: neigh, - store: store, linkRes: linkRes, nudDisp: nudDisp, } @@ -456,16 +441,17 @@ type overflowOptions struct { func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. - for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { // Add a new entry - entry, ok := c.store.entry(i) + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -473,9 +459,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // LRU eviction strategy. Note that the number of static entries should not // affect the total number of dynamic entries that can be added. if i >= neighborCacheSize+opts.startAtEntryIndex { - removedEntry, ok := c.store.entry(i - neighborCacheSize) + removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, @@ -506,11 +492,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error { }) c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -518,10 +504,10 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries - for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { - entry, ok := c.store.entry(i) + for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { + entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.store.entry(%d) not found", i) + return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -531,15 +517,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } return nil @@ -575,14 +561,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - c.clock.Advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -603,15 +590,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the entry - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ @@ -626,11 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -650,12 +637,12 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -668,21 +655,21 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -694,12 +681,12 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -712,16 +699,16 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -736,8 +723,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) } c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -756,12 +743,12 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a static entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -774,15 +761,15 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Remove the static entry that was just added - c.neigh.removeEntry(entry.Addr) + c.linkRes.neigh.removeEntry(entry.Addr) { wantEvents := []testEntryEventInfo{ { @@ -796,11 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -826,12 +813,13 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -854,16 +842,16 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" - c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) { wantEvents := []testEntryEventInfo{ { @@ -886,11 +874,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -918,22 +906,22 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { c := newTestContext(config) - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) + c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, LinkAddr: entry.LinkAddr, State: Static, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -948,11 +936,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } opts := overflowOptions{ @@ -975,22 +963,16 @@ func TestNeighborCacheClear(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add a dynamic entry. - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) @@ -1014,15 +996,15 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Add a static entry. - neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + linkRes.neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) { wantEvents := []testEntryEventInfo{ @@ -1037,16 +1019,16 @@ func TestNeighborCacheClear(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } // Clear should remove both dynamic and static entries. - neigh.clear() + linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they // need to be sorted beforehand. @@ -1072,8 +1054,8 @@ func TestNeighborCacheClear(t *testing.T) { } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1090,12 +1072,13 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { c := newTestContext(config) // Add a dynamic entry - entry, ok := c.store.entry(0) + entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.store.entry(0) not found") + t.Fatal("c.linkRes.entries.entry(0) not found") } - if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1118,15 +1101,15 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } // Clear the cache. - c.neigh.clear() + c.linkRes.neigh.clear() { wantEvents := []testEntryEventInfo{ { @@ -1140,11 +1123,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...) c.nudDisp.events = nil c.nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1165,18 +1148,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := store.entry(0) + frequentlyUsedEntry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1184,23 +1160,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Fill the neighbor cache to capacity for i := 0; i < neighborCacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1222,47 +1198,47 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } // Keep adding more entries - for i := neighborCacheSize; i < store.size(); i++ { + for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) + if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", frequentlyUsedEntry.Addr, err) } } - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy - removedEntry, ok := store.entry(i - neighborCacheSize + 1) + removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) } wantEvents := []testEntryEventInfo{ { @@ -1293,11 +1269,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.events = nil nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1312,10 +1288,10 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { }, } - for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("store.entry(%d) not found", i) + t.Fatalf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1325,15 +1301,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } // No more events should have been dispatched. nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1344,24 +1320,19 @@ func TestNeighborCacheConcurrent(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) - storeEntries := store.entries() + storeEntries := linkRes.entries.entries() for _, entry := range storeEntries { var wg sync.WaitGroup for r := 0; r < concurrentProcesses; r++ { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + switch e, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: + t.Errorf("got linkRes.neigh.entry(%s, '', nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{}) } }(entry) } @@ -1379,10 +1350,10 @@ func TestNeighborCacheConcurrent(t *testing.T) { // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry - for i := store.size() - neighborCacheSize; i < store.size(); i++ { - entry, ok := store.entry(i) + for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { + entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("store.entry(%d) not found", i) + t.Errorf("linkRes.entries.entry(%d) not found", i) } wantEntry := NeighborEntry{ Addr: entry.Addr, @@ -1392,8 +1363,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } - if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { - t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1402,41 +1373,34 @@ func TestNeighborCacheReplace(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Add an entry - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1446,22 +1410,22 @@ func TestNeighborCacheReplace(t *testing.T) { LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } // Notify of a link address change var updatedLinkAddr tcpip.LinkAddress { - entry, ok := store.entry(1) + entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("store.entry(1) not found") + t.Fatal("linkRes.entries.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } - store.set(0, updatedLinkAddr) - neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + linkRes.entries.set(0, updatedLinkAddr) + linkRes.neigh.handleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, @@ -1471,35 +1435,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, LinkAddr: updatedLinkAddr, State: Delay, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, LinkAddr: updatedLinkAddr, State: Reachable, } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } } @@ -1509,54 +1473,47 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() + linkRes := newTestNeighborResolver(&nudDisp, config, clock) var requestCount uint32 - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - onLinkAddressRequest: func() { - atomic.AddUint32(&requestCount, 1) - }, + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) } - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } clock.Advance(typicalLatency) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, LinkAddr: entry.LinkAddr, State: Reachable, } - if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } // Verify address resolution fails for an unknown address. @@ -1564,24 +1521,24 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } - maxAttempts := neigh.config().MaxUnicastProbes + maxAttempts := linkRes.neigh.config().MaxUnicastProbes if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { t.Errorf("got link address request count = %d, want = %d", got, want) } @@ -1595,27 +1552,22 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config.RetransmitTimer = time.Millisecond // small enough to cause timeout clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: time.Minute, // large enough to cause timeout - } + linkRes := newTestNeighborResolver(nil, config, clock) + // large enough to cause timeout + linkRes.delay = time.Minute - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1623,7 +1575,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1632,31 +1584,24 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() clock := faketime.NewManualClock() - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - // Simulate a faulty link. - dropReplies: true, - } + linkRes := newTestNeighborResolver(nil, config, clock) + // Simulate a faulty link. + linkRes.dropReplies = true - entry, ok := store.entry(0) + entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("store.entry(0) not found") + t.Fatal("linkRes.entries.entry(0) not found") } // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1664,7 +1609,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } @@ -1676,20 +1621,20 @@ func TestNeighborCacheRetryResolution(t *testing.T) { State: Failed, }, } - if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) } // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + incompleteEntry, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) @@ -1701,9 +1646,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if !ok { t.Fatal("expected successful address resolution") } - reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) } if reachableEntry.Addr != entry.Addr { t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) @@ -1715,7 +1660,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) } default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1724,42 +1669,36 @@ func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() clock := &tcpip.StdClock{} - neigh := newTestNeighborCache(nil, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: 0, - } + linkRes := newTestNeighborResolver(nil, config, clock) + linkRes.delay = 0 // Clear for every possible size of the cache for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. for i := 0; i < cacheSize; i++ { - entry, ok := store.entry(i) + entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("store.entry(%d) not found", i) + b.Fatalf("linkRes.entries.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) - if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } b.StartTimer() - neigh.clear() + linkRes.neigh.clear() b.StopTimer() } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 53ac9bb6e..b05f96d4f 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -77,11 +77,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - - // linkRes provides the functionality to send reachability probes, used in - // Neighbor Unreachability Detection. - linkRes LinkAddressResolver + cache *neighborCache // nudState points to the Neighbor Unreachability Detection configuration. nudState *NUDState @@ -106,10 +102,9 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState *NUDState) *neighborEntry { return &neighborEntry{ - nic: nic, - linkRes: linkRes, + cache: cache, nudState: nudState, neigh: NeighborEntry{ Addr: remoteAddr, @@ -121,18 +116,18 @@ func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, li // newStaticNeighborEntry creates a neighbor cache entry starting at the // Static state. The entry can only transition out of Static by directly // calling `setStateLocked`. -func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { +func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ Addr: addr, LinkAddr: linkAddr, State: Static, - UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), } - if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) + if nudDisp := cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(cache.nic.id, entry) } return &neighborEntry{ - nic: nic, + cache: cache, nudState: state, neigh: entry, } @@ -158,7 +153,7 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // is resolved (which ends up obtaining the entry's lock) while holding the // link resolution queue's lock. Dequeuing packets in a new goroutine avoids // a lock ordering violation. - go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) + go e.cache.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded) } } @@ -167,8 +162,8 @@ func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.cache.nic.id, e.neigh) } } @@ -177,8 +172,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.cache.nic.id, e.neigh) } } @@ -187,8 +182,8 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { - if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) + if nudDisp := e.cache.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.cache.nic.id, e.neigh) } } @@ -206,7 +201,7 @@ func (e *neighborEntry) cancelJobLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchRemoveEventLocked() e.cancelJobLocked() e.notifyCompletionLocked(false /* succeeded */) @@ -222,7 +217,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { @@ -230,14 +225,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Stale) e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: - e.job = e.nic.stack.newJob(&e.mu, func() { + e.job = e.cache.nic.stack.newJob(&e.mu, func() { e.setStateLocked(Probe) e.dispatchChangeEventLocked() }) @@ -254,14 +249,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -269,7 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(immediateDuration) case Failed: @@ -292,12 +287,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() + e.cache.nic.stats.Neighbor.FailedEntryLookups.Increment() fallthrough case Unknown: e.neigh.State = Incomplete - e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() e.dispatchAddEventLocked() @@ -340,7 +335,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // address SHOULD be placed in the IP Source Address of the outgoing // solicitation. // - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + if err := e.cache.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, ""); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -350,7 +345,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { } retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -358,7 +353,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // for finishing the state transition. This is necessary to avoid // deadlock where sending and processing probes are done synchronously, // such as loopback and integration tests. - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job = e.cache.nic.stack.newJob(&e.mu, sendMulticastProbe) e.job.Schedule(immediateDuration) case Stale: @@ -504,7 +499,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 // here. - ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + ep, ok := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber] if !ok { panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index ec34ffa5a..57cfbdb8b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { p := entryTestProbeInfo{ RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, @@ -230,22 +230,30 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e }, stats: makeNICStats(), } + netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ - header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + header.IPv6ProtocolNumber: netEP, } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) - linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) - + var linkRes entryTestLinkResolver // Stub out the neighbor cache to verify deletion from the cache. - nic.neigh = &neighborCache{ - nic: &nic, - state: nudState, - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + neigh := &neighborCache{ + nic: &nic, + state: nudState, + linkRes: &linkRes, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + l := linkResolver{ + resolver: &linkRes, + neighborTable: neigh, + } + entry := newNeighborEntry(neigh, entryTestAddr1 /* remoteAddr */, nudState) + neigh.cache[entryTestAddr1] = entry + nic.linkAddrResolvers = map[tcpip.NetworkProtocolNumber]linkResolver{ + header.IPv6ProtocolNumber: l, } - nic.neigh.cache[entryTestAddr1] = entry return entry, &disp, &linkRes, clock } @@ -266,16 +274,16 @@ func TestEntryInitiallyUnknown(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -299,16 +307,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { // No probes should have been sent. linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } // No events should have been dispatched. nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -333,10 +341,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -352,10 +360,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) { } { nudDisp.mu.Lock() - diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...) nudDisp.mu.Unlock() if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -374,10 +382,10 @@ func TestEntryUnknownToStale(t *testing.T) { // No probes should have been sent. runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -392,8 +400,8 @@ func TestEntryUnknownToStale(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -427,11 +435,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -453,10 +461,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -483,8 +491,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -515,10 +523,10 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -553,8 +561,8 @@ func TestEntryIncompleteToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -579,10 +587,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -620,8 +628,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -646,10 +654,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -684,8 +692,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -710,10 +718,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -744,8 +752,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -785,10 +793,10 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } wantEvents := []testEntryEventInfo{ @@ -812,8 +820,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -835,7 +843,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + ipv6EP := e.cache.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) @@ -850,10 +858,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -903,8 +911,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -932,10 +940,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -977,8 +985,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1005,10 +1013,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1054,8 +1062,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -1083,10 +1091,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1134,8 +1142,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1157,10 +1165,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1212,8 +1220,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1235,10 +1243,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1290,8 +1298,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1313,10 +1321,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1358,8 +1366,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1381,10 +1389,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1439,8 +1447,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1462,10 +1470,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1520,8 +1528,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1543,10 +1551,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1601,8 +1609,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1624,10 +1632,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1678,8 +1686,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1701,10 +1709,10 @@ func TestEntryStaleToDelay(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1752,8 +1760,8 @@ func TestEntryStaleToDelay(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1780,10 +1788,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1851,8 +1859,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1880,10 +1888,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -1958,8 +1966,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -1987,10 +1995,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2065,8 +2073,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2088,10 +2096,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2147,8 +2155,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2170,10 +2178,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2231,8 +2239,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2254,10 +2262,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2319,8 +2327,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2343,11 +2351,11 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2372,10 +2380,10 @@ func TestEntryDelayToProbe(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2418,8 +2426,8 @@ func TestEntryDelayToProbe(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() @@ -2448,11 +2456,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2474,10 +2482,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2539,8 +2547,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2563,11 +2571,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2589,10 +2597,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2658,8 +2666,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2682,11 +2690,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2709,10 +2717,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2772,8 +2780,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2806,10 +2814,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -2878,8 +2886,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -2907,11 +2915,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -2933,10 +2941,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3015,8 +3023,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3044,11 +3052,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3070,10 +3078,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3149,8 +3157,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3178,11 +3186,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3204,10 +3212,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3283,8 +3291,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3309,11 +3317,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } } @@ -3336,11 +3344,11 @@ func TestEntryProbeToFailed(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.probes = nil linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff) } e.mu.Lock() @@ -3406,8 +3414,8 @@ func TestEntryProbeToFailed(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } @@ -3449,10 +3457,10 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) + diff := cmp.Diff(wantProbes, linkRes.probes) linkRes.mu.Unlock() if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff) } e.mu.Lock() @@ -3498,8 +3506,8 @@ func TestEntryFailedToIncomplete(t *testing.T) { }, } nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } nudDisp.mu.Unlock() } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1bbfe6213..41a489047 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "reflect" "sync/atomic" @@ -25,8 +24,37 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +type neighborTable interface { + neighbors() ([]NeighborEntry, tcpip.Error) + addStaticEntry(tcpip.Address, tcpip.LinkAddress) + get(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) + remove(tcpip.Address) tcpip.Error + removeAll() tcpip.Error + + handleProbe(tcpip.Address, tcpip.LinkAddress) + handleConfirmation(tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) + handleUpperLevelConfirmation(tcpip.Address) + + nudConfig() (NUDConfigurations, tcpip.Error) + setNUDConfig(NUDConfigurations) tcpip.Error +} + var _ NetworkInterface = (*NIC)(nil) +type linkResolver struct { + resolver LinkAddressResolver + + neighborTable neighborTable +} + +func (l *linkResolver) getNeighborLinkAddress(addr, localAddr tcpip.Address, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) { + return l.neighborTable.get(addr, localAddr, onResolve) +} + +func (l *linkResolver) confirmReachable(addr tcpip.Address) { + l.neighborTable.handleUpperLevelConfirmation(addr) +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -38,11 +66,11 @@ type NIC struct { context NICContext stats NICStats - neigh *neighborCache // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. - networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + linkAddrResolvers map[tcpip.NetworkProtocolNumber]linkResolver // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. // @@ -53,8 +81,6 @@ type NIC struct { // complete. linkResQueue packetsPendingLinkResolution - linkAddrCache *linkAddrCache - mu struct { sync.RWMutex spoofing bool @@ -133,35 +159,18 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic := &NIC{ LinkEndpoint: ep, - stack: stack, - id: id, - name: name, - context: ctx, - stats: makeNICStats(), - networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + stack: stack, + id: id, + name: name, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]linkResolver), } nic.linkResQueue.init(nic) - nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) - // Check for Neighbor Unreachability Detection support. - var nud NUDHandler - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { - rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) - nic.neigh = &neighborCache{ - nic: nic, - state: NewNUDState(stack.nudConfigs, rng), - cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), - } - - // An interface value that holds a nil pointer but non-nil type is not the - // same as the nil interface. Because of this, nud must only be assignd if - // nic.neigh is non-nil since a nil reference to a neighborCache is not - // valid. - // - // See https://golang.org/doc/faq#nil_error for more information. - nud = nic.neigh - } + resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { @@ -170,7 +179,32 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC for _, netProto := range stack.networkProtocols { netNum := netProto.Number() nic.mu.packetEPs[netNum] = new(packetEndpointList) - nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic) + + netEP := netProto.NewEndpoint(nic, nic) + nic.networkEndpoints[netNum] = netEP + + if resolutionRequired { + if r, ok := netEP.(LinkAddressResolver); ok { + l := linkResolver{ + resolver: r, + } + + if stack.useNeighborCache { + l.neighborTable = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, stack.randomGenerator), + linkRes: r, + + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } else { + cache := new(linkAddrCache) + cache.init(nic, ageLimit, resolutionTimeout, resolutionAttempts, r) + l.neighborTable = cache + } + nic.linkAddrResolvers[r.LinkAddressProtocol()] = l + } + } } nic.LinkEndpoint.Attach(nic) @@ -223,16 +257,19 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() - } - // Clear the neighbour table (including static entries) as we cannot guarantee - // that the current neighbour table will be valid when the NIC is enabled - // again. - // - // This matches linux's behaviour at the time of writing: - // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 - if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported { - panic(fmt.Sprintf("n.clearNeighbors(): %s", err)) + // Clear the neighbour table (including static entries) as we cannot + // guarantee that the current neighbour table will be valid when the NIC is + // enabled again. + // + // This matches linux's behaviour at the time of writing: + // https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371 + netProto := ep.NetworkProtocolNumber() + switch err := n.clearNeighbors(netProto); err.(type) { + case nil, *tcpip.ErrNotSupported: + default: + panic(fmt.Sprintf("n.clearNeighbors(%d): %s", netProto, err)) + } } if !n.setEnabled(false) { @@ -246,7 +283,7 @@ func (n *NIC) disableLocked() { // address (ff02::1), start DAD for permanent addresses, and start soliciting // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. -func (n *NIC) enable() *tcpip.Error { +func (n *NIC) enable() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -266,7 +303,7 @@ func (n *NIC) enable() *tcpip.Error { // remove detaches NIC from the link endpoint and releases network endpoint // resources. This guarantees no packets between this NIC and the network // stack. -func (n *NIC) remove() *tcpip.Error { +func (n *NIC) remove() tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -302,40 +339,63 @@ func (n *NIC) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) return err } -func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { - // As per relevant RFCs, we should queue packets while we wait for link - // resolution to complete. - // - // RFC 1122 section 2.3.2.2 (for IPv4): - // The link layer SHOULD save (rather than discard) at least - // one (the latest) packet of each set of packets destined to - // the same unresolved IP address, and transmit the saved - // packet when the address has been resolved. - // - // RFC 4861 section 7.2.2 (for IPv6): - // While waiting for address resolution to complete, the sender MUST, for - // each neighbor, retain a small queue of packets waiting for address - // resolution to complete. The queue MUST hold at least one packet, and MAY - // contain more. However, the number of queued packets per neighbor SHOULD - // be limited to some small value. When a queue overflows, the new arrival - // SHOULD replace the oldest entry. Once address resolution completes, the - // node transmits any queued packets. - return n.linkResQueue.enqueue(r, gso, protocol, pkt) + +func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + switch pkt := pkt.(type) { + case *PacketBuffer: + if err := n.writePacket(r, gso, protocol, pkt); err != nil { + return 0, err + } + return 1, nil + case *PacketBufferList: + return n.writePackets(r, gso, protocol, *pkt) + default: + panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) + } +} + +func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { + routeInfo, _, err := r.resolvedFields(nil) + switch err.(type) { + case nil: + return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + case *tcpip.ErrWouldBlock: + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and + // MAY contain more. However, the number of queued packets per neighbor + // SHOULD be limited to some small value. When a queue overflows, the new + // arrival SHOULD replace the oldest entry. Once address resolution + // completes, the node transmits any queued packets. + return n.linkResQueue.enqueue(r, gso, protocol, pkt) + default: + return 0, err + } } // WritePacketToRemote implements NetworkInterface. -func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr return n.writePacket(r, gso, protocol, pkt) } -func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { +func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -352,11 +412,11 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { return n.enqueuePacketBuffer(r, gso, protocol, &pkts) } -func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) { +func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.GSOOptions = gso @@ -472,15 +532,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) @@ -544,72 +604,75 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } // removeAddress removes an address from n. -func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { +func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error { for _, ep := range n.networkEndpoints { addressableEndpoint, ok := ep.(AddressableEndpoint) if !ok { continue } - if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) { + case *tcpip.ErrBadLocalAddress: continue - } else { + default: return err } } - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - if n.neigh != nil { - entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) - return entry.LinkAddr, ch, err +func (n *NIC) getLinkAddress(addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { + linkRes, ok := n.linkAddrResolvers[protocol] + if !ok { + return &tcpip.ErrNotSupported{} } - return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve) + if linkAddr, ok := linkRes.resolver.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil + } + + _, _, err := linkRes.getNeighborLinkAddress(addr, localAddr, onResolve) + return err } -func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { - if n.neigh == nil { - return nil, tcpip.ErrNotSupported +func (n *NIC) neighbors(protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.neighbors() } - return n.neigh.entries(), nil + return nil, &tcpip.ErrNotSupported{} } -func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { - if n.neigh == nil { - return tcpip.ErrNotSupported +func (n *NIC) addStaticNeighbor(addr tcpip.Address, protocol tcpip.NetworkProtocolNumber, linkAddress tcpip.LinkAddress) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + linkRes.neighborTable.addStaticEntry(addr, linkAddress) + return nil } - n.neigh.addStaticEntry(addr, linkAddress) - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { - if n.neigh == nil { - return tcpip.ErrNotSupported +func (n *NIC) removeNeighbor(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.remove(addr) } - if !n.neigh.removeEntry(addr) { - return tcpip.ErrBadAddress - } - return nil + return &tcpip.ErrNotSupported{} } -func (n *NIC) clearNeighbors() *tcpip.Error { - if n.neigh == nil { - return tcpip.ErrNotSupported +func (n *NIC) clearNeighbors(protocol tcpip.NetworkProtocolNumber) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.removeAll() } - n.neigh.clear() - return nil + return &tcpip.ErrNotSupported{} } // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. -func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as @@ -617,12 +680,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.JoinGroup(addr) @@ -630,15 +693,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { ep, ok := n.networkEndpoints[protocol] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } gep, ok := ep.(GroupAddressableEndpoint) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } return gep.LeaveGroup(addr) @@ -848,9 +911,8 @@ func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt } } -// DeliverTransportControlPacket delivers control packets to the appropriate -// transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { +// DeliverTransportError implements TransportDispatcher. +func (n *NIC) DeliverTransportError(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -872,7 +934,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { + if n.stack.demux.deliverError(n, net, trans, transErr, pkt, id) { return } } @@ -888,33 +950,34 @@ func (n *NIC) Name() string { } // nudConfigs gets the NUD configurations for n. -func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { - if n.neigh == nil { - return NUDConfigurations{}, tcpip.ErrNotSupported +func (n *NIC) nudConfigs(protocol tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + return linkRes.neighborTable.nudConfig() } - return n.neigh.config(), nil + + return NUDConfigurations{}, &tcpip.ErrNotSupported{} } // setNUDConfigs sets the NUD configurations for n. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { - if n.neigh == nil { - return tcpip.ErrNotSupported +func (n *NIC) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { + if linkRes, ok := n.linkAddrResolvers[protocol]; ok { + c.resetInvalidFields() + return linkRes.neighborTable.setNUDConfig(c) } - c.resetInvalidFields() - n.neigh.setConfig(c) - return nil + + return &tcpip.ErrNotSupported{} } -func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { n.mu.Lock() defer n.mu.Unlock() eps, ok := n.mu.packetEPs[netProto] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.add(ep) @@ -941,3 +1004,23 @@ func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RUnlock() return n.Enabled() && ep.IsAssigned(spoofing) } + +// HandleNeighborProbe implements NetworkInterface. +func (n *NIC) HandleNeighborProbe(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleProbe(addr, linkAddr) + return nil + } + + return &tcpip.ErrNotSupported{} +} + +// HandleNeighborConfirmation implements NetworkInterface. +func (n *NIC) HandleNeighborConfirmation(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) tcpip.Error { + if l, ok := n.linkAddrResolvers[protocol]; ok { + l.neighborTable.handleConfirmation(addr, linkAddr, flags) + return nil + } + + return &tcpip.ErrNotSupported{} +} diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index be5df7b01..9992d6eb4 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -39,7 +39,7 @@ type testIPv6Endpoint struct { invalidatedRtr tcpip.Address } -func (*testIPv6Endpoint) Enable() *tcpip.Error { +func (*testIPv6Endpoint) Enable() tcpip.Error { return nil } @@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported + return 0, &tcpip.ErrNotSupported{} } // WriteHeaderIncludedPacket implements // NetworkEndpoint.WriteHeaderIncludedPacket. -func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error { // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } // HandlePacket implements NetworkEndpoint.HandlePacket. @@ -111,8 +111,6 @@ type testIPv6EndpointStats struct{} // IsNetworkEndpointStats implements stack.NetworkEndpointStats. func (*testIPv6EndpointStats) IsNetworkEndpointStats() {} -var _ LinkAddressResolver = (*testIPv6Protocol)(nil) - // We use this instead of ipv6.protocol because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Protocol struct{} @@ -139,7 +137,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ nic: nic, protocol: p, @@ -149,12 +147,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error { return nil } @@ -169,24 +167,6 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } -// LinkAddressProtocol implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return header.IPv6ProtocolNumber -} - -// LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - return nil -} - -// ResolveStaticAddress implements LinkAddressResolver. -func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if header.IsV6MulticastAddress(addr) { - return header.EthernetAddressFromMulticastIPv6Address(addr), true - } - return "", false -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 12d67409a..5a94e9ac6 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -161,25 +161,6 @@ type ReachabilityConfirmationFlags struct { IsRouter bool } -// NUDHandler communicates external events to the Neighbor Unreachability -// Detection state machine, which is implemented per-interface. This is used by -// network endpoints to inform the Neighbor Cache of probes and confirmations. -type NUDHandler interface { - // HandleProbe processes an incoming neighbor probe (e.g. ARP request or - // Neighbor Solicitation for ARP or NDP, respectively). Validation of the - // probe needs to be performed before calling this function since the - // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) - - // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP - // reply or Neighbor Advertisement for ARP or NDP, respectively). - HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) - - // HandleUpperLevelConfirmation processes an incoming upper-level protocol - // (e.g. TCP acknowledgements) reachability confirmation. - HandleUpperLevelConfirmation(addr tcpip.Address) -} - // NUDConfigurations is the NUD configurations for the netstack. This is used // by the neighbor cache to operate the NUD state machine on each device in the // local network. diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 7bca1373e..e9acef6a2 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -19,7 +19,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -52,80 +54,146 @@ func (f *fakeRand) Float32() float32 { return f.num } -// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NUD configurations using an invalid NICID. -func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The networking - // stack will only allocate neighbor caches if a protocol providing link - // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - UseNeighborCache: true, - }) +func TestNUDFunctions(t *testing.T) { + const nicID = 1 - // No NIC with ID 1 yet. - config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + tests := []struct { + name string + nicID tcpip.NICID + netProtoFactory []stack.NetworkProtocolFactory + extraLinkCapabilities stack.LinkEndpointCapabilities + expectedErr tcpip.Error + }{ + { + name: "Invalid NICID", + nicID: nicID + 1, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrUnknownNICID{}, + }, + { + name: "No network protocol", + nicID: nicID, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With resolution capability", + nicID: nicID, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + expectedErr: &tcpip.ErrNotSupported{}, + }, + { + name: "With IPv6 and resolution capability", + nicID: nicID, + netProtoFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + extraLinkCapabilities: stack.CapabilityResolutionRequired, + }, } -} -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to retrieve NUD configurations when the -// stack doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, + NetworkProtocols: test.netProtoFactory, + Clock: clock, + }) - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channel.New(0, 0, linkAddr1) + e.LinkEPCapabilities &^= stack.CapabilityResolutionRequired + e.LinkEPCapabilities |= test.extraLinkCapabilities - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { - t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) - } -} + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } -// TestNUDConfigurationFailsForNotSupported tests to make sure we get a -// NotSupported error if we attempt to set NUD configurations when the stack -// doesn't support NUD. -// -// The stack will report to not support NUD if a neighbor cache for a given NIC -// is not allocated. The networking stack will only allocate neighbor caches if -// a protocol providing link address resolution is specified (e.g. ARP, IPv6). -func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { - const nicID = 1 + configs := stack.DefaultNUDConfigurations() + configs.BaseReachableTime = time.Hour - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + { + err := s.SetNUDConfigurations(test.nicID, ipv6.ProtocolNumber, configs) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.SetNUDConfigurations(%d, %d, _) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } - s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + { + gotConfigs, err := s.NUDConfigurations(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.NUDConfigurations(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff(configs, gotConfigs); diff != "" { + t.Errorf("got configs mismatch (-want +got):\n%s", diff) + } + } + } - config := stack.NUDConfigurations{} - if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { - t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + for _, addr := range []tcpip.Address{llAddr1, llAddr2} { + { + err := s.AddStaticNeighbor(test.nicID, ipv6.ProtocolNumber, addr, linkAddr1) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.AddStaticNeighbor(%d, %d, %s, %s) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, addr, linkAddr1, diff) + } + } + } + + { + wantErr := test.expectedErr + for i := 0; i < 2; i++ { + { + err := s.RemoveNeighbor(test.nicID, ipv6.ProtocolNumber, llAddr1) + if diff := cmp.Diff(wantErr, err); diff != "" { + t.Errorf("s.RemoveNeighbor(%d, %d, '') error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } + } + + if test.expectedErr != nil { + break + } + + // Removing a neighbor that does not exist should give us a bad address + // error. + wantErr = &tcpip.ErrBadAddress{} + } + } + + { + neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Errorf("neighbors mismatch (-want +got):\n%s", diff) + } + } + } + + { + err := s.ClearNeighbors(test.nicID, ipv6.ProtocolNumber) + if diff := cmp.Diff(test.expectedErr, err); diff != "" { + t.Errorf("s.ClearNeigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) + } else if test.expectedErr == nil { + if neighbors, err := s.Neighbors(test.nicID, ipv6.ProtocolNumber); err != nil { + t.Errorf("s.Neighbors(%d, %d): %s", test.nicID, ipv6.ProtocolNumber, err) + } else if len(neighbors) != 0 { + t.Errorf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + } + } + } + }) } } -// TestDefaultNUDConfigurationIsValid verifies that calling -// resetInvalidFields() on the result of DefaultNUDConfigurations() does not -// change anything. DefaultNUDConfigurations() should return a valid -// NUDConfigurations. func TestDefaultNUDConfigurations(t *testing.T) { const nicID = 1 @@ -143,12 +211,12 @@ func TestDefaultNUDConfigurations(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - c, err := s.NUDConfigurations(nicID) + c, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got, want := c, stack.DefaultNUDConfigurations(); got != want { - t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + t.Errorf("got stack.NUDConfigurations(%d, %d) = %+v, want = %+v", nicID, ipv6.ProtocolNumber, got, want) } } @@ -198,9 +266,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.BaseReachableTime; got != test.want { t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) @@ -255,9 +323,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MinRandomFactor; got != test.want { t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) @@ -335,9 +403,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxRandomFactor; got != test.want { t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) @@ -397,9 +465,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.RetransmitTimer; got != test.want { t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) @@ -449,9 +517,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.DelayFirstProbeTime; got != test.want { t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) @@ -501,9 +569,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxMulticastProbes; got != test.want { t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) @@ -553,9 +621,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - sc, err := s.NUDConfigurations(nicID) + sc, err := s.NUDConfigurations(nicID, ipv6.ProtocolNumber) if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + t.Fatalf("got stack.NUDConfigurations(%d, %d) = %s", nicID, ipv6.ProtocolNumber, err) } if got := sc.MaxUnicastProbes; got != test.want { t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9d4fc3e48..4f013b212 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -187,6 +187,12 @@ func (pk *PacketBuffer) Size() int { return pk.HeaderSize() + pk.Data.Size() } +// MemSize returns the estimation size of the pk in memory, including backing +// buffer data. +func (pk *PacketBuffer) MemSize() int { + return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize +} + // Views returns the underlying storage of the whole packet. func (pk *PacketBuffer) Views() []buffer.View { // Optimization for outbound packets that headers are in pk.header. diff --git a/pkg/tcpip/stack/packet_buffer_unsafe.go b/pkg/tcpip/stack/packet_buffer_unsafe.go new file mode 100644 index 000000000..ee3d47270 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_unsafe.go @@ -0,0 +1,19 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "unsafe" + +const packetBufferStructSize = int(unsafe.Sizeof(PacketBuffer{})) diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index c4769b17e..1c651e216 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -114,26 +114,12 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi } } -func (f *packetsPendingLinkResolution) writePacketBuffer(r RouteInfo, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { - switch pkt := pkt.(type) { - case *PacketBuffer: - if err := f.nic.writePacket(r, gso, proto, pkt); err != nil { - return 0, err - } - return 1, nil - case *PacketBufferList: - return f.nic.writePackets(r, gso, proto, *pkt) - default: - panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) - } -} - // enqueue a packet to be sent once link resolution completes. // // If the maximum number of pending resolutions is reached, the packets // associated with the oldest link resolution will be dequeued as if they failed // link resolution. -func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) { +func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { f.mu.Lock() // Make sure we attempt resolution while holding f's lock so that we avoid // a race where link resolution completes before we enqueue the packets. @@ -146,13 +132,13 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N // To make sure B does not interleave with A and C, we make sure A and C are // done while holding the lock. routeInfo, ch, err := r.resolvedFields(nil) - switch err { + switch err.(type) { case nil: // The route resolved immediately, so we don't need to wait for link // resolution to send the packet. f.mu.Unlock() - return f.writePacketBuffer(routeInfo, gso, proto, pkt) - case tcpip.ErrWouldBlock: + return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + case *tcpip.ErrWouldBlock: // We need to wait for link resolution to complete. default: f.mu.Unlock() @@ -225,7 +211,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l for _, p := range packets { if success { p.routeInfo.RemoteLinkAddress = linkAddr - _, _ = f.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) } else { f.incrementOutgoingPacketErrors(p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 33df192aa..d589f798d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -49,31 +49,6 @@ type TransportEndpointID struct { RemoteAddress tcpip.Address } -// ControlType is the type of network control message. -type ControlType int - -// The following are the allowed values for ControlType values. -// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. -const ( - // ControlAddressUnreachable indicates that an IPv6 packet did not reach its - // destination as the destination address was unreachable. - // - // This maps to the ICMPv6 Destination Ureachable Code 3 error; see - // RFC 4443 section 3.1 for more details. - ControlAddressUnreachable ControlType = iota - ControlNetworkUnreachable - // ControlNoRoute indicates that an IPv4 packet did not reach its destination - // because the destination host was unreachable. - // - // This maps to the ICMPv4 Destination Ureachable Code 1 error; see - // RFC 791's Destination Unreachable Message section (page 4) for more - // details. - ControlNoRoute - ControlPacketTooBig - ControlPortUnreachable - ControlUnknown -) - // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { // LocalAddressBroadcast is true if the packet's local address is a broadcast @@ -81,6 +56,39 @@ type NetworkPacketInfo struct { LocalAddressBroadcast bool } +// TransportErrorKind enumerates error types that are handled by the transport +// layer. +type TransportErrorKind int + +const ( + // PacketTooBigTransportError indicates that a packet did not reach its + // destination because a link on the path to the destination had an MTU that + // was too small to carry the packet. + PacketTooBigTransportError TransportErrorKind = iota + + // DestinationHostUnreachableTransportError indicates that the destination + // host was unreachable. + DestinationHostUnreachableTransportError + + // DestinationPortUnreachableTransportError indicates that a packet reached + // the destination host, but the transport protocol was not active on the + // destination port. + DestinationPortUnreachableTransportError + + // DestinationNetworkUnreachableTransportError indicates that the destination + // network was unreachable. + DestinationNetworkUnreachableTransportError +) + +// TransportError is a marker interface for errors that may be handled by the +// transport layer. +type TransportError interface { + tcpip.SockErrorCause + + // Kind returns the type of the transport error. + Kind() TransportErrorKind +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { @@ -93,10 +101,10 @@ type TransportEndpoint interface { // HandlePacket takes ownership of the packet. HandlePacket(TransportEndpointID, *PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g. - // ICMP) packets arrive to this transport endpoint. - // HandleControlPacket takes ownership of pkt. - HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) + // HandleError is called when the transport endpoint receives an error. + // + // HandleError takes ownership of the packet buffer. + HandleError(TransportError, *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -172,10 +180,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -184,7 +192,7 @@ type TransportProtocol interface { // ParsePorts returns the source and destination ports stored in a // packet of this protocol. - ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this // protocol that don't match any existing endpoint. For example, @@ -197,12 +205,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -248,14 +256,11 @@ type TransportDispatcher interface { // DeliverTransportPacket takes ownership of the packet. DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition - // DeliverTransportControlPacket delivers control packets to the - // appropriate transport protocol endpoint. - // - // pkt.NetworkHeader must be set before calling - // DeliverTransportControlPacket. + // DeliverTransportError delivers an error to the appropriate transport + // endpoint. // - // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) + // DeliverTransportError takes ownership of the packet buffer. + DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -289,10 +294,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - JoinGroup(group tcpip.Address) *tcpip.Error + JoinGroup(group tcpip.Address) tcpip.Error // LeaveGroup attempts to leave the specified group. - LeaveGroup(group tcpip.Address) *tcpip.Error + LeaveGroup(group tcpip.Address) tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -440,17 +445,17 @@ func (k AddressKind) IsPermanent() bool { type AddressableEndpoint interface { // AddAndAcquirePermanentAddress adds the passed permanent address. // - // Returns tcpip.ErrDuplicateAddress if the address exists. + // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. // - // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed // permanent address. - RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + RemovePermanentAddress(addr tcpip.Address) tcpip.Error // MainAddress returns the endpoint's primary permanent address. MainAddress() tcpip.AddressWithPrefix @@ -512,14 +517,14 @@ type NetworkInterface interface { Promiscuous() bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -529,7 +534,18 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP + // request or NDP Neighbor Solicitation). + // + // HandleNeighborProbe assumes that the probe is valid for the network + // interface the probe was received on. + HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error + + // HandleNeighborConfirmation processes an incoming neighbor confirmation + // (e.g. ARP reply or NDP Neighbor Advertisement). + HandleNeighborConfirmation(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress, ReachabilityConfirmationFlags) tcpip.Error } // LinkResolvableNetworkEndpoint handles link resolution events. @@ -547,8 +563,8 @@ type NetworkEndpoint interface { // Must only be called when the stack is in a state that allows the endpoint // to send and receive packets. // - // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. - Enable() *tcpip.Error + // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() tcpip.Error // Enabled returns true if the endpoint is enabled. Enabled() bool @@ -574,16 +590,16 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. @@ -649,17 +665,17 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint + NewEndpoint(nic NetworkInterface, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -796,7 +812,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error + WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -806,7 +822,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -821,19 +837,15 @@ type InjectableLinkEndpoint interface { // link. // // dest is used by endpoints with multiple raw destinations. - InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error + InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// A LinkAddressResolver is an extension to a NetworkProtocol that -// can resolve link addresses. +// A LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target // address. The request is broadcasted on the local network if a remote link // address is not provided. - // - // The request is sent from the passed network interface. If the interface - // local address is unspecified, any interface local address may be used. - LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the @@ -847,22 +859,16 @@ type LinkAddressResolver interface { LinkAddressProtocol() tcpip.NetworkProtocolNumber } -// A LinkAddressCache caches link addresses. -type LinkAddressCache interface { - // AddLinkAddress adds a link address to the cache. - AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress) -} - // RawFactory produces endpoints for writing various types of raw packets. type RawFactory interface { // NewUnassociatedEndpoint produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can // be used to write arbitrary packets that include the network header. - NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) // NewPacketEndpoint produces endpoints for reading and writing packets // that include network and (when cooked is false) link layer headers. - NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) } // GSOType is the type of GSO segments. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index d9a8554e2..bab55ce49 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -53,7 +53,7 @@ type Route struct { // linkRes is set if link address resolution is enabled for this protocol on // the route's NIC. - linkRes LinkAddressResolver + linkRes linkResolver } type routeInfo struct { @@ -174,7 +174,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA } if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { + if linkRes, ok := r.outgoingNIC.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes } } @@ -184,11 +184,11 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, gateway, localAddr, remoteA return r } - if r.linkRes == nil { + if r.linkRes.resolver == nil { return r } - if linkAddr, ok := r.linkRes.ResolveStaticAddress(r.RemoteAddress); ok { + if linkAddr, ok := r.linkRes.resolver.ResolveStaticAddress(r.RemoteAddress); ok { r.ResolveWith(linkAddr) return r } @@ -331,7 +331,7 @@ type ResolvedFieldsResult struct { // // Note, the route will not cache the remote link address when address // resolution completes. -func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) *tcpip.Error { +func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error { _, _, err := r.resolvedFields(afterResolve) return err } @@ -342,7 +342,7 @@ func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) *tcpip.E // // The route's fields will also be returned, regardless of whether address // resolution is required or not. -func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, *tcpip.Error) { +func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) { r.mu.RLock() fields := r.fieldsLocked() resolutionRequired := r.isResolutionRequiredRLocked() @@ -354,11 +354,6 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn return fields, nil, nil } - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - // If specified, the local address used for link address resolution must be an // address on the outgoing interface. var linkAddressResolutionRequestLocalAddr tcpip.Address @@ -367,7 +362,7 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn } afterResolveFields := fields - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) { + linkAddr, ch, err := r.linkRes.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, func(r LinkResolutionResult) { if afterResolve != nil { if r.Success { afterResolveFields.RemoteLinkAddress = r.LinkAddress @@ -382,6 +377,13 @@ func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteIn return fields, ch, err } +func (r *Route) nextHop() tcpip.Address { + if len(r.NextHop) == 0 { + return r.RemoteAddress + } + return r.NextHop +} + // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -398,7 +400,7 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isResolutionRequiredRLocked() bool { - return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local() + return len(r.mu.remoteLinkAddress) == 0 && r.linkRes.resolver != nil && r.isValidForOutgoingRLocked() && !r.local() } func (r *Route) isValidForOutgoing() bool { @@ -427,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) @@ -437,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) @@ -447,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) @@ -519,3 +521,14 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } + +// ConfirmReachable informs the network/link layer that the neighbour used for +// the route is reachable. +// +// "Reachable" is defined as having full-duplex communication between the +// local and remote ends of the route. +func (r *Route) ConfirmReachable() { + if r.linkRes.resolver != nil { + r.linkRes.confirmReachable(r.nextHop()) + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e9c5db4c3..a51d758d0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -76,12 +76,16 @@ type TCPCubicState struct { // TCPRACKState is used to hold a copy of the internal RACK state when the // TCPProbeFunc is invoked. type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool + DSACKSeen bool + ReoWnd time.Duration + ReoWndIncr uint8 + ReoWndPersist int8 + RTTSeq seqnum.Value } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -372,7 +376,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { type Stack struct { transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. @@ -382,6 +385,15 @@ type Stack struct { stats tcpip.Stats + // LOCK ORDERING: mu > route.mu. + route struct { + mu struct { + sync.RWMutex + + table []tcpip.Route + } + } + mu sync.RWMutex nics map[tcpip.NICID]*NIC @@ -389,11 +401,6 @@ type Stack struct { cleanupEndpointsMu sync.Mutex cleanupEndpoints map[TransportEndpoint]struct{} - // route is the route table passed in by the user via SetRouteTable(), - // it is used by FindRoute() to build a route for a specific - // destination. - routeTable []tcpip.Route - *ports.PortManager // If not nil, then any new endpoints will have this probe function @@ -429,6 +436,8 @@ type Stack struct { // useNeighborCache indicates whether ARP and NDP packets should be handled // by the NIC's neighborCache instead of linkAddrCache. + // + // TODO(gvisor.dev/issue/4658): Remove this field. useNeighborCache bool // nudDisp is the NUD event dispatcher that is used to send the netstack @@ -449,6 +458,18 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. receiveBufferSize ReceiveBufferSizeOption + + // tcpInvalidRateLimit is the maximal rate for sending duplicate + // acknowledgements in response to incoming TCP packets that are for an existing + // connection but that are invalid due to any of the following reasons: + // + // a) out-of-window sequence number. + // b) out-of-window acknowledgement number. + // c) PAWS check failure (when implemented). + // + // This is required to prevent potential ACK loops. + // Setting this to 0 will disable all rate limiting. + tcpInvalidRateLimit time.Duration } // UniqueID is an abstract generator of unique identifiers. @@ -495,13 +516,17 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // UseNeighborCache indicates whether ARP and NDP packets should be handled - // by the Neighbor Unreachability Detection (NUD) state machine. This flag - // also enables the APIs for inspecting and modifying the neighbor table via - // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, - // and ClearNeighbors. + // UseNeighborCache is unused. + // + // TODO(gvisor.dev/issue/4658): Remove this field. UseNeighborCache bool + // UseLinkAddrCache indicates that the legacy link address cache should be + // used for link resolution. + // + // TODO(gvisor.dev/issue/4658): Remove this field. + UseLinkAddrCache bool + // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. NUDDisp NUDDispatcher @@ -552,7 +577,7 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: @@ -570,11 +595,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } case header.IPv6AddressSize: if len(addr.Addr) == header.IPv4AddressSize { - return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{} } } @@ -582,10 +607,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl case netProto == t.NetProto: case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { - return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{} } default: - return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{} } return addr, netProto, nil @@ -631,7 +656,6 @@ func New(opts Options) *Stack { s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), cleanupEndpoints: make(map[TransportEndpoint]struct{}), PortManager: ports.NewPortManager(), @@ -642,7 +666,7 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, - useNeighborCache: opts.UseNeighborCache, + useNeighborCache: !opts.UseLinkAddrCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, randomGenerator: mathrand.New(randSrc), @@ -656,15 +680,13 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, + tcpInvalidRateLimit: defaultTCPInvalidRateLimit, } // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto - if r, ok := netProto.(LinkAddressResolver); ok { - s.linkAddrResolvers[r.LinkAddressProtocol()] = r - } } // Add specified transport protocols. @@ -698,10 +720,10 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.SetOption(option) } @@ -715,10 +737,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return netProto.Option(option) } @@ -727,10 +749,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.SetOption(option) } @@ -742,10 +764,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return transProtoState.proto.Option(option) } @@ -778,15 +800,15 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables packet forwarding between NICs for the // passed protocol. -func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { protocol, ok := s.networkProtocols[protocolNum] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } forwardingProtocol.SetForwarding(enable) @@ -814,45 +836,44 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { // // This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - - s.routeTable = table + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = table } // GetRouteTable returns the route table which is currently in use. func (s *Stack) GetRouteTable() []tcpip.Route { - s.mu.Lock() - defer s.mu.Unlock() - return append([]tcpip.Route(nil), s.routeTable...) + s.route.mu.RLock() + defer s.route.mu.RUnlock() + return append([]tcpip.Route(nil), s.route.mu.table...) } // AddRoute appends a route to the route table. func (s *Stack) AddRoute(route tcpip.Route) { - s.mu.Lock() - defer s.mu.Unlock() - s.routeTable = append(s.routeTable, route) + s.route.mu.Lock() + defer s.route.mu.Unlock() + s.route.mu.table = append(s.route.mu.table, route) } // RemoveRoutes removes matching routes from the route table. func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { - s.mu.Lock() - defer s.mu.Unlock() + s.route.mu.Lock() + defer s.route.mu.Unlock() var filteredRoutes []tcpip.Route - for _, route := range s.routeTable { + for _, route := range s.route.mu.table { if !match(route) { filteredRoutes = append(filteredRoutes, route) } } - s.routeTable = filteredRoutes + s.route.mu.table = filteredRoutes } // NewEndpoint creates a new transport layer endpoint of the given protocol. -func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewEndpoint(network, waiterQueue) @@ -861,9 +882,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } if !associated { @@ -872,7 +893,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network t, ok := s.transportProtocols[transport] if !ok { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return t.proto.NewRawEndpoint(network, waiterQueue) @@ -880,9 +901,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // NewPacketEndpoint creates a new packet endpoint listening for the given // netProto. -func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if s.rawFactory == nil { - return nil, tcpip.ErrNotPermitted + return nil, &tcpip.ErrNotPermitted{} } return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) @@ -913,20 +934,20 @@ type NICOptions struct { // NICs can be configured. // // LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() // Make sure id is unique. if _, ok := s.nics[id]; ok { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } // Make sure name is unique, unless unnamed. if opts.Name != "" { for _, n := range s.nics { if n.Name() == opts.Name { - return tcpip.ErrDuplicateNICID + return &tcpip.ErrDuplicateNICID{} } } } @@ -942,7 +963,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls // LinkEndpoint.Attach to bind ep with a NetworkDispatcher. -func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } @@ -960,26 +981,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. -func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.enable() } // DisableNIC disables the given NIC. -func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.disable() @@ -1000,7 +1021,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // RemoveNIC removes NIC and all related routes from the network stack. -func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { +func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1010,25 +1031,26 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { // removeNICLocked removes NIC and all related routes from the network stack. // // s.mu must be locked. -func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { +func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. + s.route.mu.Lock() n := 0 - for i, r := range s.routeTable { - s.routeTable[i] = tcpip.Route{} + for i, r := range s.route.mu.table { + s.route.mu.table[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - s.routeTable[n] = r + s.route.mu.table[n] = r n++ } } - - s.routeTable = s.routeTable[:n] + s.route.mu.table = s.route.mu.table[:n] + s.route.mu.Unlock() return nic.remove() } @@ -1118,13 +1140,13 @@ type NICStateFlags struct { } // AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } // AddAddressWithPrefix is the same as AddAddress, but allows you to specify // the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { ap := tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: addr, @@ -1134,16 +1156,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) } // AddAddressWithOptions is the same as AddAddress, but allows you to specify // whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { netProto, ok := s.networkProtocols[protocol] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ Protocol: protocol, @@ -1156,13 +1178,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt // AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows // you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { +func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } return nic.addAddress(protocolAddress, peb) @@ -1170,7 +1192,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc // RemoveAddress removes an existing network-layer address from the specified // NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -1178,7 +1200,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return nic.removeAddress(addr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // AllAddresses returns a map of NICIDs to their protocol addresses (primary @@ -1308,7 +1330,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1344,48 +1366,58 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route - for _, route := range s.routeTable { - if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { - continue - } + if r := func() *Route { + s.route.mu.RLock() + defer s.route.mu.RUnlock() - nic, ok := s.nics[route.NIC] - if !ok || !nic.Enabled() { - continue - } + for _, route := range s.route.mu.table { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } - if id == 0 || id == route.NIC { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - var gateway tcpip.Address - if needRoute { - gateway = route.Gateway - } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) - if r == nil { - panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r } - return r, nil } - } - // If the stack has forwarding enabled and we haven't found a valid route to - // the remote address yet, keep track of the first valid route. We keep - // iterating because we prefer routes that let us use a local address that - // is assigned to the outgoing interface. There is no requirement to do this - // from any RFC but simply a choice made to better follow a strong host - // model which the netstack follows at the time of writing. - if canForward && chosenRoute == (tcpip.Route{}) { - chosenRoute = route + // If the stack has forwarding enabled and we haven't found a valid route + // to the remote address yet, keep track of the first valid route. We + // keep iterating because we prefer routes that let us use a local + // address that is assigned to the outgoing interface. There is no + // requirement to do this from any RFC but simply a choice made to better + // follow a strong host model which the netstack follows at the time of + // writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } } + + return nil + }(); r != nil { + return r, nil } if chosenRoute != (tcpip.Route{}) { @@ -1412,7 +1444,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } } - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if id == 0 { @@ -1432,12 +1464,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return nil, tcpip.ErrNoRoute + return nil, &tcpip.ErrNoRoute{} } if header.IsV6LoopbackAddress(remoteAddr) { - return nil, tcpip.ErrBadLocalAddress + return nil, &tcpip.ErrBadLocalAddress{} } - return nil, tcpip.ErrNetworkUnreachable + return nil, &tcpip.ErrNetworkUnreachable{} } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1483,13 +1515,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto } // SetPromiscuousMode enables or disables promiscuous mode in the given NIC. -func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setPromiscuousMode(enable) @@ -1499,13 +1531,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error // SetSpoofing enables or disables address spoofing in the given NIC, allowing // endpoints to bind to any address in the NIC. -func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } nic.setSpoofing(enable) @@ -1513,20 +1545,6 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { return nil } -// AddLinkAddress adds a link address for the neighbor on the specified NIC. -func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[nicID] - if !ok { - return tcpip.ErrUnknownNICID - } - - nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr) - return nil -} - // LinkResolutionResult is the result of a link address resolution attempt. type LinkResolutionResult struct { LinkAddress tcpip.LinkAddress @@ -1549,93 +1567,82 @@ type LinkResolutionResult struct { // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) *tcpip.Error { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID - } - - linkRes, ok := s.linkAddrResolvers[protocol] - if !ok { - return tcpip.ErrNotSupported - } - - if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { - onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) - return nil + return &tcpip.ErrUnknownNICID{} } - _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) - return err + return nic.getLinkAddress(addr, localAddr, protocol, onResolve) } // Neighbors returns all IP to MAC address associations. -func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { +func (s *Stack) Neighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) ([]NeighborEntry, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } - return nic.neighbors() + return nic.neighbors(protocol) } // AddStaticNeighbor statically associates an IP address to a MAC address. -func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } - return nic.addStaticNeighbor(addr, linkAddr) + return nic.addStaticNeighbor(addr, protocol, linkAddr) } // RemoveNeighbor removes an IP to MAC address association previously created // either automically or by AddStaticNeighbor. Returns ErrBadAddress if there // is no association with the provided address. -func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } - return nic.removeNeighbor(addr) + return nic.removeNeighbor(protocol, addr) } // ClearNeighbors removes all IP to MAC address associations. -func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { +func (s *Stack) ClearNeighbors(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber) tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } - return nic.clearNeighbors() + return nic.clearNeighbors(protocol) } // RegisterTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CheckRegisterTransportEndpoint checks if an endpoint can be registered with // the stack transport dispatcher. -func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) } @@ -1672,7 +1679,7 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. -func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { return s.demux.registerRawEndpoint(netProto, transProto, ep) } @@ -1782,7 +1789,7 @@ func (s *Stack) Resume() { // RegisterPacketEndpoint registers ep with the stack, causing it to receive // all traffic of the specified netProto on the given NIC. If nicID is 0, it // receives traffic from every NIC. -func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { +func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -1801,7 +1808,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network // Capture on a specific device. nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } if err := nic.registerPacketEndpoint(netProto, ep); err != nil { return err @@ -1839,12 +1846,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacketToRemote writes a payload on the specified NIC using the provided // network protocol and remote link address. -func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } pkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(nic.MaxHeaderLength()), @@ -1909,37 +1916,37 @@ func (s *Stack) RemoveTCPProbe() { } // JoinGroup joins the given multicast group on the given NIC. -func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.joinGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // LeaveGroup leaves the given multicast group on the given NIC. -func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.leaveGroup(protocol, multicastAddr) } - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } // IsInGroup returns true if the NIC with ID nicID has joined the multicast // group multicastAddr. -func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { return nic.isInGroup(multicastAddr), nil } - return false, tcpip.ErrUnknownNICID + return false, &tcpip.ErrUnknownNICID{} } // IPTables returns the stack's iptables. @@ -1979,45 +1986,45 @@ func (s *Stack) AllowICMPMessage() bool { // GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol // number installed on the specified NIC. -func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() nic, ok := s.nics[nicID] if !ok { - return nil, tcpip.ErrUnknownNICID + return nil, &tcpip.ErrUnknownNICID{} } return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. -func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { +func (s *Stack) NUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NUDConfigurations, tcpip.Error) { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return NUDConfigurations{}, tcpip.ErrUnknownNICID + return NUDConfigurations{}, &tcpip.ErrUnknownNICID{} } - return nic.nudConfigs() + return nic.nudConfigs(proto) } // SetNUDConfigurations sets the per-interface NUD configurations. // // Note, if c contains invalid NUD configuration values, it will be fixed to // use default values for the erroneous values. -func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocolNumber, c NUDConfigurations) tcpip.Error { s.mu.RLock() nic, ok := s.nics[id] s.mu.RUnlock() if !ok { - return tcpip.ErrUnknownNICID + return &tcpip.ErrUnknownNICID{} } - return nic.setNUDConfigs(c) + return nic.setNUDConfigs(proto, c) } // Seed returns a 32 bit value that can be used as a seed value for port @@ -2056,7 +2063,7 @@ func generateRandInt64() int64 { } // FindNetworkEndpoint returns the network endpoint for the given address. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -2068,7 +2075,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres addressEndpoint.DecRef() return nic.getNetworkEndpoint(netProto), nil } - return nil, tcpip.ErrBadAddress + return nil, &tcpip.ErrBadAddress{} } // FindNICNameFromID returns the name of the NIC for the given NICID. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 92e70f94e..3066f4ffd 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -15,6 +15,8 @@ package stack import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" ) @@ -29,6 +31,10 @@ const ( // DefaultMaxBufferSize is the default maximum permitted size of a // send/receive buffer. DefaultMaxBufferSize = 4 << 20 // 4 MiB + + // defaultTCPInvalidRateLimit is the default value for + // stack.TCPInvalidRateLimit. + defaultTCPInvalidRateLimit = 500 * time.Millisecond ) // ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to @@ -39,18 +45,22 @@ type ReceiveBufferSizeOption struct { Max int } +// TCPInvalidRateLimitOption is used by stack.(Stack*).Option/SetOption to get/set +// stack.tcpInvalidRateLimit. +type TCPInvalidRateLimitOption time.Duration + // SetOption allows setting stack wide options. -func (s *Stack) SetOption(option interface{}) *tcpip.Error { +func (s *Stack) SetOption(option interface{}) tcpip.Error { switch v := option.(type) { case tcpip.SendBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -62,11 +72,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } if v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } s.mu.Lock() @@ -74,13 +84,22 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error { s.mu.Unlock() return nil + case TCPInvalidRateLimitOption: + if v < 0 { + return &tcpip.ErrInvalidOptionValue{} + } + s.mu.Lock() + s.tcpInvalidRateLimit = time.Duration(v) + s.mu.Unlock() + return nil + default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option allows retrieving stack wide options. -func (s *Stack) Option(option interface{}) *tcpip.Error { +func (s *Stack) Option(option interface{}) tcpip.Error { switch v := option.(type) { case *tcpip.SendBufferSizeOption: s.mu.RLock() @@ -94,7 +113,13 @@ func (s *Stack) Option(option interface{}) *tcpip.Error { s.mu.RUnlock() return nil + case *TCPInvalidRateLimitOption: + s.mu.RLock() + *v = TCPInvalidRateLimitOption(s.tcpInvalidRateLimit) + s.mu.RUnlock() + return nil + default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 0f02f1d53..b641a4aaa 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -90,7 +91,7 @@ type fakeNetworkEndpoint struct { dispatcher stack.TransportDispatcher } -func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { +func (f *fakeNetworkEndpoint) Enable() tcpip.Error { f.mu.Lock() defer f.mu.Unlock() f.mu.enabled = true @@ -137,12 +138,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket( + f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), - stack.ControlPortUnreachable, 0, pkt) + // Nothing checks the error. + nil, /* transport error */ + pkt, + ) return } @@ -162,7 +166,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -185,12 +189,12 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} } func (f *fakeNetworkEndpoint) Close() { @@ -243,7 +247,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ nic: nic, proto: f, @@ -253,23 +257,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li return e } -func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: f.defaultTTL = uint8(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } @@ -418,7 +422,7 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return err @@ -427,7 +431,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r *stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -456,14 +460,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -600,8 +604,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } } @@ -649,8 +653,9 @@ func TestDisableUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.DisableNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -708,8 +713,9 @@ func TestRemoveUnknownNIC(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { - t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + err := s.RemoveNIC(1) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{}) } } @@ -752,8 +758,8 @@ func TestRemoveNIC(t *testing.T) { func TestRouteWithDownNIC(t *testing.T) { tests := []struct { name string - downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error - upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error }{ { name: "Disabled NIC", @@ -911,15 +917,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -932,7 +938,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1057,11 +1063,12 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) } } @@ -1108,12 +1115,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. - if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { - t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) + { + err := s.RemoveAddress(1, localAddr) + if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok { + t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{}) + } } } @@ -1207,7 +1217,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1235,7 +1245,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1274,8 +1284,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1311,7 +1321,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1354,8 +1364,8 @@ func TestPromiscuousMode(t *testing.T) { // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) - if err != tcpip.ErrNoRoute { - t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{}) } // Set promiscuous mode to false, then check that packet can't be @@ -1561,7 +1571,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1611,8 +1621,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s.SetRouteTable([]tcpip.Route{}) // If there is no endpoint, it won't work. - if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} @@ -1631,8 +1644,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } // If the NIC doesn't exist, it won't work. - if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + { + _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) + if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok { + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{}) + } } } @@ -1774,9 +1790,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { anyAddr = header.IPv6Any } - want := tcpip.ErrNetworkUnreachable + var want tcpip.Error = &tcpip.ErrNetworkUnreachable{} if tc.routeNeeded { - want = tcpip.ErrNoRoute + want = &tcpip.ErrNoRoute{} } // If there is no endpoint, it won't work. @@ -1790,8 +1806,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { // Route table is empty but we need a route, this should cause an error. - if err != tcpip.ErrNoRoute { - t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{}) } } else { if err != nil { @@ -2115,7 +2131,7 @@ func TestCreateNICWithOptions(t *testing.T) { type callArgsAndExpect struct { nicID tcpip.NICID opts stack.NICOptions - err *tcpip.Error + err tcpip.Error } tests := []struct { @@ -2133,7 +2149,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{Name: "eth2"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2148,7 +2164,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(2), opts: stack.NICOptions{Name: "lo"}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -2178,7 +2194,7 @@ func TestCreateNICWithOptions(t *testing.T) { { nicID: tcpip.NICID(1), opts: stack.NICOptions{}, - err: tcpip.ErrDuplicateNICID, + err: &tcpip.ErrDuplicateNICID{}, }, }, }, @@ -3297,14 +3313,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { testCases := []struct { name string rs stack.ReceiveBufferSizeOption - err *tcpip.Error + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, @@ -3337,14 +3353,14 @@ func TestStackSendBufferSizeOption(t *testing.T) { testCases := []struct { name string ss tcpip.SendBufferSizeOption - err *tcpip.Error + err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, - {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, - {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, @@ -3356,11 +3372,12 @@ func TestStackSendBufferSizeOption(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{}) defer s.Close() - if err := s.SetOption(tc.ss); err != tc.err { - t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + err := s.SetOption(tc.ss) + if diff := cmp.Diff(tc.err, err); diff != "" { + t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff) } - var ss tcpip.SendBufferSizeOption if tc.err == nil { + var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err != nil { t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) } @@ -3919,7 +3936,7 @@ func TestFindRouteWithForwarding(t *testing.T) { addrNIC tcpip.NICID localAddr tcpip.Address - findRouteErr *tcpip.Error + findRouteErr tcpip.Error dependentOnForwarding bool }{ { @@ -3928,7 +3945,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3937,7 +3954,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID1, localAddr: fakeNetCfg.nic2Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3946,7 +3963,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID1, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3982,7 +3999,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: false, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -3991,7 +4008,7 @@ func TestFindRouteWithForwarding(t *testing.T) { forwardingEnabled: true, addrNIC: nicID2, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4015,7 +4032,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, localAddr: fakeNetCfg.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4031,7 +4048,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4039,7 +4056,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, addrNIC: nicID1, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4047,7 +4064,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4055,7 +4072,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, - findRouteErr: tcpip.ErrNoRoute, + findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, { @@ -4087,7 +4104,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4095,7 +4112,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4103,7 +4120,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4111,7 +4128,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, - findRouteErr: tcpip.ErrNetworkUnreachable, + findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, { @@ -4166,8 +4183,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if r != nil { defer r.Release() } - if err != test.findRouteErr { - t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + if diff := cmp.Diff(test.findRouteErr, err); diff != "" { + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { @@ -4214,8 +4231,11 @@ func TestFindRouteWithForwarding(t *testing.T) { if err := s.SetForwarding(test.netCfg.proto, false); err != nil { t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) } - if err := send(r, data); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + { + err := send(r, data) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{}) + } } if n := ep1.Drain(); n != 0 { t.Errorf("got %d unexpected packets from ep1", n) @@ -4277,8 +4297,9 @@ func TestWritePacketToRemote(t *testing.T) { } t.Run("InvalidNICID", func(t *testing.T) { - if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { - t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()) + if _, ok := err.(*tcpip.ErrUnknownDevice); !ok { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{}) } pkt, ok := e.Read() if got, want := ok, false; got != want { @@ -4296,9 +4317,11 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, UseNeighborCache: true, + Clock: clock, }) e := channel.New(0, 0, "") e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -4306,36 +4329,56 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddStaticNeighbor(nicID, ipv4Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv4Addr, linkAddr, err) - } - if err := s.AddStaticNeighbor(nicID, ipv6Addr, linkAddr); err != nil { - t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, ipv6Addr, linkAddr, err) + addrs := []struct { + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + proto: ipv4.ProtocolNumber, + addr: ipv4Addr, + }, + { + proto: ipv6.ProtocolNumber, + addr: ipv6Addr, + }, } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 2 { - t.Fatalf("got len(neighbors) = %d, want = 2; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if err := s.AddStaticNeighbor(nicID, addr.proto, addr.addr, linkAddr); err != nil { + t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, addr.proto, addr.addr, linkAddr, err) + } + + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if diff := cmp.Diff( + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + neighbors, + ); diff != "" { + t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) + } } // Disabling the NIC should clear the neighbor table. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } // Enabling the NIC should have an empty neighbor table. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - if neighbors, err := s.Neighbors(nicID); err != nil { - t.Fatalf("s.Neighbors(%d): %s", nicID, err) - } else if len(neighbors) != 0 { - t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors) + for _, addr := range addrs { + if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { + t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) + } else if len(neighbors) != 0 { + t.Fatalf("got proto=%d len(neighbors) = %d, want = 0; neighbors = %#v", addr.proto, len(neighbors), neighbors) + } } } @@ -4352,11 +4395,17 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrUnknownNICID) + { + err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrUnknownNICID); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{}) + } } - if err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported) + { + err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil) + if _, ok := err.(*tcpip.ErrNotSupported); !ok { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{}) + } } } @@ -4368,7 +4417,9 @@ func TestStaticGetLinkAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, }) - if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + e := channel.New(0, 0, "") + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 07b2818d2..7d8d0851e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,9 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// handleControlPacket delivers a control packet to the transport endpoint -// identified by id. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { +// handleError delivers an error to the transport endpoint identified by id. +func (epsByNIC *endpointsByNIC) handleError(n *NIC, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -200,12 +199,12 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -222,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -294,7 +293,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for i, n := range netProtos { if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) @@ -306,7 +305,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // checkEndpoint checks if an endpoint can be registered with the dispatcher. -func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { for _, n := range netProtos { if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { return err @@ -403,7 +402,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() @@ -412,7 +411,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -422,7 +421,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } -func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { ep.mu.RLock() defer ep.mu.RUnlock() @@ -431,7 +430,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } } @@ -456,7 +455,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -464,7 +463,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.Lock() @@ -482,7 +481,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } -func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { if id.RemotePort != 0 { // SO_REUSEPORT only applies to bound/listening endpoints. flags.LoadBalanced = false @@ -490,7 +489,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return tcpip.ErrUnknownProtocol + return &tcpip.ErrUnknownProtocol{} } eps.mu.RLock() @@ -596,9 +595,11 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb return foundRaw } -// deliverControlPacket attempts to deliver the given control packet. Returns -// true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { +// deliverError attempts to deliver the given error to the appropriate transport +// endpoint. +// +// Returns true if the error was delivered. +func (d *transportDemuxer) deliverError(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -611,7 +612,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - ep.handleControlPacket(n, id, typ, extra, pkt) + ep.handleError(n, id, transErr, pkt) return true } @@ -649,10 +650,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw // endpoint. -func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { +func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } eps.mu.Lock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index de4b5fbdc..10cbbe589 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) { for _, test := range []struct { name string proto tcpip.NetworkProtocolNumber - want *tcpip.Error + want tcpip.Error }{ - {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}}, {"success", ipv4.ProtocolNumber, nil}, } { t.Run(test.name, func(t *testing.T) { @@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) { defer wq.EventUnregister(&we) defer close(ch) - var err *tcpip.Error + var err tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c49427c4c..bebf4e6b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -87,18 +87,18 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { return tcpip.ReadResult{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if len(f.route.RemoteAddress) == 0 { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, tcpip.ErrBadBuffer + return 0, &tcpip.ErrBadBuffer{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -114,37 +114,37 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return -1, tcpip.ErrUnknownProtocolOption +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { + return -1, &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*fakeTransportEndpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*fakeTransportEndpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } -func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error { f.peerAddr = addr.Addr // Find the route. r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } // Try to register so that we can start receiving packets. @@ -164,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error { return nil } -func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return nil } func (*fakeTransportEndpoint) Reset() { } -func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { +func (*fakeTransportEndpoint) Listen(int) tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -188,7 +188,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai return a, nil, nil } -func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { +func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error { if err := f.proto.stack.RegisterTransportEndpoint( []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -203,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { return nil } -func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } -func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { return tcpip.FullAddress{}, nil } @@ -237,7 +237,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleError(stack.TransportError, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -252,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {} func (*fakeTransportEndpoint) Wait() {} -func (*fakeTransportEndpoint) LastError() *tcpip.Error { +func (*fakeTransportEndpoint) LastError() tcpip.Error { return nil } @@ -280,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newFakeTransportEndpoint(f, netProto, f.stack), nil } -func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrUnknownProtocol +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { + return nil, &tcpip.ErrUnknownProtocol{} } func (*fakeTransportProtocol) MinimumPacketSize() int { return fakeTransHeaderLen } -func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) { return 0, 0, nil } @@ -300,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: f.opts.good = bool(*v) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPModerateReceiveBufferOption: *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 812ee36ed..c500a0d1c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -47,141 +47,6 @@ import ( // Using header.IPv4AddressSize would cause an import cycle. const ipv4AddressSize = 4 -// Error represents an error in the netstack error space. Using a special type -// ensures that errors outside of this space are not accidentally introduced. -// -// All errors must have unique msg strings. -// -// +stateify savable -type Error struct { - msg string - - ignoreStats bool -} - -// String implements fmt.Stringer.String. -func (e *Error) String() string { - if e == nil { - return "<nil>" - } - return e.msg -} - -// IgnoreStats indicates whether this error type should be included in failure -// counts in tcpip.Stats structs. -func (e *Error) IgnoreStats() bool { - return e.ignoreStats -} - -// Errors that can be returned by the network stack. -var ( - ErrUnknownProtocol = &Error{msg: "unknown protocol"} - ErrUnknownNICID = &Error{msg: "unknown nic id"} - ErrUnknownDevice = &Error{msg: "unknown device"} - ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} - ErrDuplicateNICID = &Error{msg: "duplicate nic id"} - ErrDuplicateAddress = &Error{msg: "duplicate address"} - ErrNoRoute = &Error{msg: "no route"} - ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} - ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} - ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} - ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} - ErrNoPortAvailable = &Error{msg: "no ports are available"} - ErrPortInUse = &Error{msg: "port is in use"} - ErrBadLocalAddress = &Error{msg: "bad local address"} - ErrClosedForSend = &Error{msg: "endpoint is closed for send"} - ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} - ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} - ErrConnectionRefused = &Error{msg: "connection was refused"} - ErrTimeout = &Error{msg: "operation timed out"} - ErrAborted = &Error{msg: "operation aborted"} - ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} - ErrDestinationRequired = &Error{msg: "destination address is required"} - ErrNotSupported = &Error{msg: "operation not supported"} - ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} - ErrNotConnected = &Error{msg: "endpoint not connected"} - ErrConnectionReset = &Error{msg: "connection reset by peer"} - ErrConnectionAborted = &Error{msg: "connection aborted"} - ErrNoSuchFile = &Error{msg: "no such file"} - ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrBadAddress = &Error{msg: "bad address"} - ErrNetworkUnreachable = &Error{msg: "network is unreachable"} - ErrMessageTooLong = &Error{msg: "message too long"} - ErrNoBufferSpace = &Error{msg: "no buffer space available"} - ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} - ErrNotPermitted = &Error{msg: "operation not permitted"} - ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} - ErrMalformedHeader = &Error{msg: "header is malformed"} - ErrBadBuffer = &Error{msg: "bad buffer"} -) - -var messageToError map[string]*Error - -var populate sync.Once - -// StringToError converts an error message to the error. -func StringToError(s string) *Error { - populate.Do(func() { - var errors = []*Error{ - ErrUnknownProtocol, - ErrUnknownNICID, - ErrUnknownDevice, - ErrUnknownProtocolOption, - ErrDuplicateNICID, - ErrDuplicateAddress, - ErrNoRoute, - ErrBadLinkEndpoint, - ErrAlreadyBound, - ErrInvalidEndpointState, - ErrAlreadyConnecting, - ErrAlreadyConnected, - ErrNoPortAvailable, - ErrPortInUse, - ErrBadLocalAddress, - ErrClosedForSend, - ErrClosedForReceive, - ErrWouldBlock, - ErrConnectionRefused, - ErrTimeout, - ErrAborted, - ErrConnectStarted, - ErrDestinationRequired, - ErrNotSupported, - ErrQueueSizeNotSupported, - ErrNotConnected, - ErrConnectionReset, - ErrConnectionAborted, - ErrNoSuchFile, - ErrInvalidOptionValue, - ErrBadAddress, - ErrNetworkUnreachable, - ErrMessageTooLong, - ErrNoBufferSpace, - ErrBroadcastDisabled, - ErrNotPermitted, - ErrAddressFamilyNotSupported, - ErrMalformedHeader, - ErrBadBuffer, - } - - messageToError = make(map[string]*Error) - for _, e := range errors { - if messageToError[e.String()] != nil { - panic("tcpip errors with duplicated message: " + e.String()) - } - messageToError[e.String()] = e - } - }) - - e, ok := messageToError[s] - if !ok { - panic("unknown error message: " + s) - } - - return e -} - // Errors related to Subnet var ( errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") @@ -633,7 +498,7 @@ type Endpoint interface { // If non-zero number of bytes are successfully read and written to dst, err // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer // should be returned. - Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error) + Read(io.Writer, ReadOptions) (ReadResult, Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -648,7 +513,7 @@ type Endpoint interface { // stream (TCP) Endpoints may return partial writes, and even then only // in the case where writing additional data would block. Other Endpoints // will either write the entire message or return an error. - Write(Payloader, WriteOptions) (int64, *Error) + Write(Payloader, WriteOptions) (int64, Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -662,21 +527,21 @@ type Endpoint interface { // connected returns nil. Calling connect again results in ErrAlreadyConnected. // Anything else -- the attempt to connect failed. // - // If address.Addr is empty, this means that Enpoint has to be + // If address.Addr is empty, this means that Endpoint has to be // disconnected if this is supported, otherwise // ErrAddressFamilyNotSupported must be returned. - Connect(address FullAddress) *Error + Connect(address FullAddress) Error // Disconnect disconnects the endpoint from its peer. - Disconnect() *Error + Disconnect() Error // Shutdown closes the read and/or write end of the endpoint connection // to its peer. - Shutdown(flags ShutdownFlags) *Error + Shutdown(flags ShutdownFlags) Error // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. - Listen(backlog int) *Error + Listen(backlog int) Error // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. This method does not @@ -686,36 +551,36 @@ type Endpoint interface { // // If peerAddr is not nil then it is populated with the peer address of the // returned endpoint. - Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. - Bind(address FullAddress) *Error + Bind(address FullAddress) Error // GetLocalAddress returns the address to which the endpoint is bound. - GetLocalAddress() (FullAddress, *Error) + GetLocalAddress() (FullAddress, Error) // GetRemoteAddress returns the address to which the endpoint is // connected. - GetRemoteAddress() (FullAddress, *Error) + GetRemoteAddress() (FullAddress, Error) // Readiness returns the current readiness of the endpoint. For example, // if waiter.EventIn is set, the endpoint is immediately readable. Readiness(mask waiter.EventMask) waiter.EventMask // SetSockOpt sets a socket option. - SetSockOpt(opt SettableSocketOption) *Error + SetSockOpt(opt SettableSocketOption) Error // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOptInt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) Error // GetSockOpt gets a socket option. - GetSockOpt(opt GettableSocketOption) *Error + GetSockOpt(opt GettableSocketOption) Error // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOptInt) (int, *Error) + GetSockOptInt(SockOptInt) (int, Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -738,7 +603,7 @@ type Endpoint interface { SetOwner(owner PacketOwner) // LastError clears and returns the last error reported by the endpoint. - LastError() *Error + LastError() Error // SocketOptions returns the structure which contains all the socket // level options. @@ -993,12 +858,54 @@ type SettableSocketOption interface { isSettableSocketOption() } +// CongestionControlState indicates the current congestion control state for +// TCP sender. +type CongestionControlState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open CongestionControlState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. type TCPInfoOption struct { - RTT time.Duration + // RTT is the smoothed round trip time. + RTT time.Duration + + // RTTVar is the round trip time variation. RTTVar time.Duration + + // RTO is the retransmission timeout for the endpoint. + RTO time.Duration + + // CcState is the congestion control state. + CcState CongestionControlState + + // SndCwnd is the congestion window, in packets. + SndCwnd uint32 + + // SndSsthresh is the threshold between slow start and congestion + // avoidance. + SndSsthresh uint32 + + // ReorderSeen indicates if reordering is seen in the endpoint. + ReorderSeen bool } func (*TCPInfoOption) isGettableSocketOption() {} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 218b218e7..71695b630 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -17,6 +17,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index aedf1845e..38e1881c7 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -38,96 +38,207 @@ import ( var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) -// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner -// link endpoint and checks the destination link address before delivering -// network packets to the network dispatcher. -// -// See ethernet.Endpoint for more details. -func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { - var e endpointWithDestinationCheck - e.Endpoint.Init(ethernet.New(ep), &e) - return &e -} - -// endpointWithDestinationCheck is a link endpoint that checks the destination -// link address before delivering network packets to the network dispatcher. -type endpointWithDestinationCheck struct { - nested.Endpoint -} - -// DeliverNetworkPacket implements stack.NetworkDispatcher. -func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) - } -} - -func TestForwarding(t *testing.T) { - const ( - host1NICID = 1 - routerNICID1 = 2 - routerNICID2 = 3 - host2NICID = 4 - - listenPort = 8080 - ) +const ( + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 +) - host1IPv4Addr := tcpip.ProtocolAddress{ +var ( + host1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 24, }, } - routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + routerNIC1IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + routerNIC2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, }, } - host2IPv4Addr := tcpip.ProtocolAddress{ + host2IPv4Addr = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr := tcpip.ProtocolAddress{ + host1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), PrefixLen: 64, }, } - routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + routerNIC1IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + routerNIC2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr := tcpip.ProtocolAddress{ + host2IPv6Addr = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("b::2").To16()), PrefixLen: 64, }, } +) + +func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) { + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) + + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + { + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + { + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + { + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + { + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) +} + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + +func TestForwarding(t *testing.T) { + const listenPort = 8080 type endpointAndAddresses struct { serverEP tcpip.Endpoint @@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) { subTests := []struct { name string proto tcpip.TransportProtocolNumber - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) needRemoteAddr bool }{ @@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) { { name: "TCP", proto: tcp.ProtocolNumber, - expectedConnectErr: tcpip.ErrConnectStarted, + expectedConnectErr: &tcpip.ErrConnectStarted{}, setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { t.Helper() @@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) { var addr tcpip.FullAddress for { newEP, wq, err := ep.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { <-ch continue } @@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) { host1Stack := stack.New(stackOpts) routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - - host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) - routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - - if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) - } - if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) - } - if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) - } - if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) - } - - if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) - } - if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { - t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) - } - - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) - } - if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { - t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) - } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) - } - - host1Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host1NICID, - }, - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, - NIC: host1NICID, - }, - }) - routerStack.SetRouteTable([]tcpip.Route{ - { - Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID1, - }, - { - Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - { - Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: routerNICID2, - }, - }) - host2Stack.SetRouteTable([]tcpip.Route{ - { - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), - NIC: host2NICID, - }, - { - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - { - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), - Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, - NIC: host2NICID, - }, - }) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) defer epsAndAddrs.serverEP.Close() @@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) { t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) } - if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr { - t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr) + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } } if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index f85164c5b..f2301a9e6 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -19,12 +19,14 @@ import ( "fmt" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -245,6 +247,14 @@ func TestPing(t *testing.T) { } } +type transportError struct { + origin tcpip.SockErrOrigin + typ uint8 + code uint8 + info uint32 + kind stack.TransportErrorKind +} + func TestTCPLinkResolutionFailure(t *testing.T) { const ( host1NICID = 1 @@ -255,8 +265,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error sockError tcpip.SockError + transErr transportError }{ { name: "IPv4 with resolvable remote", @@ -274,12 +285,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) { name: "IPv4 without resolvable remote", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, - ErrType: byte(header.ICMPv4DstUnreachable), - ErrCode: byte(header.ICMPv4HostUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv4Addr3.AddressWithPrefix.Address, @@ -291,17 +299,20 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv4.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP, + typ: uint8(header.ICMPv4DstUnreachable), + code: uint8(header.ICMPv4HostUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, { name: "IPv6 without resolvable remote", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr3.AddressWithPrefix.Address, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, sockError: tcpip.SockError{ - Err: tcpip.ErrNoRoute, - ErrType: byte(header.ICMPv6DstUnreachable), - ErrCode: byte(header.ICMPv6AddressUnreachable), - ErrOrigin: tcpip.SockExtErrorOriginICMP6, + Err: &tcpip.ErrNoRoute{}, Dst: tcpip.FullAddress{ NIC: host1NICID, Addr: ipv6Addr3.AddressWithPrefix.Address, @@ -313,6 +324,12 @@ func TestTCPLinkResolutionFailure(t *testing.T) { }, NetProto: ipv6.ProtocolNumber, }, + transErr: transportError{ + origin: tcpip.SockExtErrorOriginICMP6, + typ: uint8(header.ICMPv6DstUnreachable), + code: uint8(header.ICMPv6AddressUnreachable), + kind: stack.DestinationHostUnreachableTransportError, + }, }, } @@ -355,18 +372,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr - if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted) + { + err := clientEP.Connect(remoteAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{}) + } } // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + _, err := clientEP.Write(&r, wOpts) + if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" { + t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff) + } } if test.expectedWriteErr == nil { @@ -380,14 +403,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { sockErrCmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(tcpip.SockError{}), - cmp.Comparer(func(a, b *tcpip.Error) bool { + cmp.Comparer(func(a, b tcpip.Error) bool { // tcpip.Error holds an unexported field but the errors netstack uses // are pre defined so we can simply compare pointers. return a == b }), - // Ignore the payload since we do not know the TCP seq/ack numbers. checker.IgnoreCmpPath( + // Ignore the payload since we do not know the TCP seq/ack numbers. "Payload", + // Ignore the cause since we will compare its properties separately + // since the concrete type of the cause is unknown. + "Cause", ), } @@ -399,6 +425,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) { if diff := cmp.Diff(&test.sockError, sockErr, sockErrCmpOpts...); diff != "" { t.Errorf("socket error mismatch (-want +got):\n%s", diff) } + + transErr, ok := sockErr.Cause.(stack.TransportError) + if !ok { + t.Fatalf("socket error cause is not a transport error; cause = %#v", sockErr.Cause) + } + if diff := cmp.Diff( + test.transErr, + transportError{ + origin: transErr.Origin(), + typ: transErr.Type(), + code: transErr.Code(), + info: transErr.Info(), + kind: transErr.Kind(), + }, + cmp.AllowUnexported(transportError{}), + ); diff != "" { + t.Errorf("socket error mismatch (-want +got):\n%s", diff) + } }) } } @@ -453,10 +497,11 @@ func TestGetLinkAddress(t *testing.T) { host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) ch := make(chan stack.LinkResolutionResult, 1) - if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { ch <- r - }); err != tcpip.ErrWouldBlock { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock) + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) } wantRes := stack.LinkResolutionResult{Success: test.expectedOk} if test.expectedOk { @@ -570,10 +615,11 @@ func TestRouteResolvedFields(t *testing.T) { wantUnresolvedRouteInfo := wantRouteInfo wantUnresolvedRouteInfo.RemoteLinkAddress = "" - if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { + err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) { ch <- r - }); err != tcpip.ErrWouldBlock { - t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, tcpip.ErrWouldBlock) + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) @@ -616,7 +662,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { name string netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "IPv4", @@ -703,7 +749,7 @@ func TestWritePacketsLinkResolution(t *testing.T) { var rOpts tcpip.ReadOptions res, err := serverEP.Read(&writer, rOpts) if err != nil { - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Should not have anymore bytes to read after we read the sent // number of bytes. if count == len(data) { @@ -728,3 +774,439 @@ func TestWritePacketsLinkResolution(t *testing.T) { }) } } + +type eventType int + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved +) + +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + entry stack.NeighborEntry +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) +} + +var _ stack.NUDDispatcher = (*nudDispatcher)(nil) + +type nudDispatcher struct { + c chan eventInfo +} + +func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + entry: entry, + } + d.c <- e +} + +func (d *nudDispatcher) waitForEvent(want eventInfo) error { + if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil +} + +// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it +// that the neighbor used for a route is reachable. +func TestTCPConfirmNeighborReachability(t *testing.T) { + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + neighborAddr tcpip.Address + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + isHost1Listener bool + }{ + { + name: "IPv4 active connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 active connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv6 active connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + }, + { + name: "IPv4 passive connection to neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection to neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv4 passive connection through neighbor", + netProto: ipv4.ProtocolNumber, + remoteAddr: host1IPv4Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + { + name: "IPv6 passive connection through neighbor", + netProto: ipv6.ProtocolNumber, + remoteAddr: host1IPv6Addr.AddressWithPrefix.Address, + neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address, + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + var listenerWQ waiter.Queue + listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventOut) + clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) + if err != nil { + listenerEP.Close() + t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) + } + + return listenerEP, clientEP, clientCH + }, + isHost1Listener: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + nudDisp := nudDispatcher{ + c: make(chan eventInfo, 3), + } + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, + UseNeighborCache: true, + } + host1StackOpts := stackOpts + host1StackOpts.NUDDisp = &nudDisp + + host1Stack := stack.New(host1StackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + setupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + // Add a reachable dynamic entry to our neighbor table for the remote. + { + ch := make(chan stack.LinkResolutionResult, 1) + err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{}) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryAdded, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, + }); err != nil { + t.Fatalf("error waiting for initial NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the remote's neighbor entry to be stale before creating a + // TCP connection from host1 to some remote. + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + // The maximum reachable time for a neighbor is some maximum random factor + // applied to the base reachable time. + // + // See NUDConfigurations.BaseReachableTime for more information. + maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + + listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) + defer listenerEP.Close() + defer clientEP.Close() + listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} + if err := listenerEP.Bind(listenerAddr); err != nil { + t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err) + } + if err := listenerEP.Listen(1); err != nil { + t.Fatalf("listenerEP.Listen(1): %s", err) + } + { + err := clientEP.Connect(listenerAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{}) + } + } + + // Wait for the TCP handshake to complete then make sure the neighbor is + // reachable without entering the probe state as TCP should provide NUD + // with confirmation that the neighbor is reachable (indicated by a + // successful 3-way handshake). + <-clientCH + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + + // Wait for the neighbor to be stale again then send data to the remote. + // + // On successful transmission, the neighbor should become reachable + // without probing the neighbor as a TCP ACK would be received which is an + // indication of the neighbor being reachable. + clock.Advance(maxReachableTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for stale NUD event: %s", err) + } + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for delay NUD event: %s", err) + } + if test.isHost1Listener { + // If host1 is not the client, host1 does not send any data so TCP + // has no way to know it is making forward progress. Because of this, + // TCP should not mark the route reachable and NUD should go through the + // probe state. + clock.Advance(nudConfigs.DelayFirstProbeTime) + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for probe NUD event: %s", err) + } + } + if err := nudDisp.waitForEvent(eventInfo{ + eventType: entryChanged, + nicID: host1NICID, + entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2}, + }); err != nil { + t.Fatalf("error waiting for reachable NUD event: %s", err) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 761283b66..ab67762ef 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { @@ -262,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -322,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: data.ToVectorisedView(), - })); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + { + err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })) + if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + } } } @@ -470,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { Addr: test.dstAddr, Port: localPort, } - if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + { + err := connectingEndpoint.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } } if !test.expectAccept { - if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + _, _, err := listeningEndpoint.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } return } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 9cc12fa58..d685fdd36 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if diff := cmp.Diff(data, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } - } else if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) + } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{}) } }) } @@ -761,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + { + _, err := ep.Read(&buf, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{}) + } } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 35ee7437a..9654c9527 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) { linkEndpoint func() stack.LinkEndpoint localAddr tcpip.Address icmpBuf func(*testing.T) buffer.View - expectedConnectErr *tcpip.Error + expectedConnectErr tcpip.Error checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) }{ { @@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: loopback.New, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, }, { @@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv4.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv4ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, { @@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) { netProto: ipv6.ProtocolNumber, linkEndpoint: channelEP, icmpBuf: ipv6ICMPBuf, - expectedConnectErr: tcpip.ErrNoRoute, + expectedConnectErr: &tcpip.ErrNoRoute{}, checkLinkEndpoint: channelEPCheck, }, } @@ -186,8 +186,11 @@ func TestLocalPing(t *testing.T) { defer ep.Close() connAddr := tcpip.FullAddress{Addr: test.localAddr} - if err := ep.Connect(connAddr); err != test.expectedConnectErr { - t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + { + err := ep.Connect(connAddr) + if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff) + } } if test.expectedConnectErr != nil { @@ -263,12 +266,12 @@ func TestLocalUDP(t *testing.T) { subTests := []struct { name string addAddress bool - expectedWriteErr *tcpip.Error + expectedWriteErr tcpip.Error }{ { name: "Unassigned local address", addAddress: false, - expectedWriteErr: tcpip.ErrNoRoute, + expectedWriteErr: &tcpip.ErrNoRoute{}, }, { name: "Assigned local address", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e4bcd3120..f5e1a6e45 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -84,7 +84,7 @@ type endpoint struct { ops tcpip.SocketOptions } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -159,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -204,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: case stateConnected: @@ -212,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case stateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -241,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -262,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -275,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -297,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc nicID := to.NIC if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -320,10 +320,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, tcpip.ErrBadBuffer + return 0, &tcpip.ErrBadBuffer{} } - var err *tcpip.Error + var err tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) @@ -340,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.TTLOption: e.mu.Lock() @@ -357,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -382,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -411,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } icmpv4.SetChecksum(0) @@ -425,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -442,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } dataVV := data.ToVectorisedView() @@ -457,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { return tcpip.FullAddress{}, 0, err @@ -466,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -486,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -536,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() e.shutdownFlags |= flags if e.state != stateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } if flags&tcpip.ShutdownRead != 0 { @@ -566,16 +566,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. @@ -584,13 +584,13 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { id.LocalPort = p err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) - switch err { + switch err.(type) { case nil: return true, nil - case tcpip.ErrPortInUse: + case *tcpip.ErrPortInUse: return false, nil default: return false, err @@ -600,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ return id, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.state != stateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -620,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { if len(addr.Addr) != 0 { // A local address was specified, verify that it's valid. if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -648,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -664,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -676,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.state != stateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -778,9 +778,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { -} +// HandleError implements stack.TransportEndpoint. +func (*endpoint) HandleError(stack.TransportError, *stack.PacketBuffer) {} // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't // expose internal socket state. @@ -806,7 +805,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index afe96998a..c9fa9974a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -75,7 +75,7 @@ func (e *endpoint) Resume(s *stack.Stack) { return } - var err *tcpip.Error + var err tcpip.Error if e.state == stateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) if err != nil { @@ -85,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.ID.LocalAddress = e.route.LocalAddress } else if len(e.ID.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3820e5dc7..47f7dd1cb 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber { // NewEndpoint creates a new icmp endpoint. It implements // stack.TransportProtocol.NewEndpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return newEndpoint(p.stack, netProto, p.number, waiterQueue) } // NewRawEndpoint creates a new raw icmp endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { if netProto != p.netProto() { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue) } @@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int { } // ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. -func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { switch p.number { case ProtocolNumber4: hdr := header.ICMPv4(v) @@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index d48877677..73bb66830 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -86,15 +86,15 @@ type endpoint struct { boundNIC tcpip.NICID // lastErrorMu protects lastError. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // ops is used to get socket level options. ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. -func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ @@ -159,16 +159,16 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if ep.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if ep.rcvClosed { ep.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } ep.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -198,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } -func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be -// connected, and this function always returnes tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - return tcpip.ErrNotSupported +// connected, and this function always returnes *tcpip.ErrNotSupported. +func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used -// with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return tcpip.ErrNotSupported +// with Shutdown, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with -// Listen, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +// Listen, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with -// Accept, and this function always returns tcpip.ErrNotSupported. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +// Accept, and this function always returns *tcpip.ErrNotSupported. +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { // TODO(gvisor.dev/issue/173): Add Bind support. // "By default, all packets of the specified protocol type are passed @@ -274,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -303,19 +303,19 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns -// tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +// *tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -336,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } -func (ep *endpoint) LastError() *tcpip.Error { +func (ep *endpoint) LastError() tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -350,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (ep *endpoint) UpdateLastError(err *tcpip.Error) { +func (ep *endpoint) UpdateLastError(err tcpip.Error) { ep.lastErrorMu.Lock() ep.lastError = err ep.lastErrorMu.Unlock() } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrNotSupported +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrNotSupported{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -381,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 4d98fb051..ece662c0d 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -68,24 +68,6 @@ func (ep *endpoint) afterLoad() { // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { - panic(*err) + panic(err) } } - -// saveLastError is invoked by stateify. -func (ep *endpoint) saveLastError() string { - if ep.lastError == nil { - return "" - } - - return ep.lastError.String() -} - -// loadLastError is invoked by stateify. -func (ep *endpoint) loadLastError(s string) { - if s == "" { - return - } - - ep.lastError = tcpip.StringToError(s) -} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6c6d45188..9c9ccc0ff 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -93,13 +93,13 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) } -func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, tcpip.ErrUnknownProtocol + return nil, &tcpip.ErrUnknownProtocol{} } e := &endpoint{ @@ -189,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the // endpoint is closed. if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -225,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil } // Write implements tcpip.Endpoint.Write. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // We can create, but not write to, unassociated IPv6 endpoints. if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } } n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -265,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } e.mu.RLock() defer e.mu.RUnlock() if e.closed { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } payloadBytes := make([]byte, p.Len()) if _, err := io.ReadFull(p, payloadBytes); err != nil { - return 0, tcpip.ErrBadBuffer + return 0, &tcpip.ErrBadBuffer{} } // If this is an unassociated socket and callee provided a nonzero @@ -288,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } dstAddr := ip.DestinationAddress() // Update dstAddr with the address in the IP header, unless @@ -309,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - return 0, tcpip.ErrDestinationRequired + return 0, &tcpip.ErrDestinationRequired{} } return e.finishWrite(payloadBytes, e.route) @@ -319,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } // Find the route to the destination. If BindAddress is 0, @@ -336,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) { +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), @@ -363,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect implements tcpip.Endpoint.Connect. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrAddressFamilyNotSupported + return &tcpip.ErrAddressFamilyNotSupported{} } e.mu.Lock() defer e.mu.Unlock() if e.closed { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nic := addr.NIC @@ -393,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } else if addr.NIC != e.BindNICID { // We're bound and addr specifies a NIC. They must be // the same. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } } @@ -424,34 +424,34 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() if !e.connected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(backlog int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept implements tcpip.Endpoint.Accept. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } // Bind implements tcpip.Endpoint.Bind. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // If a local address was specified, verify that it's valid. if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } if e.associated { @@ -471,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, tcpip.ErrNotSupported +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { // Even a connected socket doesn't return a remote address. - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness implements tcpip.Endpoint.Readiness. @@ -499,18 +499,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -531,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -560,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } @@ -680,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { func (*endpoint) Wait() {} // LastError implements tcpip.Endpoint.LastError. -func (*endpoint) LastError() *tcpip.Error { +func (*endpoint) LastError() tcpip.Error { return nil } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 65c64d99f..263ec5146 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if e.connected { - var err *tcpip.Error + var err tcpip.Error // TODO(gvisor.dev/issue/4906): Properly restore the route with the right // remote address. We used to pass e.remote.RemoteAddress which was // effectively the empty address but since moving e.route to hold a pointer @@ -89,7 +89,7 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is bound, re-bind. if e.bound { if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go index f30aa2a4a..e393b993d 100644 --- a/pkg/tcpip/transport/raw/protocol.go +++ b/pkg/tcpip/transport/raw/protocol.go @@ -25,11 +25,11 @@ import ( type EndpointFactory struct{} // NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. -func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) } // NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. -func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e475c36f3..842c1622b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // On success, a handshake h is returned with h.ep.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } l.addPendingEndpoint(ep) @@ -281,7 +281,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q l.removePendingEndpoint(ep) - return nil, tcpip.ErrConnectionAborted + return nil, &tcpip.ErrConnectionAborted{} } deferAccept = l.listenEP.deferAccept @@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // established endpoint is returned with e.mu held. // // Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) { h, err := l.startHandshake(s, opts, queue, owner) if err != nil { return nil, err @@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error { defer s.decRef() h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) @@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // and needs to handle it. // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 62954d7e4..34a631b53 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool { // synSentState handles a segment received when the TCP 3-way handshake is in // the SYN-SENT state. -func (h *handshake) synSentState(s *segment) *tcpip.Error { +func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. if s.flagIsSet(header.TCPFlagRst) { @@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.ep.workerCleanup = true // Although the RFC above calls out ECONNRESET, Linux actually returns // ECONNREFUSED here so we do as well. - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. -func (h *handshake) synRcvdState(s *segment) *tcpip.Error { +func (h *handshake) synRcvdState(s *segment) tcpip.Error { if s.flagIsSet(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - return tcpip.ErrConnectionRefused + return &tcpip.ErrConnectionRefused{} } return nil } @@ -333,7 +333,9 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // number and "After sending the acknowledgment, drop the unacceptable // segment and return." if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { - h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + if h.ep.allowOutOfWindowAck() { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + } return nil } @@ -349,7 +351,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) if !h.active { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } h.resetState() @@ -412,7 +414,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } -func (h *handshake) handleSegment(s *segment) *tcpip.Error { +func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) @@ -429,7 +431,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error { // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). -func (h *handshake) processSegments() *tcpip.Error { +func (h *handshake) processSegments() tcpip.Error { for i := 0; i < maxSegmentsPerWake; i++ { s := h.ep.segmentQueue.dequeue() if s == nil { @@ -505,7 +507,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { +func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper resendWaker := sleep.Waker{} @@ -555,7 +557,7 @@ func (h *handshake) complete() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() if (n¬ifyClose)|(n¬ifyAbort) != 0 { - return tcpip.ErrAborted + return &tcpip.ErrAborted{} } if n¬ifyDrain != 0 { for !h.ep.segmentQueue.empty() { @@ -593,19 +595,19 @@ type backoffTimer struct { t *time.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { - return nil, tcpip.ErrTimeout + return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} bt.t = time.AfterFunc(timeout, f) return bt, nil } -func (bt *backoffTimer) reset() *tcpip.Error { +func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 if bt.timeout > MaxRTO { - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) return nil @@ -706,7 +708,7 @@ type tcpFields struct { txHash uint32 } -func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { @@ -716,7 +718,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -755,7 +757,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -803,7 +805,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 @@ -875,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] @@ -941,12 +943,14 @@ func (e *endpoint) handleClose() { // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This // method must only be called from the protocol goroutine. -func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { +func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) e.hardError = err - if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + switch err.(type) { + case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout: + default: // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk // which can cause sndNxt to be outside the acceptable window on the @@ -1056,7 +1060,7 @@ func (e *endpoint) drainClosingSegmentQueue() { } } -func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { +func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { if e.rcv.acceptable(s.sequenceNumber, 0) { // RFC 793, page 37 states that "in all states // except SYN-SENT, all reset (RST) segments are @@ -1084,7 +1088,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.hardError = tcpip.ErrAborted + e.hardError = &tcpip.ErrAborted{} e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1097,14 +1101,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // handleSegment is invoked from the processor goroutine // rather than the worker goroutine. e.notifyProtocolGoroutine(notifyResetByPeer) - return false, tcpip.ErrConnectionReset + return false, &tcpip.ErrConnectionReset{} } } return true, nil } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { +func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1151,7 +1155,7 @@ func (e *endpoint) probeSegment() { // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { +func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. defer e.probeSegment() @@ -1183,8 +1187,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // endpoint MUST terminate its connection. The local TCP endpoint // should then rely on SYN retransmission from the remote end to // re-establish the connection. - - e.snd.sendAck() + e.snd.maybeSendOutOfWindowAck(s) } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -1225,7 +1228,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. -func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { +func (e *endpoint) keepaliveTimerExpired() tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() @@ -1239,13 +1242,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -1289,7 +1292,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1335,6 +1338,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + // Reaching this point means that we successfully completed the 3-way + // handshake with our peer. + // + // Completing the 3-way handshake is an indication that the route is valid + // and the remote is reachable as the only way we can complete a handshake + // is if our SYN reached the remote and their ACK reached us. + e.route.ConfirmReachable() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1347,25 +1358,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // wakes up. funcs := []struct { w *sleep.Waker - f func() *tcpip.Error + f func() tcpip.Error }{ { w: &e.sndWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { e.handleWrite() return nil }, }, { w: &e.sndCloseWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { e.handleClose() return nil }, }, { w: &closeWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { // This means the socket is being closed due // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. @@ -1376,10 +1387,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.snd.resendWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { if !e.snd.retransmitTimerExpired() { e.stack.Stats().TCP.EstablishedTimedout.Increment() - return tcpip.ErrTimeout + return &tcpip.ErrTimeout{} } return nil }, @@ -1390,7 +1401,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.newSegmentWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { return e.handleSegments(false /* fastPath */) }, }, @@ -1400,7 +1411,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, { w: &e.notificationWaker, - f: func() *tcpip.Error { + f: func() tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -1417,11 +1428,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyReset != 0 || n¬ifyAbort != 0 { - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} } if n¬ifyResetByPeer != 0 { - return tcpip.ErrConnectionReset + return &tcpip.ErrConnectionReset{} } if n¬ifyClose != 0 && closeTimer == nil { @@ -1500,7 +1511,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. - cleanupOnError := func(err *tcpip.Error) { + cleanupOnError := func(err tcpip.Error) { e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 7b1f5e763..1975f1a44 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -178,8 +178,8 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int return int(cwnd) } -// HandleNDupAcks implements congestionControl.HandleNDupAcks. -func (c *cubicState) HandleNDupAcks() { +// HandleLossDetected implements congestionControl.HandleLossDetected. +func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ c.t = time.Now() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 809c88732..2d90246e4 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrNoRoute { + if _, ok := err.(*tcpip.ErrNoRoute); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } } @@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -525,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -549,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) @@ -613,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -635,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b6bd6d455..4e5a6089f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -386,12 +386,12 @@ type endpoint struct { // hardError is meaningful only when state is stateError. It stores the // error to be returned when read/write syscalls are called and the // endpoint is in this state. hardError is protected by endpoint mu. - hardError *tcpip.Error `state:".(string)"` + hardError tcpip.Error // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // rcvReadMu synchronizes calls to Read. // @@ -688,6 +688,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // lastOutOfWindowAckTime is the time at which the an ACK was sent in response + // to an out of window segment being received by this endpoint. + lastOutOfWindowAckTime time.Time `state:".(unixTime)"` } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1059,7 +1063,7 @@ func (e *endpoint) Close() { if isResetState { // Close the endpoint without doing full shutdown and // send a RST. - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.closeNoShutdownLocked() // Wake up worker to close the endpoint. @@ -1293,14 +1297,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) hardErrorLocked() *tcpip.Error { +func (e *endpoint) hardErrorLocked() tcpip.Error { err := e.hardError e.hardError = nil return err } // Preconditions: e.mu must be held to call this function. -func (e *endpoint) lastErrorLocked() *tcpip.Error { +func (e *endpoint) lastErrorLocked() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1309,7 +1313,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error { } // LastError implements tcpip.Endpoint.LastError. -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.LockUser() defer e.UnlockUser() if err := e.hardErrorLocked(); err != nil { @@ -1319,7 +1323,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() e.lastErrorMu.Lock() e.lastError = err @@ -1328,7 +1332,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1337,7 +1341,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { - if serr == tcpip.ErrClosedForReceive { + if _, ok := serr.(*tcpip.ErrClosedForReceive); ok { e.stats.ReadErrors.ReadClosed.Increment() } return tcpip.ReadResult{}, serr @@ -1377,7 +1381,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // If something is read, we must report it. Report error when nothing is read. if done == 0 && err != nil { - return tcpip.ReadResult{}, tcpip.ErrBadBuffer + return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{} } return tcpip.ReadResult{ Count: done, @@ -1389,7 +1393,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // inclusive range of segments that can be read. // // Precondition: e.rcvReadMu must be held. -func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { +func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1398,7 +1402,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } // The endpoint can be read if it's connected, or if it's already closed @@ -1414,17 +1418,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { if err := e.hardErrorLocked(); err != nil { return nil, nil, err } - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } e.stats.ReadErrors.NotConnected.Increment() - return nil, nil, tcpip.ErrNotConnected + return nil, nil, &tcpip.ErrNotConnected{} } if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return nil, nil, tcpip.ErrClosedForReceive + return nil, nil, &tcpip.ErrClosedForReceive{} } - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } return e.rcvList.Front(), e.rcvList.Back(), nil @@ -1476,39 +1480,39 @@ func (e *endpoint) commitRead(done int) *segment { // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu -func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { +func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: if err := e.hardErrorLocked(); err != nil { return 0, err } - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case !s.connecting() && !s.connected(): - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} case s.connecting(): // As per RFC793, page 56, a send request arriving when in connecting // state, can be queued to be completed after the state becomes // connected. Return an error code for the caller of endpoint Write to // try again, until the connection handshake is complete. - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } // Check if the connection has already been closed for sends. if e.sndClosed { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } sndBufSize := e.getSendBufferSize() avail := sndBufSize - e.sndBufUsed if avail <= 0 { - return 0, tcpip.ErrWouldBlock + return 0, &tcpip.ErrWouldBlock{} } return avail, nil } // Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More // and opts.EndOfRecord are also ignored. @@ -1516,7 +1520,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc e.LockUser() defer e.UnlockUser() - nextSeg, n, err := func() (*segment, int, *tcpip.Error) { + nextSeg, n, err := func() (*segment, int, tcpip.Error) { e.sndBufMu.Lock() defer e.sndBufMu.Unlock() @@ -1526,7 +1530,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return nil, 0, err } - v, err := func() ([]byte, *tcpip.Error) { + v, err := func() ([]byte, tcpip.Error) { // We can release locks while copying data. // // This is not possible if atomic is set, because we can't allow the @@ -1549,7 +1553,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } v := make([]byte, avail) if _, err := io.ReadFull(p, v); err != nil { - return nil, tcpip.ErrBadBuffer + return nil, &tcpip.ErrBadBuffer{} } return v, nil }() @@ -1702,7 +1706,7 @@ func (e *endpoint) getSendBufferSize() int { } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 const inetECNMask = 3 @@ -1730,7 +1734,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.MaxSegOption: userMSS := v if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.userMSS = uint16(userMSS) @@ -1741,7 +1745,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Return not supported if attempting to set this option to // anything other than path MTU discovery disabled. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.ReceiveBufferSizeOption: @@ -1801,7 +1805,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.TCPSynCountOption: if v < 1 || v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } e.LockUser() e.maxSynRetries = uint8(v) @@ -1817,7 +1821,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil default: e.UnlockUser() - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } } var rs tcpip.TCPReceiveBufferSizeRangeOption @@ -1838,7 +1842,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -1884,7 +1888,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPLingerTimeoutOption: e.LockUser() @@ -1927,13 +1931,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // readyReceiveSize returns the number of bytes ready to be received. -func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { +func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { e.LockUser() defer e.UnlockUser() // The endpoint cannot be in listen state. if e.EndpointState() == StateListen { - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } e.rcvListMu.Lock() @@ -1943,7 +1947,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.KeepaliveCountOption: e.keepalive.Lock() @@ -2007,24 +2011,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 1, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } +func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { + info := tcpip.TCPInfoOption{} + e.LockUser() + snd := e.snd + if snd != nil { + // We do not calculate RTT before sending the data packets. If + // the connection did not send and receive data, then RTT will + // be zero. + snd.rtt.Lock() + info.RTT = snd.rtt.srtt + info.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + + info.RTO = snd.rto + info.CcState = snd.state + info.SndSsthresh = uint32(snd.sndSsthresh) + info.SndCwnd = uint32(snd.sndCwnd) + info.ReorderSeen = snd.rc.reorderSeen + } + e.UnlockUser() + return info +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.TCPInfoOption: - *o = tcpip.TCPInfoOption{} - e.LockUser() - snd := e.snd - e.UnlockUser() - if snd != nil { - snd.rtt.Lock() - o.RTT = snd.rtt.srtt - o.RTTVar = snd.rtt.rttvar - snd.rtt.Unlock() - } + *o = e.getTCPInfo() case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() @@ -2070,14 +2088,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -2086,18 +2104,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (*endpoint) Disconnect() *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Disconnect() tcpip.Error { + return &tcpip.ErrNotSupported{} } // Connect connects the endpoint to its peer. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { err := e.connect(addr, true, true) - if err != nil && !err.IgnoreStats() { - // Connect failed. Let's wake up any waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } @@ -2108,7 +2128,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2127,7 +2147,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return nil } // Otherwise return that it's already connected. - return tcpip.ErrAlreadyConnected + return &tcpip.ErrAlreadyConnected{} } nicID := addr.NIC @@ -2140,7 +2160,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if nicID != 0 && nicID != e.boundNICID { - return tcpip.ErrNoRoute + return &tcpip.ErrNoRoute{} } nicID = e.boundNICID @@ -2152,16 +2172,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc case StateConnecting, StateSynSent, StateSynRecv: // A connection request has already been issued but hasn't completed // yet. - return tcpip.ErrAlreadyConnecting + return &tcpip.ErrAlreadyConnecting{} case StateError: if err := e.hardErrorLocked(); err != nil { return err } - return tcpip.ErrConnectionAborted + return &tcpip.ErrConnectionAborted{} default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Find a route to the desired destination. @@ -2217,12 +2237,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { - if err != tcpip.ErrPortInUse || !reuse { + if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } transEPID := e.ID @@ -2268,7 +2288,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) - if err == tcpip.ErrPortInUse { + if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } return false, err @@ -2323,23 +2343,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } - return tcpip.ErrConnectStarted + return &tcpip.ErrConnectStarted{} } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection to its // peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.LockUser() defer e.UnlockUser() return e.shutdownLocked(flags) } -func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.shutdownFlags |= flags switch { case e.EndpointState().connected(): @@ -2354,7 +2374,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) // Wake up worker to terminate loop. e.notifyProtocolGoroutine(notifyTickleWorker) return nil @@ -2368,7 +2388,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { // Already closed. e.sndBufMu.Unlock() if e.EndpointState() == StateTimeWait { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } return nil } @@ -2401,22 +2421,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } return nil default: - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } } // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) tcpip.Error { err := e.listen(backlog) - if err != nil && !err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() + if err != nil { + if !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } } return err } -func (e *endpoint) listen(backlog int) *tcpip.Error { +func (e *endpoint) listen(backlog int) tcpip.Error { e.LockUser() defer e.UnlockUser() @@ -2434,7 +2456,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Adjust the size of the channel iff we can fix // existing pending connections into the new one. if len(e.acceptedChan) > backlog { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } if cap(e.acceptedChan) == backlog { return nil @@ -2466,7 +2488,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // Endpoint must be bound before it can transition to listen mode. if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } // Register the endpoint. @@ -2506,7 +2528,7 @@ func (e *endpoint) startAcceptedLoop() { // to an endpoint previously set to listen mode. // // addr if not-nil will contain the peer address of the returned endpoint. -func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2515,7 +2537,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.rcvListMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { - return nil, nil, tcpip.ErrInvalidEndpointState + return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. @@ -2526,7 +2548,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. case n = <-e.acceptedChan: e.acceptCond.Signal() default: - return nil, nil, tcpip.ErrWouldBlock + return nil, nil, &tcpip.ErrWouldBlock{} } if peerAddr != nil { *peerAddr = n.getRemoteAddress() @@ -2535,19 +2557,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. } // Bind binds the endpoint to a specific local port and optionally address. -func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) { e.LockUser() defer e.UnlockUser() return e.bindLocked(addr) } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. if e.EndpointState() != StateInitial { - return tcpip.ErrAlreadyBound + return &tcpip.ErrAlreadyBound{} } e.BindAddr = addr.Addr @@ -2575,7 +2597,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { if len(addr.Addr) != 0 { nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } e.ID.LocalAddress = addr.Addr } @@ -2616,7 +2638,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2628,12 +2650,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.LockUser() defer e.UnlockUser() if !e.EndpointState().connected() { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return e.getRemoteAddress(), nil @@ -2665,7 +2687,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -2674,11 +2696,8 @@ func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, ext // Update the error queue if IP_RECVERR is enabled. if e.SocketOptions().GetRecvError() { e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, + Err: err, + Cause: transErr, // Linux passes the payload with the TCP header. We don't know if the TCP // header even exists, it may not for fragmented packets. Payload: pkt.Data.ToView(), @@ -2700,27 +2719,26 @@ func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, ext e.notifyProtocolGoroutine(notifyError) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - switch typ { - case stack.ControlPacketTooBig: +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + handlePacketTooBig := func(mtu uint32) { e.sndBufMu.Lock() e.packetTooBigCount++ - if v := int(extra); v < e.sndMTU { + if v := int(mtu); v < e.sndMTU { e.sndMTU = v } e.sndBufMu.Unlock() - e.notifyProtocolGoroutine(notifyMTUChanged) + } - case stack.ControlNoRoute: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) - - case stack.ControlAddressUnreachable: - e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt) - - case stack.ControlNetworkUnreachable: - e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.PacketTooBigTransportError: + handlePacketTooBig(transErr.Info()) + case stack.DestinationHostUnreachableTransportError: + e.onICMPError(&tcpip.ErrNoRoute{}, transErr, pkt) + case stack.DestinationNetworkUnreachableTransportError: + e.onICMPError(&tcpip.ErrNetworkUnreachable{}, transErr, pkt) } } @@ -3013,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState { rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, + ReoWnd: rc.reoWnd, + ReoWndIncr: rc.reoWndIncr, + ReoWndPersist: rc.reoWndPersist, + RTTSeq: rc.rttSeq, } return s } @@ -3107,3 +3129,19 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { Max: ss.Max, } } + +// allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. +func (e *endpoint) allowOutOfWindowAck() bool { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + + now := time.Now() + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } + + e.lastOutOfWindowAckTime = now + return true +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4a01c81b4..e4368026f 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -59,7 +59,7 @@ func (e *endpoint) beforeSave() { Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), }) } - e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) e.mu.Unlock() e.Close() e.mu.Lock() @@ -232,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) { // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } e.mu.Lock() @@ -269,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } connectingLoading.Done() @@ -296,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // saveRecentTSTime is invoked by stateify. func (e *endpoint) saveRecentTSTime() unixTime { return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} @@ -324,22 +308,14 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { e.recentTSTime = time.Unix(unix.second, unix.nano) } -// saveHardError is invoked by stateify. -func (e *endpoint) saveHardError() string { - if e.hardError == nil { - return "" - } - - return e.hardError.String() +// saveLastOutOfWindowAckTime is invoked by stateify. +func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { + return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} } -// loadHardError is invoked by stateify. -func (e *endpoint) loadHardError(s string) { - if s == "" { - return - } - - e.hardError = tcpip.StringToError(s) +// loadLastOutOfWindowAckTime is invoked by stateify. +func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { + e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 596178625..2f9fe7ee0 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) { // CreateEndpoint creates a TCP endpoint for the connection request, performing // the 3-way handshake in the process. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { r.mu.Lock() defer r.mu.Unlock() if r.segment == nil { - return nil, tcpip.ErrInvalidEndpointState + return nil, &tcpip.ErrInvalidEndpointState{} } f := r.forwarder diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1720370c9..04012cd40 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new tcp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently // unsupported. It implements stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue) } @@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given tcp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.TCP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err @@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.Lock() @@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.sendBufferSize = *v @@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.recvBufferSize = *v @@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi } // linux returns ENOENT when an invalid congestion control // is specified. - return tcpip.ErrNoSuchFile + return &tcpip.ErrNoSuchFile{} case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() @@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPTimeWaitReuseOption: if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.timeWaitReuse = *v @@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi case *tcpip.TCPSynRetriesOption: if *v < 1 || *v > 255 { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } p.mu.Lock() p.synRetries = uint8(*v) @@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.TCPSACKEnabled: p.mu.RLock() @@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E return nil default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 307bacca5..e862f159e 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -22,12 +22,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) -// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as -// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. -// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is -// 1, PTO is inflated by WCDelAckT time to compensate for a potential long -// delayed ACK timer at the receiver. -const wcDelayedACKTimeout = 200 * time.Millisecond +const ( + // wcDelayedACKTimeout is the recommended maximum delayed ACK timer + // value as defined in the RFC. It stands for worst case delayed ACK + // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by + // WCDelAckT time to compensate for a potential long delayed ACK timer + // at the receiver. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. + wcDelayedACKTimeout = 200 * time.Millisecond + + // tcpRACKRecoveryThreshold is the number of loss recoveries for which + // the reorder window is inflated and after that the reorder window is + // reset to its initial value of minRTT/4. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2. + tcpRACKRecoveryThreshold = 16 +) // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or @@ -44,6 +53,11 @@ type rackControl struct { // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // exitedRecovery indicates if the connection is exiting loss recovery. + // This flag is set if the sender is leaving the recovery after + // receiving an ACK and is reset during updating of reorder window. + exitedRecovery bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -51,15 +65,30 @@ type rackControl struct { // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // reoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + reoWnd time.Duration + + // reoWndIncr is the multiplier applied to adjust reorder window. + reoWndIncr uint8 + + // reoWndPersist is the number of loss recoveries before resetting + // reorder window. + reoWndPersist int8 + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool + // rttSeq is the SND.NXT when rtt is updated. + rttSeq seqnum.Value // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` @@ -75,29 +104,36 @@ type rackControl struct { // tlpHighRxt the value of sender.sndNxt at the time of sending // a TLP retransmission. tlpHighRxt seqnum.Value + + // snd is a reference to the sender. + snd *sender } // init initializes RACK specific fields. -func (rc *rackControl) init() { +func (rc *rackControl) init(snd *sender, iss seqnum.Value) { + rc.fack = iss + rc.reoWndIncr = 1 + rc.snd = snd rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. -// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 +func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) + tsOffset := rc.snd.ep.tsOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: - // 1. When Timestamping option is available, if the TSVal is less than the - // transmit time of the most recent retransmitted packet. + // 1. When Timestamping option is available, if the TSVal is less than + // the transmit time of the most recent retransmitted packet. // 2. When RTT calculated for the packet is less than the smoothed RTT // for the connection. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // step 2 if seg.xmitCount > 1 { if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { - if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) { return } } @@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) { } } -// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. -func (rc *rackControl) setDSACKSeen() { - rc.dsackSeen = true +func (rc *rackControl) setDSACKSeen(dsackSeen bool) { + rc.dsackSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool { // The connection supports SACK. s.ep.sackPermitted && // The connection is not in loss recovery. - (s.state != RTORecovery && s.state != SACKRecovery) && + (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. s.ep.scoreboard.Sacked() == 0 } @@ -193,7 +228,7 @@ func (s *sender) schedulePTO() { // probeTimerExpired is the same as TLP_send_probe() as defined in // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. -func (s *sender) probeTimerExpired() *tcpip.Error { +func (s *sender) probeTimerExpired() tcpip.Error { if !s.rc.probeTimer.checkExpiration() { return nil } @@ -266,9 +301,88 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // Step 2. Either the original packet or the retransmission (in the // form of a probe) was lost. Invoke a congestion control response // equivalent to fast recovery. - s.cc.HandleNDupAcks() + s.cc.HandleLossDetected() s.enterRecovery() s.leaveRecovery() } } } + +// updateRACKReorderWindow updates the reorder window. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 4: Update RACK reordering window +// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as +// an allowance for settling time before marking a packet lost. RACK starts +// initially with a conservative window of min_RTT/4. If no reordering has +// been observed RACK uses reo_wnd of zero during loss recovery, in order to +// retransmit quickly, or when the number of DUPACKs exceeds the classic +// DUPACKthreshold. +func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { + dsackSeen := rc.dsackSeen + snd := rc.snd + + // React to DSACK once per round trip. + // If SND.UNA < RACK.rtt_seq: + // RACK.dsack = false + if snd.sndUna.LessThan(rc.rttSeq) { + dsackSeen = false + } + + // If RACK.dsack: + // RACK.reo_wnd_incr += 1 + // RACK.dsack = false + // RACK.rtt_seq = SND.NXT + // RACK.reo_wnd_persist = 16 + if dsackSeen { + rc.reoWndIncr++ + dsackSeen = false + rc.rttSeq = snd.sndNxt + rc.reoWndPersist = tcpRACKRecoveryThreshold + } else if rc.exitedRecovery { + // Else if exiting loss recovery: + // RACK.reo_wnd_persist -= 1 + // If RACK.reo_wnd_persist <= 0: + // RACK.reo_wnd_incr = 1 + rc.reoWndPersist-- + if rc.reoWndPersist <= 0 { + rc.reoWndIncr = 1 + } + rc.exitedRecovery = false + } + + // Reorder window is zero during loss recovery, or when the number of + // DUPACKs exceeds the classic DUPACKthreshold. + // If RACK.reord is FALSE: + // If in loss recovery: (If in fast or timeout recovery) + // RACK.reo_wnd = 0 + // Return + // Else if RACK.pkts_sacked >= RACK.dupthresh: + // RACK.reo_wnd = 0 + // return + if !rc.reorderSeen { + if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { + rc.reoWnd = 0 + return + } + + if snd.sackedOut >= nDupAckThreshold { + rc.reoWnd = 0 + return + } + } + + // Calculate reorder window. + // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr + // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) + snd.rtt.Lock() + srtt := snd.rtt.srtt + snd.rtt.Unlock() + rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) + if srtt < rc.reoWnd { + rc.reoWnd = srtt + } +} + +func (rc *rackControl) exitRecovery() { + rc.exitedRecovery = true +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 405a6dce7..a5c82b8fa 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -347,7 +347,7 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { r.ep.rcvListMu.Lock() rcvClosed := r.ep.rcvClosed || r.closed r.ep.rcvListMu.Unlock() @@ -385,7 +385,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() + r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } @@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { break @@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - return true, tcpip.ErrConnectionAborted + return true, &tcpip.ErrConnectionAborted{} } } @@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed @@ -454,7 +454,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // send an ACK and stop further processing of the segment. // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { - r.ep.snd.sendAck() + r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index f83ebc717..ff39780a5 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -79,10 +79,10 @@ func (r *renoState) Update(packetsAcked int) { r.updateCongestionAvoidance(packetsAcked) } -// HandleNDupAcks implements congestionControl.HandleNDupAcks. -func (r *renoState) HandleNDupAcks() { - // A retransmit was triggered due to nDupAckThreshold - // being hit. Reduce our slow start threshold. +// HandleLossDetected implements congestionControl.HandleLossDetected. +func (r *renoState) HandleLossDetected() { + // A retransmit was triggered due to nDupAckThreshold or when RACK + // detected loss. Reduce our slow start threshold. r.reduceSlowStartThreshold() } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 079d90848..463a259b7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -48,34 +48,13 @@ const ( MaxRetries = 15 ) -// ccState indicates the current congestion control state for this sender. -type ccState int - -const ( - // Open indicates that the sender is receiving acks in order and - // no loss or dupACK's etc have been detected. - Open ccState = iota - // RTORecovery indicates that an RTO has occurred and the sender - // has entered an RTO based recovery phase. - RTORecovery - // FastRecovery indicates that the sender has entered FastRecovery - // based on receiving nDupAck's. This state is entered only when - // SACK is not in use. - FastRecovery - // SACKRecovery indicates that the sender has entered SACK based - // recovery. - SACKRecovery - // Disorder indicates the sender either received some SACK blocks - // or dupACK's. - Disorder -) - // congestionControl is an interface that must be implemented by any supported // congestion control algorithm. type congestionControl interface { - // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold - // just before entering fast retransmit. - HandleNDupAcks() + // HandleLossDetected is invoked when the loss is detected by RACK or + // sender.dupAckCount >= nDupAckThreshold just before entering fast + // retransmit. + HandleLossDetected() // HandleRTOExpired is invoked when the retransmit timer expires. HandleRTOExpired() @@ -204,7 +183,7 @@ type sender struct { maxSentAck seqnum.Value // state is the current state of congestion control for this endpoint. - state ccState + state tcpip.CongestionControlState // cc is the congestion control algorithm in use for this sender. cc congestionControl @@ -280,14 +259,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, - rc: rackControl{ - fack: iss, - }, gso: ep.gso != nil, } - s.rc.init() - if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } @@ -295,6 +269,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.cc = s.initCongestionControl(ep.cc) s.lr = s.initLossRecovery() + s.rc.init(s, iss) // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. @@ -593,7 +568,7 @@ func (s *sender) retransmitTimerExpired() bool { s.leaveRecovery() } - s.state = RTORecovery + s.state = tcpip.RTORecovery s.cc.HandleRTOExpired() // Mark the next segment to be sent as the first unacknowledged one and @@ -1018,7 +993,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -1062,14 +1037,14 @@ func (s *sender) enterRecovery() { s.fr.highRxt = s.sndUna s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { - s.state = SACKRecovery + s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. s.rc.tlpRxtOut = false return } - s.state = FastRecovery + s.state = tcpip.FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } @@ -1080,7 +1055,6 @@ func (s *sender) leaveRecovery() { // Deflate cwnd. It had been artificially inflated when new dups arrived. s.sndCwnd = s.sndSsthresh - s.cc.PostRecovery() } @@ -1166,7 +1140,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.fr.highRxt = s.sndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() - s.state = Disorder + s.state = tcpip.Disorder return false } @@ -1179,7 +1153,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { s.dupAckCount = 0 return false } - s.cc.HandleNDupAcks() + s.cc.HandleLossDetected() s.enterRecovery() s.dupAckCount = 0 return true @@ -1217,11 +1191,13 @@ func (s *sender) isDupAck(seg *segment) bool { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + s.rc.setDSACKSeen(false) + // Look for DSACK block. idx := 0 n := len(rcvdSeg.parsedOptions.SACKBlocks) if checkDSACK(rcvdSeg) { - s.rc.setDSACKSeen() + s.rc.setDSACKSeen(true) idx = 1 n-- } @@ -1242,7 +1218,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { for _, sb := range sackBlocks { for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true s.sackedOut += s.pCount(seg, s.maxPayloadSize) @@ -1412,6 +1388,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { acked := s.sndUna.Size(ack) s.sndUna = ack + // The remote ACK-ing at least 1 byte is an indication that we have a + // full-duplex connection to the remote as the only way we will receive an + // ACK is if the remote received data that we previously sent. + // + // As of writing, linux seems to only confirm a route as reachable when + // forward progress is made which is indicated by an ACK that removes data + // from the retransmit queue. + if acked > 0 { + s.ep.route.ConfirmReachable() + } + ackLeft := acked originalOutstanding := s.outstanding for ackLeft > 0 { @@ -1435,7 +1422,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Update the RACK fields if SACK is enabled. if s.ep.sackPermitted && !seg.acked { - s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1464,7 +1451,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if !s.fr.active { s.cc.Update(originalOutstanding - s.outstanding) if s.fr.last.LessThan(s.sndUna) { - s.state = Open + s.state = tcpip.Open + // Update RACK when we are exiting fast or RTO + // recovery as described in the RFC + // draft-ietf-tcpm-rack-08 Section-7.2 Step 4. + s.rc.exitRecovery() } } @@ -1488,6 +1479,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } + // Update RACK reorder window. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // * Upon receiving an ACK: + // * Step 4: Update RACK reordering window + s.rc.updateRACKReorderWindow(rcvdSeg) + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. if s.fr.active { @@ -1508,7 +1505,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // sendSegment sends the specified segment. -func (s *sender) sendSegment(seg *segment) *tcpip.Error { +func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() @@ -1539,7 +1536,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime @@ -1552,3 +1549,13 @@ func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } + +// maybeSendOutOfWindowAck sends an ACK if we are not being rate limited +// currently. +func (s *sender) maybeSendOutOfWindowAck(seg *segment) { + // Data packets are unlikely to be part of an ACK loop. So always send + // an ACK for a packet w/ data. + if seg.payloadSize() > 0 || s.ep.allowOutOfWindowAck() { + s.sendAck() + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index af915203b..a6a26b705 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -16,6 +16,7 @@ package tcp_test import ( "bytes" + "fmt" "testing" "time" @@ -534,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) { // ACK before the test completes. <-probeDone } + +func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + return + } + + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + return + } + + if state.Sender.RACKState.ReoWndIncr != 1 { + probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr) + return + } + + if state.Sender.RACKState.ReoWndPersist > 0 { + probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist) + return + } + probeDone <- nil + }) +} + +func TestRACKCheckReorderWindow(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan error) + const ackNumToVerify = 3 + addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone) + + const numPackets = 7 + sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed packets [2-6] which indicates there is reordering + // in the connection. + bytesRead += 6 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + if err := <-probeDone; err != nil { + t.Fatalf("unexpected values for RACK variables: %v", err) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 87ff2b909..cd3c4a027 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -48,7 +48,7 @@ type endpointTester struct { } // CheckReadError issues a read to the endpoint and checking for an error. -func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { +func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) { t.Helper() res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { @@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for receive to be notified. select { case <-notifyRead: @@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventHUp) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Close the connection, wait for completion. @@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) { // Call Connect again to retreive the handshake failure status // and stats updates. - if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAborted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + } } if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { @@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { c.EP = ep want := stats.TCP.FailedConnectionAttempts.Value() + 1 - if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { - t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) + { + err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrNoRoute); !ok { + t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + } } if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { @@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} - if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + { + err := c.EP.Connect(connectAddr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("Connect(%+v): %s", connectAddr, err) + } } // Receive SYN packet with our user supplied MSS. @@ -1442,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { t.Fatalf("unexpected return value from Connect: %s", err) } @@ -1502,8 +1515,9 @@ func TestSynSent(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} - if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { - t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + err := c.EP.Connect(addr) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) } // Receive SYN packet. @@ -1548,9 +1562,9 @@ func TestSynSent(t *testing.T) { ept := endpointTester{c.EP} if test.reset { - ept.CheckReadError(t, tcpip.ErrConnectionRefused) + ept.CheckReadError(t, &tcpip.ErrConnectionRefused{}) } else { - ept.CheckReadError(t, tcpip.ErrAborted) + ept.CheckReadError(t, &tcpip.ErrAborted{}) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1576,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1601,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1640,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, rcvBufSz) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1716,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1784,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1866,13 +1880,13 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1891,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -2052,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. @@ -2370,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2443,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -2958,7 +2972,7 @@ func TestSetTTL(t *testing.T) { c := context.New(t, 65535) defer c.Cleanup() - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -2968,8 +2982,11 @@ func TestSetTTL(t *testing.T) { t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("unexpected return value from Connect: %s", err) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("unexpected return value from Connect: %s", err) + } } // Receive SYN packet. @@ -3029,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3085,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -3110,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) { defer c.Cleanup() s := c.Stack() - ch := make(chan *tcpip.Error, 1) + ch := make(chan tcpip.Error, 1) f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = r.CreateEndpoint(&c.WQ) ch <- err }) @@ -3141,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -3160,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventOut) defer c.WQ.EventUnregister(&we) - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } // Receive SYN packet. @@ -3271,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { - case tcpip.ErrWouldBlock: + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case *tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } - case tcpip.ErrConnectionReset: + case *tcpip.ErrConnectionReset: break loop default: - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -3325,8 +3346,9 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. var r bytes.Reader r.Reset(make([]byte, 10)) - if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if _, ok := err.(*tcpip.ErrConnectionReset); !ok { + t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) } } @@ -4184,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4263,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. - ept.CheckReadError(t, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, &tcpip.ErrClosedForReceive{}) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + { + _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) + if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + } } } @@ -4277,7 +4302,7 @@ func TestReusePort(t *testing.T) { defer c.Cleanup() // First case, just an endpoint that was bound. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { t.Fatalf("NewEndpoint failed; %s", err) @@ -4307,8 +4332,11 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } c.EP.Close() @@ -4515,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -4539,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) { } } -func makeStack() (*stack.Stack, *tcpip.Error) { +func makeStack() (*stack.Stack, tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -4609,8 +4637,11 @@ func TestSelfConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) - if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + { + err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { + t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + } } <-notifyCh @@ -4762,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("Bind(%d) failed: %s", i, err) } } - want := tcpip.ErrConnectStarted + var want tcpip.Error = &tcpip.ErrConnectStarted{} if collides { - want = tcpip.ErrNoPortAvailable + want = &tcpip.ErrNoPortAvailable{} } if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) @@ -4889,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) { func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, tc := range testCases { @@ -4975,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { func TestEndpointSetCongestionControl(t *testing.T) { testCases := []struct { cc tcpip.CongestionControlOption - err *tcpip.Error + err tcpip.Error }{ {"reno", nil}, {"cubic", nil}, - {"blahblah", tcpip.ErrNoSuchFile}, + {"blahblah", &tcpip.ErrNoSuchFile{}}, } for _, connected := range []bool{false, true} { @@ -4989,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5085,7 +5116,7 @@ func TestKeepalive(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. @@ -5176,7 +5207,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -5283,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5326,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5343,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5355,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5598,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5673,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { defer c.WQ.EventUnregister(&we) newEP, _, err := c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5709,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5750,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -5766,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if err != tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5780,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5824,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { defer c.Cleanup() // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) @@ -5899,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { }) newEP, _, err := c.EP.Accept(nil) - - if err != nil && err != tcpip.ErrWouldBlock { + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + default: t.Fatalf("Accept failed: %s", err) } - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) @@ -5972,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6042,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6074,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } ept := endpointTester{ep} - ept.CheckReadError(t, tcpip.ErrNotConnected) + ept.CheckReadError(t, &tcpip.ErrNotConnected{}) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6094,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6110,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) + { + err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + } } // Listening endpoint remains in listen state. if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { @@ -6230,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } } @@ -6267,6 +6302,13 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Enable Auto-tuning. stk := c.Stack() + // Disable out of window rate limiting for this test by setting it to 0 as we + // use out of window ACKs to measure the advertised window. + var tcpInvalidRateLimit stack.TCPInvalidRateLimitOption + if err := stk.SetOption(tcpInvalidRateLimit); err != nil { + t.Fatalf("e.stack.SetOption(%#v) = %s", tcpInvalidRateLimit, err) + } + const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 { @@ -6354,7 +6396,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { break } totalCopied += res.Count @@ -6546,7 +6588,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6665,7 +6707,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6772,7 +6814,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6862,7 +6904,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -6936,7 +6978,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7086,7 +7128,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: @@ -7277,7 +7319,7 @@ func TestTCPUserTimeout(t *testing.T) { ) ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7321,7 +7363,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { // Check that the connection is still alive. ept := endpointTester{c.EP} - ept.CheckReadError(t, tcpip.ErrWouldBlock) + ept.CheckReadError(t, &tcpip.ErrWouldBlock{}) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7360,7 +7402,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - ept.CheckReadError(t, tcpip.ErrTimeout) + ept.CheckReadError(t, &tcpip.ErrTimeout{}) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7515,8 +7557,9 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Send data. This should result in an acceptable endpoint. @@ -7573,8 +7616,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) + _, _, err := c.EP.Accept(nil) + if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7696,13 +7740,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) { s := c.Stack() testCases := []struct { v int - err *tcpip.Error + err tcpip.Error }{ {int(tcpip.TCPTimeWaitReuseDisabled), nil}, {int(tcpip.TCPTimeWaitReuseGlobal), nil}, {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil}, - {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue}, - {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue}, + {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}}, + {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}}, } for _, tc := range testCases { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ee55f030c..b1cb9a324 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 // only endpoint instead of a default dual stack socket. func (c *Context) CreateV6Endpoint(v6only bool) { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { + err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Create creates a TCP endpoint. func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) @@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK // does not carry an option that was not requested. func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { - var err *tcpip.Error + var err tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) @@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} err = c.EP.Connect(testFullAddr) - if err != tcpip.ErrConnectStarted { + if _, ok := err.(*tcpip.ErrConnectStarted); !ok { c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) } // Receive SYN packet. @@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4988ba29b..afd8f4d39 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,8 +109,8 @@ type endpoint struct { multicastNICID tcpip.NICID portFlags ports.Flags - lastErrorMu sync.Mutex `state:"nosave"` - lastError *tcpip.Error `state:".(string)"` + lastErrorMu sync.Mutex `state:"nosave"` + lastError tcpip.Error // Values used to reserve a port or register a transport endpoint. // (which ever happens first). @@ -215,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) LastError() *tcpip.Error { +func (e *endpoint) LastError() tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -225,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error { } // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. -func (e *endpoint) UpdateLastError(err *tcpip.Error) { +func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -282,7 +282,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -290,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult e.rcvMu.Lock() if e.rcvList.Empty() { - err := tcpip.ErrWouldBlock + var err tcpip.Error = &tcpip.ErrWouldBlock{} if e.rcvClosed { e.stats.ReadErrors.ReadClosed.Increment() - err = tcpip.ErrClosedForReceive + err = &tcpip.ErrClosedForReceive{} } e.rcvMu.Unlock() return tcpip.ReadResult{}, err @@ -340,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { - return res, tcpip.ErrBadBuffer + return res, &tcpip.ErrBadBuffer{} } res.Count = n return res, nil @@ -351,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: case StateConnected: @@ -359,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi case StateBound: if to == nil { - return false, tcpip.ErrDestinationRequired + return false, &tcpip.ErrDestinationRequired{} } return false, nil default: - return false, tcpip.ErrInvalidEndpointState + return false, &tcpip.ErrInvalidEndpointState{} } e.mu.RUnlock() @@ -389,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { localAddr := e.ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -415,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { n, err := e.write(p, opts) - switch err { + switch err.(type) { case nil: e.stats.PacketsSent.Increment() - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: e.stats.WriteErrors.InvalidArgs.Increment() - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: e.stats.WriteErrors.WriteClosed.Increment() - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: e.stats.WriteErrors.InvalidEndpointState.Increment() - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: // Errors indicating any problem with IP routing of the packet. e.stats.SendErrors.NoRoute.Increment() default: @@ -436,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { if err := e.LastError(); err != nil { return 0, err } // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { - return 0, tcpip.ErrInvalidOptionValue + return 0, &tcpip.ErrInvalidOptionValue{} } to := opts.To @@ -459,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, tcpip.ErrClosedForSend + return 0, &tcpip.ErrClosedForSend{} } // Prepare for write. @@ -480,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, tcpip.ErrNoRoute + return 0, &tcpip.ErrNoRoute{} } nicID = e.BindNICID @@ -490,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc if to.Port == 0 { // Port 0 is an invalid port to send to. - return 0, tcpip.ErrInvalidEndpointState + return 0, &tcpip.ErrInvalidEndpointState{} } dst, netProto, err := e.checkV4MappedLocked(*to) @@ -509,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, tcpip.ErrBroadcastDisabled + return 0, &tcpip.ErrBroadcastDisabled{} } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, tcpip.ErrBadBuffer + return 0, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. so := e.SocketOptions() if so.GetRecvError() { so.QueueLocalErr( - tcpip.ErrMessageTooLong, + &tcpip.ErrMessageTooLong{}, route.NetProto, header.UDPMaximumPacketSize, tcpip.FullAddress{ @@ -532,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc v, ) } - return 0, tcpip.ErrMessageTooLong + return 0, &tcpip.ErrMessageTooLong{} } ttl := e.ttl @@ -582,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: // Return not supported if the value is not disabling path // MTU discovery. if v != tcpip.PMTUDiscoveryDont { - return tcpip.ErrNotSupported + return &tcpip.ErrNotSupported{} } case tcpip.MulticastTTLOption: @@ -640,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -662,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { if nic != 0 { if !e.stack.CheckNIC(nic) { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } else { nic = e.stack.CheckLocalAddress(0, netProto, addr) if nic == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } if e.BindNICID != 0 && e.BindNICID != nic { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.multicastNICID = nic @@ -680,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -696,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -705,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToInsert]; ok { - return tcpip.ErrPortInUse + return &tcpip.ErrPortInUse{} } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -716,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { - return tcpip.ErrInvalidOptionValue + return &tcpip.ErrInvalidOptionValue{} } nicID := v.NIC @@ -731,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) } if nicID == 0 { - return tcpip.ErrUnknownDevice + return &tcpip.ErrUnknownDevice{} } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} @@ -740,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { defer e.mu.Unlock() if _, ok := e.multicastMemberships[memToRemove]; !ok { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { @@ -756,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: e.mu.RLock() @@ -803,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return v, nil default: - return -1, tcpip.ErrUnknownProtocolOption + return -1, &tcpip.ErrUnknownProtocolOption{} } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() @@ -819,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { e.mu.Unlock() default: - return tcpip.ErrUnknownProtocolOption + return &tcpip.ErrUnknownProtocolOption{} } return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), Data: data, @@ -876,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err @@ -885,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres } // Disconnect implements tcpip.Endpoint.Disconnect. -func (e *endpoint) Disconnect() *tcpip.Error { +func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -903,7 +906,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { - var err *tcpip.Error + var err tcpip.Error id = stack.TransportEndpointID{ LocalPort: e.ID.LocalPort, LocalAddress: e.ID.LocalAddress, @@ -934,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error { } // Connect connects the endpoint to its peer. Specifying a NIC is optional. -func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { if addr.Port == 0 { // We don't support connecting to port zero. - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } e.mu.Lock() @@ -954,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } if nicID != 0 && nicID != e.BindNICID { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } nicID = e.BindNICID default: - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1029,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // ConnectEndpoint is not supported. -func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { + return &tcpip.ErrInvalidEndpointState{} } // Shutdown closes the read and/or write end of the endpoint connection // to its peer. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. if state := e.EndpointState(); state != StateBound && state != StateConnected { - return tcpip.ErrNotConnected + return &tcpip.ErrNotConnected{} } e.shutdownFlags |= flags @@ -1062,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen is not supported by UDP, it just fails. -func (*endpoint) Listen(int) *tcpip.Error { - return tcpip.ErrNotSupported +func (*endpoint) Listen(int) tcpip.Error { + return &tcpip.ErrNotSupported{} } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, tcpip.ErrNotSupported +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { + return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) @@ -1090,11 +1093,11 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id return id, bindToDevice, err } -func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState + return &tcpip.ErrInvalidEndpointState{} } addr, netProto, err := e.checkV4MappedLocked(addr) @@ -1118,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { - return tcpip.ErrBadLocalAddress + return &tcpip.ErrBadLocalAddress{} } } @@ -1148,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. -func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { +func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1164,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress returns the address to which the endpoint is bound. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() @@ -1181,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { } // GetRemoteAddress returns the address to which the endpoint is connected. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() if e.EndpointState() != StateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } return tcpip.FullAddress{ @@ -1319,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } -func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { // Update last error first. e.lastErrorMu.Lock() e.lastError = err @@ -1335,12 +1338,9 @@ func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, ext } e.SocketOptions().QueueErr(&tcpip.SockError{ - Err: err, - ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), - ErrType: errType, - ErrCode: errCode, - ErrInfo: extra, - Payload: payload, + Err: err, + Cause: transErr, + Payload: payload, Dst: tcpip.FullAddress{ NIC: pkt.NICID, Addr: e.ID.RemoteAddress, @@ -1359,24 +1359,13 @@ func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, ext e.waiterQueue.Notify(waiter.EventErr) } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { - if typ == stack.ControlPortUnreachable { +// HandleError implements stack.TransportEndpoint. +func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { + // TODO(gvisor.dev/issues/5270): Handle all transport errors. + switch transErr.Kind() { + case stack.DestinationPortUnreachableTransportError: if e.EndpointState() == StateConnected { - var errType byte - var errCode byte - switch pkt.NetworkProtocolNumber { - case header.IPv4ProtocolNumber: - errType = byte(header.ICMPv4DstUnreachable) - errCode = byte(header.ICMPv4PortUnreachable) - case header.IPv6ProtocolNumber: - errType = byte(header.ICMPv6DstUnreachable) - errCode = byte(header.ICMPv6PortUnreachable) - default: - panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) - } - e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) - return + e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) } } } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index feb53b553..21a6aa460 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// saveLastError is invoked by stateify. -func (e *endpoint) saveLastError() string { - if e.lastError == nil { - return "" - } - - return e.lastError.String() -} - -// loadLastError is invoked by stateify. -func (e *endpoint) loadLastError(s string) { - if s == "" { - return - } - - e.lastError = tcpip.StringToError(s) -} - // beforeSave is invoked by stateify. func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). @@ -114,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) { netProto = header.IPv6ProtocolNumber } - var err *tcpip.Error + var err tcpip.Error if state == StateConnected { e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { @@ -123,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { - panic(tcpip.ErrBadLocalAddress) + panic(&tcpip.ErrBadLocalAddress{}) } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index aae794506..705ad1f64 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { } // CreateEndpoint creates a connected UDP endpoint for the session request. -func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { netHdr := r.pkt.Network() route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) if err != nil { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 91420edd3..427fdd0c9 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } @@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given udp // packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { h := header.UDP(v) return h.SourcePort(), h.DestinationPort(), nil } @@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption +func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error { + return &tcpip.ErrUnknownProtocolOption{} } // Close implements stack.TransportProtocol.Close. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c4794e876..5d81dbb94 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -353,7 +353,7 @@ func (c *testContext) cleanup() { func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { c.t.Helper() - var err *tcpip.Error + var err tcpip.Error c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) if err != nil { c.t.Fatal("NewEndpoint failed: ", err) @@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) { testActions := []struct { name string setBindToDevice *tcpip.NICID - setBindToDeviceError *tcpip.Error + setBindToDeviceError tcpip.Error getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, - {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0}, {"BindToExistent", nicIDPtr(321), nil, 321}, {"UnbindToDevice", nicIDPtr(0), nil, 0}, } @@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe var buf bytes.Buffer res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - if err == tcpip.ErrWouldBlock { + if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for data to become available. select { case <-ch: @@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) { t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() - if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(addr) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } } @@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) { defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. - if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) + { + err := ep.Bind(tcpip.FullAddress{Port: addr.Port}) + if _, ok := err.(*tcpip.ErrPortInUse); !ok { + t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{}) + } } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) { for _, tt := range []struct { name string handleLocal bool - wantErr *tcpip.Error + wantErr tcpip.Error wantInvalidSource uint64 }{ {"HandleLocal", false, nil, 0}, - {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ @@ -959,7 +965,7 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { // testFailingWrite sends a packet of the given test flow into the UDP endpoint // and verifies it fails with the provided error code. -func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { +func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() @@ -1092,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { testWrite(c, unicastV6) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{}) const want = 1 if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) @@ -1113,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { testWrite(c, unicastV4in6) // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV4WriteOnV6Only(t *testing.T) { @@ -1123,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) { c.createEndpointForFlow(unicastV6Only) // Write to V4 mapped address. - testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) + testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{}) } func TestV6WriteOnBoundToV4Mapped(t *testing.T) { @@ -1138,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { } // Write to v6 address. - testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) + testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{}) } func TestV6WriteOnConnected(t *testing.T) { @@ -1197,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) } - if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { - c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + { + err := c.ep.LastError() + if _, ok := err.(*tcpip.ErrConnectionRefused); !ok { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } } }) } @@ -1605,7 +1614,7 @@ func TestTTL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, }) - ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil, nil, nil) + ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() ep.Close() } @@ -2308,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) { t.Fatalf("Shutdown failed: %s", err) } - testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) + testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{}) } -func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { + switch err.(type) { case nil: want.PacketsSent.IncrementBy(incr) - case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: want.WriteErrors.InvalidArgs.IncrementBy(incr) - case tcpip.ErrClosedForSend: + case *tcpip.ErrClosedForSend: want.WriteErrors.WriteClosed.IncrementBy(incr) - case tcpip.ErrInvalidEndpointState: + case *tcpip.ErrInvalidEndpointState: want.WriteErrors.InvalidEndpointState.IncrementBy(incr) - case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: want.SendErrors.NoRoute.IncrementBy(incr) default: want.SendErrors.SendToNetworkFailed.IncrementBy(incr) @@ -2332,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE } } -func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) { got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - switch err { - case nil, tcpip.ErrWouldBlock: - case tcpip.ErrClosedForReceive: + switch err.(type) { + case nil, *tcpip.ErrWouldBlock: + case *tcpip.ErrClosedForReceive: want.ReadErrors.ReadClosed.IncrementBy(incr) default: c.t.Errorf("Endpoint error missing stats update err %v", err) @@ -2509,14 +2518,26 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { Port: 80, } opts := tcpip.WriteOptions{To: &to} - expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error { + if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok { + return nil + } + return &tcpip.ErrBroadcastDisabled{} + } if !test.requiresBroadcastOpt { expectedErrWithoutBcastOpt = nil } r.Reset(data) - if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } ep.SocketOptions().SetBroadcast(true) @@ -2529,8 +2550,15 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { ep.SocketOptions().SetBroadcast(false) r.Reset(data) - if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { - t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) + { + n, err := ep.Write(&r, opts) + if expectedErrWithoutBcastOpt != nil { + if want := expectedErrWithoutBcastOpt(err); want != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want) + } + } else if err != nil { + t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) + } } }) } diff --git a/runsc/boot/events.go b/runsc/boot/events.go index 422f4da00..0814b2a69 100644 --- a/runsc/boot/events.go +++ b/runsc/boot/events.go @@ -15,21 +15,30 @@ package boot import ( - "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/usage" ) +// EventOut is the return type of the Event command. +type EventOut struct { + Event Event `json:"event"` + + // ContainerUsage maps each container ID to its total CPU usage. + ContainerUsage map[string]uint64 `json:"containerUsage"` +} + // Event struct for encoding the event data to JSON. Corresponds to runc's // main.event struct. type Event struct { - Type string `json:"type"` - ID string `json:"id"` - Data interface{} `json:"data,omitempty"` + Type string `json:"type"` + ID string `json:"id"` + Data Stats `json:"data"` } // Stats is the runc specific stats structure for stability when encoding and // decoding stats. type Stats struct { + CPU CPU `json:"cpu"` Memory Memory `json:"memory"` Pids Pids `json:"pids"` } @@ -58,24 +67,42 @@ type Memory struct { Raw map[string]uint64 `json:"raw,omitempty"` } -// Event gets the events from the container. -func (cm *containerManager) Event(_ *struct{}, out *Event) error { - stats := &Stats{} - stats.populateMemory(cm.l.k) - stats.populatePIDs(cm.l.k) - *out = Event{Type: "stats", Data: stats} - return nil +// CPU contains stats on the CPU. +type CPU struct { + Usage CPUUsage `json:"usage"` +} + +// CPUUsage contains stats on CPU usage. +type CPUUsage struct { + Kernel uint64 `json:"kernel,omitempty"` + User uint64 `json:"user,omitempty"` + Total uint64 `json:"total,omitempty"` + PerCPU []uint64 `json:"percpu,omitempty"` } -func (s *Stats) populateMemory(k *kernel.Kernel) { - mem := k.MemoryFile() +// Event gets the events from the container. +func (cm *containerManager) Event(_ *struct{}, out *EventOut) error { + *out = EventOut{ + Event: Event{ + Type: "stats", + }, + } + + // Memory usage. + // TODO(gvisor.dev/issue/172): Per-container accounting. + mem := cm.l.k.MemoryFile() mem.UpdateUsage() _, totalUsage := usage.MemoryAccounting.Copy() - s.Memory.Usage = MemoryEntry{ + out.Event.Data.Memory.Usage = MemoryEntry{ Usage: totalUsage, } -} -func (s *Stats) populatePIDs(k *kernel.Kernel) { - s.Pids.Current = uint64(len(k.TaskSet().Root.ThreadGroups())) + // PIDs. + // TODO(gvisor.dev/issue/172): Per-container accounting. + out.Event.Data.Pids.Current = uint64(len(cm.l.k.TaskSet().Root.ThreadGroups())) + + // CPU usage by container. + out.ContainerUsage = control.ContainerUsage(cm.l.k) + + return nil } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index d37528ee7..77a7c530b 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -102,7 +102,7 @@ type containerInfo struct { goferFDs []*fd.FD } -// Loader keeps state needed to start the kernel and run the container.. +// Loader keeps state needed to start the kernel and run the container. type Loader struct { // k is the kernel. k *kernel.Kernel diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 13c6a16a0..797c1c2bc 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -203,6 +203,19 @@ func LoadPaths(pid string) (map[string]string, error) { } func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { + // For nested containers, in /proc/self/cgroup we see paths from host, + // which don't exist in container, so recover the container paths here by + // double-checking with /proc/pid/mountinfo + mountinfo, err := os.Open("/proc/self/mountinfo") + if err != nil { + return nil, err + } + defer mountinfo.Close() + + return loadPathsHelperWithMountinfo(cgroup, mountinfo) +} + +func loadPathsHelperWithMountinfo(cgroup, mountinfo io.Reader) (map[string]string, error) { paths := make(map[string]string) scanner := bufio.NewScanner(cgroup) @@ -225,6 +238,31 @@ func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { if err := scanner.Err(); err != nil { return nil, err } + + mfScanner := bufio.NewScanner(mountinfo) + for mfScanner.Scan() { + txt := mfScanner.Text() + fields := strings.Fields(txt) + if len(fields) < 9 || fields[len(fields)-3] != "cgroup" { + continue + } + for _, opt := range strings.Split(fields[len(fields)-1], ",") { + // Remove prefix for cgroups with no controller, eg. systemd. + opt = strings.TrimPrefix(opt, "name=") + if cgroupPath, ok := paths[opt]; ok { + root := fields[3] + relCgroupPath, err := filepath.Rel(root, cgroupPath) + if err != nil { + return nil, err + } + paths[opt] = relCgroupPath + } + } + } + if err := mfScanner.Err(); err != nil { + return nil, err + } + return paths, nil } @@ -243,8 +281,13 @@ func New(spec *specs.Spec) (*Cgroup, error) { if spec.Linux == nil || spec.Linux.CgroupsPath == "" { return nil, nil } + return NewFromPath(spec.Linux.CgroupsPath) +} + +// NewFromPath creates a new Cgroup instance. +func NewFromPath(cgroupsPath string) (*Cgroup, error) { var parents map[string]string - if !filepath.IsAbs(spec.Linux.CgroupsPath) { + if !filepath.IsAbs(cgroupsPath) { var err error parents, err = LoadPaths("self") if err != nil { @@ -253,7 +296,7 @@ func New(spec *specs.Spec) (*Cgroup, error) { } own := make(map[string]bool) return &Cgroup{ - Name: spec.Linux.CgroupsPath, + Name: cgroupsPath, Parents: parents, Own: own, }, nil @@ -351,6 +394,9 @@ func (c *Cgroup) Join() (func(), error) { undo = func() { for _, path := range undoPaths { log.Debugf("Restoring cgroup %q", path) + // Writing the value 0 to a cgroup.procs file causes + // the writing process to be moved to the corresponding + // cgroup. - cgroups(7). if err := setValue(path, "cgroup.procs", "0"); err != nil { log.Warningf("Error restoring cgroup %q: %v", path, err) } @@ -361,6 +407,9 @@ func (c *Cgroup) Join() (func(), error) { for key, cfg := range controllers { path := c.makePath(key) log.Debugf("Joining cgroup %q", path) + // Writing the value 0 to a cgroup.procs file causes the + // writing process to be moved to the corresponding cgroup. + // - cgroups(7). if err := setValue(path, "cgroup.procs", "0"); err != nil { if cfg.optional && os.IsNotExist(err) { continue @@ -388,6 +437,16 @@ func (c *Cgroup) CPUQuota() (float64, error) { return float64(quota) / float64(period), nil } +// CPUUsage returns the total CPU usage of the cgroup. +func (c *Cgroup) CPUUsage() (uint64, error) { + path := c.makePath("cpuacct") + usage, err := getValue(path, "cpuacct.usage") + if err != nil { + return 0, err + } + return strconv.ParseUint(strings.TrimSpace(usage), 10, 64) +} + // NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'. func (c *Cgroup) NumCPU() (int, error) { path := c.makePath("cpuset") diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 931144cf9..48d71cfa6 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -25,6 +25,39 @@ import ( "gvisor.dev/gvisor/pkg/test/testutil" ) +var debianMountinfo = ` +35 24 0:30 / /sys/fs/cgroup ro shared:9 - tmpfs tmpfs ro +36 35 0:31 / /sys/fs/cgroup/unified rw shared:10 - cgroup2 cgroup2 rw +37 35 0:32 / /sys/fs/cgroup/systemd rw - cgroup cgroup rw,name=systemd +41 35 0:36 / /sys/fs/cgroup/cpu,cpuacct rw shared:16 - cgroup cgroup rw,cpu,cpuacct +42 35 0:37 / /sys/fs/cgroup/freezer rw shared:17 - cgroup cgroup rw,freezer +43 35 0:38 / /sys/fs/cgroup/hugetlb rw shared:18 - cgroup cgroup rw,hugetlb +44 35 0:39 / /sys/fs/cgroup/cpuset rw shared:19 - cgroup cgroup rw,cpuset +45 35 0:40 / /sys/fs/cgroup/net_cls,net_prio rw shared:20 - cgroup cgroup rw,net_cls,net_prio +46 35 0:41 / /sys/fs/cgroup/pids rw shared:21 - cgroup cgroup rw,pids +47 35 0:42 / /sys/fs/cgroup/perf_event rw shared:22 - cgroup cgroup rw,perf_event +48 35 0:43 / /sys/fs/cgroup/memory rw shared:23 - cgroup cgroup rw,memory +49 35 0:44 / /sys/fs/cgroup/blkio rw shared:24 - cgroup cgroup rw,blkio +50 35 0:45 / /sys/fs/cgroup/devices rw shared:25 - cgroup cgroup rw,devices +51 35 0:46 / /sys/fs/cgroup/rdma rw shared:26 - cgroup cgroup rw,rdma +` + +var dindMountinfo = ` +1305 1304 0:64 / /sys/fs/cgroup rw - tmpfs tmpfs rw,mode=755 +1306 1305 0:32 /docker/136 /sys/fs/cgroup/systemd ro master:11 - cgroup cgroup rw,xattr,name=systemd +1307 1305 0:36 /docker/136 /sys/fs/cgroup/cpu,cpuacct ro master:16 - cgroup cgroup rw,cpu,cpuacct +1308 1305 0:37 /docker/136 /sys/fs/cgroup/freezer ro master:17 - cgroup cgroup rw,freezer +1309 1305 0:38 /docker/136 /sys/fs/cgroup/hugetlb ro master:18 - cgroup cgroup rw,hugetlb +1310 1305 0:39 /docker/136 /sys/fs/cgroup/cpuset ro master:19 - cgroup cgroup rw,cpuset +1311 1305 0:40 /docker/136 /sys/fs/cgroup/net_cls,net_prio ro master:20 - cgroup cgroup rw,net_cls,net_prio +1312 1305 0:41 /docker/136 /sys/fs/cgroup/pids ro master:21 - cgroup cgroup rw,pids +1313 1305 0:42 /docker/136 /sys/fs/cgroup/perf_event ro master:22 - cgroup cgroup rw,perf_event +1314 1305 0:43 /docker/136 /sys/fs/cgroup/memory ro master:23 - cgroup cgroup rw,memory +1316 1305 0:44 /docker/136 /sys/fs/cgroup/blkio ro master:24 - cgroup cgroup rw,blkio +1317 1305 0:45 /docker/136 /sys/fs/cgroup/devices ro master:25 - cgroup cgroup rw,devices +1318 1305 0:46 / /sys/fs/cgroup/rdma ro master:26 - cgroup cgroup rw,rdma +` + func TestUninstallEnoent(t *testing.T) { c := Cgroup{ // set a non-existent name @@ -653,60 +686,110 @@ func TestPids(t *testing.T) { func TestLoadPaths(t *testing.T) { for _, tc := range []struct { - name string - cgroups string - want map[string]string - err string + name string + cgroups string + mountinfo string + want map[string]string + err string }{ { - name: "abs-path", - cgroups: "0:ctr:/path", - want: map[string]string{"ctr": "/path"}, + name: "abs-path-unknown-controller", + cgroups: "0:ctr:/path", + mountinfo: debianMountinfo, + want: map[string]string{"ctr": "/path"}, }, { - name: "rel-path", - cgroups: "0:ctr:rel-path", - want: map[string]string{"ctr": "rel-path"}, + name: "rel-path", + cgroups: "0:ctr:rel-path", + mountinfo: debianMountinfo, + want: map[string]string{"ctr": "rel-path"}, }, { - name: "non-controller", - cgroups: "0:name=systemd:/path", - want: map[string]string{"systemd": "/path"}, + name: "non-controller", + cgroups: "0:name=systemd:/path", + mountinfo: debianMountinfo, + want: map[string]string{"systemd": "path"}, }, { - name: "empty", + name: "empty", + mountinfo: debianMountinfo, }, { name: "multiple", cgroups: "0:ctr0:/path0\n" + "1:ctr1:/path1\n" + "2::/empty\n", + mountinfo: debianMountinfo, want: map[string]string{ "ctr0": "/path0", "ctr1": "/path1", }, }, { - name: "missing-field", - cgroups: "0:nopath\n", - err: "invalid cgroups file", + name: "missing-field", + cgroups: "0:nopath\n", + mountinfo: debianMountinfo, + err: "invalid cgroups file", }, { - name: "too-many-fields", - cgroups: "0:ctr:/path:extra\n", - err: "invalid cgroups file", + name: "too-many-fields", + cgroups: "0:ctr:/path:extra\n", + mountinfo: debianMountinfo, + err: "invalid cgroups file", }, { name: "multiple-malformed", cgroups: "0:ctr0:/path0\n" + "1:ctr1:/path1\n" + "2:\n", - err: "invalid cgroups file", + mountinfo: debianMountinfo, + err: "invalid cgroups file", + }, + { + name: "nested-cgroup", + cgroups: `9:memory:/docker/136 +2:cpu,cpuacct:/docker/136 +1:name=systemd:/docker/136 +0::/system.slice/containerd.service`, + mountinfo: dindMountinfo, + // we want relative path to /sys/fs/cgroup inside the nested container. + // Subcroup inside the container will be created at /sys/fs/cgroup/cpu + // This will be /sys/fs/cgroup/cpu/docker/136/CGROUP_NAME + // outside the container + want: map[string]string{ + "memory": ".", + "cpu": ".", + "cpuacct": ".", + "systemd": ".", + }, + }, + { + name: "nested-cgroup-submount", + cgroups: "9:memory:/docker/136/test", + mountinfo: dindMountinfo, + want: map[string]string{ + "memory": "test", + }, + }, + { + name: "invalid-mount-info", + cgroups: "0:memory:/path", + mountinfo: "41 35 0:36 / /sys/fs/cgroup/memory rw shared:16 - invalid", + want: map[string]string{ + "memory": "/path", + }, + }, + { + name: "invalid-rel-path-in-proc-cgroup", + cgroups: "9:memory:./invalid", + mountinfo: dindMountinfo, + err: "can't make ./invalid relative to /docker/136", }, } { t.Run(tc.name, func(t *testing.T) { r := strings.NewReader(tc.cgroups) - got, err := loadPathsHelper(r) + mountinfo := strings.NewReader(tc.mountinfo) + got, err := loadPathsHelperWithMountinfo(r, mountinfo) if len(tc.err) == 0 { if err != nil { t.Fatalf("Unexpected error: %v", err) diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 75b0aac8d..06f00e8e7 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -93,9 +93,9 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa // err must be preserved because it is used below when breaking // out of the loop. - b, err := json.Marshal(ev) + b, err := json.Marshal(ev.Event) if err != nil { - log.Warningf("Error while marshalling event %v: %v", ev, err) + log.Warningf("Error while marshalling event %v: %v", ev.Event, err) } else { os.Stdout.Write(b) } diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 1b0fdebd6..7a3d5a523 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -122,7 +122,7 @@ func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) { // Test that an pty FD is sent over the console socket if one is provided. func TestConsoleSocket(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("true") spec.Process.Terminal = true @@ -164,7 +164,7 @@ func TestConsoleSocket(t *testing.T) { // Test that an pty FD is sent over the console socket if one is provided. func TestMultiContainerConsoleSocket(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -495,7 +495,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // Test that terminal works with root and sub-containers. func TestMultiContainerTerminal(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { diff --git a/runsc/container/container.go b/runsc/container/container.go index 5a0f8d5dc..aae64ae1c 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -486,12 +486,20 @@ func (c *Container) Execute(args *control.ExecArgs) (int32, error) { } // Event returns events for the container. -func (c *Container) Event() (*boot.Event, error) { +func (c *Container) Event() (*boot.EventOut, error) { log.Debugf("Getting events for container, cid: %s", c.ID) if err := c.requireStatus("get events for", Created, Running, Paused); err != nil { return nil, err } - return c.Sandbox.Event(c.ID) + event, err := c.Sandbox.Event(c.ID) + if err != nil { + return nil, err + } + + // Some stats can utilize host cgroups for accuracy. + c.populateStats(event) + + return event, nil } // SandboxPid returns the Pid of the sandbox the container is running in, or -1 if the @@ -1110,3 +1118,54 @@ func setOOMScoreAdj(pid int, scoreAdj int) error { } return nil } + +// populateStats populates event with stats estimates based on cgroups and the +// sentry's accounting. +// TODO(gvisor.dev/issue/172): This is an estimation; we should do more +// detailed accounting. +func (c *Container) populateStats(event *boot.EventOut) { + // The events command, when run for all running containers, should + // account for the full cgroup CPU usage. We split cgroup usage + // proportionally according to the sentry-internal usage measurements, + // only counting Running containers. + log.Warningf("event.ContainerUsage: %v", event.ContainerUsage) + var containerUsage uint64 + var allContainersUsage uint64 + for ID, usage := range event.ContainerUsage { + allContainersUsage += usage + if ID == c.ID { + containerUsage = usage + } + } + + cgroup, err := c.Sandbox.FindCgroup() + if err != nil { + // No cgroup, so rely purely on the sentry's accounting. + log.Warningf("events: no cgroups") + event.Event.Data.CPU.Usage.Total = containerUsage + return + } + + // Get the host cgroup CPU usage. + cgroupsUsage, err := cgroup.CPUUsage() + if err != nil { + // No cgroup usage, so rely purely on the sentry's accounting. + log.Warningf("events: failed when getting cgroup CPU usage for container: %v", err) + event.Event.Data.CPU.Usage.Total = containerUsage + return + } + + // If the sentry reports no memory usage, fall back on cgroups and + // split usage equally across containers. + if allContainersUsage == 0 { + log.Warningf("events: no sentry CPU usage reported") + allContainersUsage = cgroupsUsage + containerUsage = cgroupsUsage / uint64(len(event.ContainerUsage)) + } + + log.Warningf("%f, %f, %f", containerUsage, cgroupsUsage, allContainersUsage) + // Scaling can easily overflow a uint64 (e.g. a containerUsage and + // cgroupsUsage of 16 seconds each will overflow), so use floats. + event.Event.Data.CPU.Usage.Total = uint64(float64(containerUsage) * (float64(cgroupsUsage) / float64(allContainersUsage))) + return +} diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 3bbf86534..d50bbcd9f 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -312,8 +312,7 @@ var ( all = append(noOverlay, overlay) ) -// configs generates different configurations to run tests. -func configs(t *testing.T, opts ...configOption) map[string]*config.Config { +func configsHelper(t *testing.T, opts ...configOption) map[string]*config.Config { // Always load the default config. cs := make(map[string]*config.Config) testutil.TestConfig(t) @@ -339,10 +338,12 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config { return cs } -// TODO(gvisor.dev/issue/1624): Merge with configs when VFS2 is the default. -func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config { - all := configs(t, opts...) - for key, value := range configs(t, opts...) { +// configs generates different configurations to run tests. +// +// TODO(gvisor.dev/issue/1624): Remove VFS1 dimension. +func configs(t *testing.T, opts ...configOption) map[string]*config.Config { + all := configsHelper(t, opts...) + for key, value := range configsHelper(t, opts...) { value.VFS2 = true all[key+"VFS2"] = value } @@ -358,7 +359,7 @@ func TestLifecycle(t *testing.T) { childReaper.Start() defer childReaper.Stop() - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { // The container will just sleep for a long time. We will kill it before // it finishes sleeping. @@ -529,7 +530,7 @@ func TestExePath(t *testing.T) { t.Fatalf("error making directory: %v", err) } - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { for _, test := range []struct { path string @@ -654,7 +655,7 @@ func doAppExitStatus(t *testing.T, vfs2 bool) { // TestExec verifies that a container can exec a new program. func TestExec(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "exec-test") if err != nil { @@ -783,7 +784,7 @@ func TestExec(t *testing.T) { // TestExecProcList verifies that a container can exec a new program and it // shows correcly in the process list. func TestExecProcList(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { const uid = 343 spec := testutil.NewSpecWithArgs("sleep", "100") @@ -854,7 +855,7 @@ func TestExecProcList(t *testing.T) { // TestKillPid verifies that we can signal individual exec'd processes. func TestKillPid(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { app, err := testutil.FindFile("test/cmd/test_app/test_app") if err != nil { @@ -930,7 +931,6 @@ func TestKillPid(t *testing.T) { // number after the last number from the checkpointed container. func TestCheckpointRestore(t *testing.T) { // Skip overlay because test requires writing to host file. - // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") @@ -1092,7 +1092,6 @@ func TestCheckpointRestore(t *testing.T) { // with filesystem Unix Domain Socket use. func TestUnixDomainSockets(t *testing.T) { // Skip overlay because test requires writing to host file. - // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { // UDS path is limited to 108 chars for compatibility with older systems. @@ -1230,7 +1229,7 @@ func TestUnixDomainSockets(t *testing.T) { // recreated. Then it resumes the container, verify that the file gets created // again. func TestPauseResume(t *testing.T) { - for name, conf := range configsWithVFS2(t, noOverlay...) { + for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock") if err != nil { @@ -1373,7 +1372,7 @@ func TestCapabilities(t *testing.T) { uid := auth.KUID(os.Getuid() + 1) gid := auth.KGID(os.Getgid() + 1) - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("sleep", "100") rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) @@ -1446,7 +1445,7 @@ func TestCapabilities(t *testing.T) { // TestRunNonRoot checks that sandbox can be configured when running as // non-privileged user. func TestRunNonRoot(t *testing.T) { - for name, conf := range configsWithVFS2(t, noOverlay...) { + for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("/bin/true") @@ -1490,7 +1489,7 @@ func TestRunNonRoot(t *testing.T) { // TestMountNewDir checks that runsc will create destination directory if it // doesn't exit. func TestMountNewDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { root, err := ioutil.TempDir(testutil.TmpDir(), "root") if err != nil { @@ -1521,7 +1520,7 @@ func TestMountNewDir(t *testing.T) { } func TestReadonlyRoot(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("sleep", "100") spec.Root.Readonly = true @@ -1569,7 +1568,7 @@ func TestReadonlyRoot(t *testing.T) { } func TestReadonlyMount(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") if err != nil { @@ -1628,7 +1627,7 @@ func TestReadonlyMount(t *testing.T) { } func TestUIDMap(t *testing.T) { - for name, conf := range configsWithVFS2(t, noOverlay...) { + for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount") if err != nil { @@ -1916,7 +1915,7 @@ func TestUserLog(t *testing.T) { } func TestWaitOnExitedSandbox(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { // Run a shell that sleeps for 1 second and then exits with a // non-zero code. @@ -2058,7 +2057,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { } func TestCreateWorkingDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create") if err != nil { @@ -2173,7 +2172,7 @@ func TestMountPropagation(t *testing.T) { } func TestMountSymlink(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink") if err != nil { diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index bc802e075..173332cc2 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -15,7 +15,6 @@ package container import ( - "encoding/json" "fmt" "io/ioutil" "math" @@ -132,7 +131,7 @@ func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { // TestMultiContainerSanity checks that it is possible to run 2 dead-simple // containers in the same sandbox. func TestMultiContainerSanity(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -170,7 +169,7 @@ func TestMultiContainerSanity(t *testing.T) { // TestMultiPIDNS checks that it is possible to run 2 dead-simple // containers in the same sandbox with different pidns. func TestMultiPIDNS(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -215,7 +214,7 @@ func TestMultiPIDNS(t *testing.T) { // TestMultiPIDNSPath checks the pidns path. func TestMultiPIDNSPath(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -322,8 +321,8 @@ func TestMultiContainerWait(t *testing.T) { } } -// TestExecWait ensures what we can wait containers and individual processes in the -// sandbox that have already exited. +// TestExecWait ensures what we can wait on containers and individual processes +// in the sandbox that have already exited. func TestExecWait(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -448,7 +447,7 @@ func TestMultiContainerMount(t *testing.T) { // TestMultiContainerSignal checks that it is possible to signal individual // containers without killing the entire sandbox. func TestMultiContainerSignal(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -548,7 +547,7 @@ func TestMultiContainerDestroy(t *testing.T) { t.Fatal("error finding test_app:", err) } - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1042,7 +1041,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { // Test that pod shared mounts are properly mounted in 2 containers and that // changes from one container is reflected in the other. func TestMultiContainerSharedMount(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1155,7 +1154,7 @@ func TestMultiContainerSharedMount(t *testing.T) { // Test that pod mounts are mounted as readonly when requested. func TestMultiContainerSharedMountReadonly(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1220,7 +1219,7 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { // Test that shared pod mounts continue to work after container is restarted. func TestMultiContainerSharedMountRestart(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1329,7 +1328,7 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { // Test that unsupported pod mounts options are ignored when matching master and // replica mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { - for name, conf := range configsWithVFS2(t, all...) { + for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1663,7 +1662,7 @@ func TestMultiContainerRunNonRoot(t *testing.T) { func TestMultiContainerHomeEnvDir(t *testing.T) { // NOTE: Don't use overlay since we need changes to persist to the temp dir // outside the sandbox. - for testName, conf := range configsWithVFS2(t, noOverlay...) { + for testName, conf := range configs(t, noOverlay...) { t.Run(testName, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() @@ -1743,8 +1742,9 @@ func TestMultiContainerEvent(t *testing.T) { // Setup the containers. sleep := []string{"/bin/sleep", "100"} + busy := []string{"/bin/bash", "-c", "i=0 ; while true ; do (( i += 1 )) ; done"} quick := []string{"/bin/true"} - podSpec, ids := createSpecs(sleep, sleep, quick) + podSpec, ids := createSpecs(sleep, busy, quick) containers, cleanup, err := startContainers(conf, podSpec, ids) if err != nil { t.Fatalf("error starting containers: %v", err) @@ -1755,37 +1755,58 @@ func TestMultiContainerEvent(t *testing.T) { t.Logf("Running containerd %s", cont.ID) } - // Wait for last container to stabilize the process count that is checked - // further below. + // Wait for last container to stabilize the process count that is + // checked further below. if ws, err := containers[2].Wait(); err != nil || ws != 0 { t.Fatalf("Container.Wait, status: %v, err: %v", ws, err) } + expectedPL := []*control.Process{ + newProcessBuilder().Cmd("sleep").Process(), + } + if err := waitForProcessList(containers[0], expectedPL); err != nil { + t.Errorf("failed to wait for sleep to start: %v", err) + } + expectedPL = []*control.Process{ + newProcessBuilder().Cmd("bash").Process(), + } + if err := waitForProcessList(containers[1], expectedPL); err != nil { + t.Errorf("failed to wait for bash to start: %v", err) + } // Check events for running containers. + var prevUsage uint64 for _, cont := range containers[:2] { - evt, err := cont.Event() + ret, err := cont.Event() if err != nil { t.Errorf("Container.Events(): %v", err) } + evt := ret.Event if want := "stats"; evt.Type != want { - t.Errorf("Wrong event type, want: %s, got :%s", want, evt.Type) + t.Errorf("Wrong event type, want: %s, got: %s", want, evt.Type) } if cont.ID != evt.ID { - t.Errorf("Wrong container ID, want: %s, got :%s", cont.ID, evt.ID) + t.Errorf("Wrong container ID, want: %s, got: %s", cont.ID, evt.ID) } - // Event.Data is an interface, so it comes from the wire was - // map[string]string. Marshal and unmarshall again to the correc type. - data, err := json.Marshal(evt.Data) - if err != nil { - t.Fatalf("invalid event data: %v", err) + // One process per remaining container. + if got, want := evt.Data.Pids.Current, uint64(2); got != want { + t.Errorf("Wrong number of PIDs, want: %d, got: %d", want, got) } - var stats boot.Stats - if err := json.Unmarshal(data, &stats); err != nil { - t.Fatalf("invalid event data: %v", err) + + // Both remaining containers should have nonzero usage, and + // 'busy' should have higher usage than 'sleep'. + usage := evt.Data.CPU.Usage.Total + if usage == 0 { + t.Errorf("Running container should report nonzero CPU usage, but got %d", usage) } - // One process per remaining container. - if want := uint64(2); stats.Pids.Current != want { - t.Errorf("Wrong number of PIDs, want: %d, got :%d", want, stats.Pids.Current) + if usage <= prevUsage { + t.Errorf("Expected container %s to use more than %d ns of CPU, but used %d", cont.ID, prevUsage, usage) + } + t.Logf("Container %s usage: %d", cont.ID, usage) + prevUsage = usage + + // The exited container should have a usage of zero. + if exited := ret.ContainerUsage[containers[2].ID]; exited != 0 { + t.Errorf("Exited container should report 0 CPU usage, but got %d", exited) } } diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index dfbf1f2d3..c46322ba4 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -49,7 +49,7 @@ type LoadOpts struct { // Returns ErrNotExist if no container is found. Returns error in case more than // one containers matching the ID prefix is found. func Load(rootDir string, id FullID, opts LoadOpts) (*Container, error) { - //log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID) + log.Debugf("Load container, rootDir: %q, id: %+v, opts: %+v", rootDir, id, opts) if !opts.Exact { var err error id, err = findContainerID(rootDir, id.ContainerID) diff --git a/runsc/mitigate/BUILD b/runsc/mitigate/BUILD index 9a9546577..3b0342d18 100644 --- a/runsc/mitigate/BUILD +++ b/runsc/mitigate/BUILD @@ -8,6 +8,7 @@ go_library( "cpu.go", "mitigate.go", ], + deps = ["@in_gopkg_yaml_v2//:go_default_library"], ) go_test( @@ -15,4 +16,5 @@ go_test( size = "small", srcs = ["cpu_test.go"], library = ":mitigate", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], ) diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go index 113b98159..ae4ce9579 100644 --- a/runsc/mitigate/cpu.go +++ b/runsc/mitigate/cpu.go @@ -16,6 +16,7 @@ package mitigate import ( "fmt" + "io/ioutil" "regexp" "strconv" "strings" @@ -31,16 +32,104 @@ const ( ) const ( - processorKey = "processor" - vendorIDKey = "vendor_id" - cpuFamilyKey = "cpu family" - modelKey = "model" - coreIDKey = "core id" - bugsKey = "bugs" + processorKey = "processor" + vendorIDKey = "vendor_id" + cpuFamilyKey = "cpu family" + modelKey = "model" + physicalIDKey = "physical id" + coreIDKey = "core id" + bugsKey = "bugs" ) -// getCPUSet returns cpu structs from reading /proc/cpuinfo. -func getCPUSet(data string) ([]*cpu, error) { +const ( + cpuOnlineTemplate = "/sys/devices/system/cpu/cpu%d/online" +) + +// cpuSet contains a map of all CPUs on the system, mapped +// by Physical ID and CoreIDs. threads with the same +// Core and Physical ID are Hyperthread pairs. +type cpuSet map[cpuID]*threadGroup + +// newCPUSet creates a CPUSet from data read from /proc/cpuinfo. +func newCPUSet(data []byte, vulnerable func(*thread) bool) (cpuSet, error) { + processors, err := getThreads(string(data)) + if err != nil { + return nil, err + } + + set := make(cpuSet) + for _, p := range processors { + // Each ID is of the form physicalID:coreID. Hyperthread pairs + // have identical physical and core IDs. We need to match + // Hyperthread pairs so that we can shutdown all but one per + // pair. + core, ok := set[p.id] + if !ok { + core = &threadGroup{} + set[p.id] = core + } + core.isVulnerable = core.isVulnerable || vulnerable(p) + core.threads = append(core.threads, p) + } + return set, nil +} + +// String implements the String method for CPUSet. +func (c cpuSet) String() string { + ret := "" + for _, tg := range c { + ret += fmt.Sprintf("%s\n", tg) + } + return ret +} + +// getRemainingList returns the list of threads that will remain active +// after mitigation. +func (c cpuSet) getRemainingList() []*thread { + threads := make([]*thread, 0, len(c)) + for _, core := range c { + // If we're vulnerable, take only one thread from the pair. + if core.isVulnerable { + threads = append(threads, core.threads[0]) + continue + } + // Otherwise don't shutdown anything. + threads = append(threads, core.threads...) + } + return threads +} + +// getShutdownList returns the list of threads that will be shutdown on +// mitigation. +func (c cpuSet) getShutdownList() []*thread { + threads := make([]*thread, 0) + for _, core := range c { + // Only if we're vulnerable do shutdown anything. In this case, + // shutdown all but the first entry. + if core.isVulnerable && len(core.threads) > 1 { + threads = append(threads, core.threads[1:]...) + } + } + return threads +} + +// threadGroup represents Hyperthread pairs on the same physical/core ID. +type threadGroup struct { + threads []*thread + isVulnerable bool +} + +// String implements the String method for threadGroup. +func (c *threadGroup) String() string { + ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable) + for _, processor := range c.threads { + ret += fmt.Sprintf("%s\n", processor) + } + return ret +} + +// getThreads returns threads structs from reading /proc/cpuinfo. +func getThreads(data string) ([]*thread, error) { // Each processor entry should start with the // processor key. Find the beginings of each. r := buildRegex(processorKey, `\d+`) @@ -56,13 +145,13 @@ func getCPUSet(data string) ([]*cpu, error) { // indexes (e.g. data[index[i], index[i+1]]). // There should be len(indicies) - 1 CPUs // since the last index is the end of the string. - var cpus = make([]*cpu, 0, len(indices)-1) + var cpus = make([]*thread, 0, len(indices)-1) // Find each string that represents a CPU. These begin "processor". for i := 1; i < len(indices); i++ { start := indices[i-1][0] end := indices[i][0] // Parse the CPU entry, which should be between start/end. - c, err := getCPU(data[start:end]) + c, err := newThread(data[start:end]) if err != nil { return nil, err } @@ -71,18 +160,25 @@ func getCPUSet(data string) ([]*cpu, error) { return cpus, nil } +// cpuID for each thread is defined by the physical and +// core IDs. If equal, two threads are Hyperthread pairs. +type cpuID struct { + physicalID int64 + coreID int64 +} + // type cpu represents pertinent info about a cpu. -type cpu struct { +type thread struct { processorNumber int64 // the processor number of this CPU. vendorID string // the vendorID of CPU (e.g. AuthenticAMD). cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake). model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake). - coreID int64 // This CPU's core id to match Hyperthread Pairs + id cpuID // id for this thread bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field. } -// getCPU parses a CPU from a single cpu entry from /proc/cpuinfo. -func getCPU(data string) (*cpu, error) { +// newThread parses a CPU from a single cpu entry from /proc/cpuinfo. +func newThread(data string) (*thread, error) { processor, err := parseProcessor(data) if err != nil { return nil, err @@ -103,6 +199,11 @@ func getCPU(data string) (*cpu, error) { return nil, err } + physicalID, err := parsePhysicalID(data) + if err != nil { + return nil, err + } + coreID, err := parseCoreID(data) if err != nil { return nil, err @@ -113,16 +214,41 @@ func getCPU(data string) (*cpu, error) { return nil, err } - return &cpu{ + return &thread{ processorNumber: processor, vendorID: vendorID, cpuFamily: cpuFamily, model: model, - coreID: coreID, - bugs: bugs, + id: cpuID{ + physicalID: physicalID, + coreID: coreID, + }, + bugs: bugs, }, nil } +// String implements the String method for thread. +func (t *thread) String() string { + template := `CPU: %d +CPU ID: %+v +Vendor: %s +Family/Model: %d/%d +Bugs: %s +` + bugs := make([]string, 0) + for bug := range t.bugs { + bugs = append(bugs, bug) + } + + return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ",")) +} + +// shutdown turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online. +func (t *thread) shutdown() error { + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644) +} + // List of pertinent side channel vulnerablilites. // For mds, see: https://www.kernel.org/doc/html/latest/admin-guide/hw-vuln/mds.html. var vulnerabilities = []string{ @@ -134,35 +260,46 @@ var vulnerabilities = []string{ } // isVulnerable checks if a CPU is vulnerable to pertinent bugs. -func (c *cpu) isVulnerable() bool { +func (t *thread) isVulnerable() bool { for _, bug := range vulnerabilities { - if _, ok := c.bugs[bug]; ok { + if _, ok := t.bugs[bug]; ok { return true } } return false } +// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online +// If the file does not exist (ioutil returns in error), we assume the CPU is on. +func (t *thread) isActive() bool { + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + data, err := ioutil.ReadFile(cpuPath) + if err != nil { + return true + } + return len(data) > 0 && data[0] != '0' +} + // similarTo checks family/model/bugs fields for equality of two // processors. -func (c *cpu) similarTo(other *cpu) bool { - if c.vendorID != other.vendorID { +func (t *thread) similarTo(other *thread) bool { + if t.vendorID != other.vendorID { return false } - if other.cpuFamily != c.cpuFamily { + if other.cpuFamily != t.cpuFamily { return false } - if other.model != c.model { + if other.model != t.model { return false } - if len(other.bugs) != len(c.bugs) { + if len(other.bugs) != len(t.bugs) { return false } - for bug := range c.bugs { + for bug := range t.bugs { if _, ok := other.bugs[bug]; !ok { return false } @@ -190,6 +327,11 @@ func parseModel(data string) (int64, error) { return parseIntegerResult(data, modelKey) } +// parsePhysicalID parses the physical id field. +func parsePhysicalID(data string) (int64, error) { + return parseIntegerResult(data, physicalIDKey) +} + // parseCoreID parses the core id field. func parseCoreID(data string) (int64, error) { return parseIntegerResult(data, coreIDKey) diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go index 77b714a02..21c12f586 100644 --- a/runsc/mitigate/cpu_test.go +++ b/runsc/mitigate/cpu_test.go @@ -15,26 +15,163 @@ package mitigate import ( + "fmt" "io/ioutil" "strings" "testing" ) -// CPU info for a Intel CascadeLake processor. Both Skylake and CascadeLake have -// the same family/model numbers, but with different bugs (e.g. skylake has -// cpu_meltdown). -var cascadeLake = &cpu{ - vendorID: "GenuineIntel", - cpuFamily: 6, - model: 85, - bugs: map[string]struct{}{ - "spectre_v1": struct{}{}, - "spectre_v2": struct{}{}, - "spec_store_bypass": struct{}{}, - mds: struct{}{}, - swapgs: struct{}{}, - taa: struct{}{}, - }, +// cpuTestCase represents data from CPUs that will be mitigated. +type cpuTestCase struct { + name string + vendorID string + family int + model int + modelName string + bugs string + physicalCores int + cores int + threadsPerCore int +} + +var cascadeLake4 = cpuTestCase{ + name: "CascadeLake", + vendorID: "GenuineIntel", + family: 6, + model: 85, + modelName: "Intel(R) Xeon(R) CPU", + bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa", + physicalCores: 1, + cores: 2, + threadsPerCore: 2, +} + +var haswell2 = cpuTestCase{ + name: "Haswell", + vendorID: "GenuineIntel", + family: 6, + model: 63, + modelName: "Intel(R) Xeon(R) CPU", + bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", + physicalCores: 1, + cores: 1, + threadsPerCore: 2, +} + +var haswell2core = cpuTestCase{ + name: "Haswell2Physical", + vendorID: "GenuineIntel", + family: 6, + model: 63, + modelName: "Intel(R) Xeon(R) CPU", + bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", + physicalCores: 2, + cores: 1, + threadsPerCore: 1, +} + +var amd8 = cpuTestCase{ + name: "AMD", + vendorID: "AuthenticAMD", + family: 23, + model: 49, + modelName: "AMD EPYC 7B12", + bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass", + physicalCores: 4, + cores: 1, + threadsPerCore: 2, +} + +// makeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase +func (tc cpuTestCase) makeCPUString() string { + template := `processor : %d +vendor_id : %s +cpu family : %d +model : %d +model name : %s +physical id : %d +core id : %d +cpu cores : %d +bugs : %s +` + ret := `` + for i := 0; i < tc.physicalCores; i++ { + for j := 0; j < tc.cores; j++ { + for k := 0; k < tc.threadsPerCore; k++ { + processorNum := (i*tc.cores+j)*tc.threadsPerCore + k + ret += fmt.Sprintf(template, + processorNum, /*processor*/ + tc.vendorID, /*vendor_id*/ + tc.family, /*cpu family*/ + tc.model, /*model*/ + tc.modelName, /*model name*/ + i, /*physical id*/ + j, /*core id*/ + tc.cores*tc.physicalCores, /*cpu cores*/ + tc.bugs /*bugs*/) + } + } + } + return ret +} + +// TestMockCPUSet tests mock cpu test cases against the cpuSet functions. +func TestMockCPUSet(t *testing.T) { + for _, tc := range []struct { + testCase cpuTestCase + isVulnerable bool + }{ + { + testCase: amd8, + isVulnerable: false, + }, + { + testCase: haswell2, + isVulnerable: true, + }, + { + testCase: haswell2core, + isVulnerable: true, + }, + + { + testCase: cascadeLake4, + isVulnerable: true, + }, + } { + t.Run(tc.testCase.name, func(t *testing.T) { + data := tc.testCase.makeCPUString() + vulnerable := func(t *thread) bool { + return t.isVulnerable() + } + set, err := newCPUSet([]byte(data), vulnerable) + if err != nil { + t.Fatalf("Failed to ") + } + remaining := set.getRemainingList() + // In the non-vulnerable case, no cores should be shutdown so all should remain. + want := tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore + if tc.isVulnerable { + want = tc.testCase.physicalCores * tc.testCase.cores + } + + if want != len(remaining) { + t.Fatalf("Failed to shutdown the correct number of cores: want: %d got: %d", want, len(remaining)) + } + + if !tc.isVulnerable { + return + } + + // If the set is vulnerable, we expect only 1 thread per hyperthread pair. + for _, r := range remaining { + if _, ok := set[r.id]; !ok { + t.Fatalf("Entry %+v not in map, there must be two entries in the same thread group.", r) + } + delete(set, r.id) + } + }) + } } // TestGetCPU tests basic parsing of single CPU strings from reading @@ -44,15 +181,19 @@ func TestGetCPU(t *testing.T) { vendor_id : GenuineIntel cpu family : 6 model : 85 +physical id: 0 core id : 0 bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit ` - want := cpu{ + want := thread{ processorNumber: 0, vendorID: "GenuineIntel", cpuFamily: 6, model: 85, - coreID: 0, + id: cpuID{ + physicalID: 0, + coreID: 0, + }, bugs: map[string]struct{}{ "cpu_meltdown": struct{}{}, "spectre_v1": struct{}{}, @@ -66,7 +207,7 @@ bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa }, } - got, err := getCPU(data) + got, err := newThread(data) if err != nil { t.Fatalf("getCpu failed with error: %v", err) } @@ -81,7 +222,7 @@ bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa } func TestInvalid(t *testing.T) { - result, err := getCPUSet(`something not a processor`) + result, err := getThreads(`something not a processor`) if err == nil { t.Fatalf("getCPU set didn't return an error: %+v", result) } @@ -148,7 +289,7 @@ cache_alignment : 64 address sizes : 46 bits physical, 48 bits virtual power management: ` - cpuSet, err := getCPUSet(data) + cpuSet, err := getThreads(data) if err != nil { t.Fatalf("getCPUSet failed: %v", err) } @@ -158,7 +299,7 @@ power management: t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet)) } - wantCPU := cpu{ + wantCPU := thread{ vendorID: "GenuineIntel", cpuFamily: 6, model: 63, @@ -187,7 +328,11 @@ func TestReadFile(t *testing.T) { t.Fatalf("Failed to read cpuinfo: %v", err) } - set, err := getCPUSet(string(data)) + vulnerable := func(t *thread) bool { + return t.isVulnerable() + } + + set, err := newCPUSet(data, vulnerable) if err != nil { t.Fatalf("Failed to parse CPU data %v\n%s", err, data) } @@ -196,9 +341,7 @@ func TestReadFile(t *testing.T) { t.Fatalf("Failed to parse any CPUs: %d", len(set)) } - for _, c := range set { - t.Logf("CPU: %+v: %t", c, c.isVulnerable()) - } + t.Log(set) } // TestVulnerable tests if the isVulnerable method is correct @@ -332,17 +475,13 @@ power management:` cpuString: skylake, vulnerable: true, }, { - name: "cascadeLake", - cpuString: cascade, - vulnerable: false, - }, { name: "amd", cpuString: amd, vulnerable: false, }, } { t.Run(tc.name, func(t *testing.T) { - set, err := getCPUSet(tc.cpuString) + set, err := getThreads(tc.cpuString) if err != nil { t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString) } @@ -353,9 +492,6 @@ power management:` for _, c := range set { got := func() bool { - if cascadeLake.similarTo(c) { - return false - } return c.isVulnerable() }() diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 266bc0bdc..7fe65c7ba 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -308,6 +308,22 @@ func (s *Sandbox) Processes(cid string) ([]*control.Process, error) { return pl, nil } +// FindCgroup returns the sandbox's Cgroup, or an error if it does not have one. +func (s *Sandbox) FindCgroup() (*cgroup.Cgroup, error) { + paths, err := cgroup.LoadPaths(strconv.Itoa(s.Pid)) + if err != nil { + return nil, err + } + // runsc places sandboxes in the same cgroup for each controller, so we + // pick an arbitrary controller here to get the cgroup path. + const controller = "cpuacct" + controllerPath, ok := paths[controller] + if !ok { + return nil, fmt.Errorf("no %q controller found", controller) + } + return cgroup.NewFromPath(controllerPath) +} + // Execute runs the specified command in the container. It returns the PID of // the newly created process. func (s *Sandbox) Execute(args *control.ExecArgs) (int32, error) { @@ -327,7 +343,7 @@ func (s *Sandbox) Execute(args *control.ExecArgs) (int32, error) { } // Event retrieves stats about the sandbox such as memory and CPU utilization. -func (s *Sandbox) Event(cid string) (*boot.Event, error) { +func (s *Sandbox) Event(cid string) (*boot.EventOut, error) { log.Debugf("Getting events for container %q in sandbox %q", cid, s.ID) conn, err := s.sandboxConnect() if err != nil { @@ -335,13 +351,13 @@ func (s *Sandbox) Event(cid string) (*boot.Event, error) { } defer conn.Close() - var e boot.Event + var e boot.EventOut // TODO(b/129292330): Pass in the container id (cid) here. The sandbox // should return events only for that container. if err := conn.Call(boot.ContainerEvent, nil, &e); err != nil { return nil, fmt.Errorf("retrieving event data from sandbox: %v", err) } - e.ID = cid + e.Event.ID = cid return &e, nil } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index d07ed6ba5..aaffabfd0 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -434,18 +434,7 @@ func TestTmpMount(t *testing.T) { // runsc to hide the incoherence of FDs opened before and after overlayfs // copy-up on the host. func TestHostOverlayfsCopyUp(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/hostoverlaytest", - WorkDir: "/root", - }, "./test_copy_up"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_copy_up test_copy_up.c && ./test_copy_up") } // TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory @@ -460,36 +449,14 @@ func TestHostOverlayfsCopyUp(t *testing.T) { // automated tests yield newly-added files from readdir() even if the fsgofer // does not explicitly rewinddir(), but overlayfs does not. func TestHostOverlayfsRewindDir(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/hostoverlaytest", - WorkDir: "/root", - }, "./test_rewinddir"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o test_rewinddir test_rewinddir.c && ./test_rewinddir") } // Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it // cannot use tricks like userns as root. For this reason, run a basic link test // to ensure some coverage. func TestLink(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/linktest", - WorkDir: "/root", - }, "./link_test"); err != nil { - t.Fatalf("docker run failed: %v", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "sh", "-c", "gcc -O2 -o link_test link_test.c && ./link_test") } // This test ensures we can run ping without errors. @@ -500,17 +467,7 @@ func TestPing4Loopback(t *testing.T) { t.Skip("hostnet only supports TCP/UDP sockets, so ping is not supported.") } - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ping4test", - }, "/root/ping4.sh"); err != nil { - t.Fatalf("docker run failed: %s", err) - } else if got != "" { - t.Errorf("test failed:\n%s", got) - } + runIntegrationTest(t, nil, "./ping4.sh") } // This test ensures we can enable ipv6 on loopback and run ping6 without @@ -522,20 +479,25 @@ func TestPing6Loopback(t *testing.T) { t.Skip("hostnet only supports TCP/UDP sockets, so ping6 is not supported.") } + // The CAP_NET_ADMIN capability is required to use the `ip` utility, which + // we use to enable ipv6 on loopback. + // + // By default, ipv6 loopback is not enabled by runsc, because docker does + // not assign an ipv6 address to the test container. + runIntegrationTest(t, []string{"NET_ADMIN"}, "./ping6.sh") +} + +func runIntegrationTest(t *testing.T, capAdd []string, args ...string) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) if got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ping6test", - // The CAP_NET_ADMIN capability is required to use the `ip` utility, which - // we use to enable ipv6 on loopback. - // - // By default, ipv6 loopback is not enabled by runsc, because docker does - // not assign an ipv6 address to the test container. - CapAdd: []string{"NET_ADMIN"}, - }, "/root/ping6.sh"); err != nil { - t.Fatalf("docker run failed: %s", err) + Image: "basic/integrationtest", + WorkDir: "/root", + CapAdd: capAdd, + }, args...); err != nil { + t.Fatalf("docker run failed: %v", err) } else if got != "" { t.Errorf("test failed:\n%s", got) } diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD index ccf1c735f..0be14ca3e 100644 --- a/test/packetimpact/dut/BUILD +++ b/test/packetimpact/dut/BUILD @@ -14,6 +14,7 @@ cc_binary( grpcpp, "//test/packetimpact/proto:posix_server_cc_grpc_proto", "//test/packetimpact/proto:posix_server_cc_proto", + "@com_google_absl//absl/strings:str_format", ], ) @@ -24,5 +25,6 @@ cc_binary( grpcpp, "//test/packetimpact/proto:posix_server_cc_grpc_proto", "//test/packetimpact/proto:posix_server_cc_proto", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 4de8540f6..eba21df12 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -16,6 +16,7 @@ #include <getopt.h> #include <netdb.h> #include <netinet/in.h> +#include <poll.h> #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -30,6 +31,7 @@ #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" #include "include/grpcpp/server_context.h" +#include "absl/strings/str_format.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" @@ -256,6 +258,44 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status::OK; } + ::grpc::Status Poll(::grpc::ServerContext *context, + const ::posix_server::PollRequest *request, + ::posix_server::PollResponse *response) override { + std::vector<struct pollfd> pfds; + pfds.reserve(request->pfds_size()); + for (const auto &pfd : request->pfds()) { + pfds.push_back({ + .fd = pfd.fd(), + .events = static_cast<short>(pfd.events()), + }); + } + int ret = ::poll(pfds.data(), pfds.size(), request->timeout_millis()); + + response->set_ret(ret); + if (ret < 0) { + response->set_errno_(errno); + } else { + // Only pollfds that have non-empty revents are returned, the client can't + // rely on indexes of the request array. + for (const auto &pfd : pfds) { + if (pfd.revents) { + auto *proto_pfd = response->add_pfds(); + proto_pfd->set_fd(pfd.fd); + proto_pfd->set_events(pfd.revents); + } + } + if (int ready = response->pfds_size(); ret != ready) { + return ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + absl::StrFormat( + "poll's return value(%d) doesn't match the number of " + "file descriptors that are actually ready(%d)", + ret, ready)); + } + } + return ::grpc::Status::OK; + } + ::grpc::Status Send(::grpc::ServerContext *context, const ::posix_server::SendRequest *request, ::posix_server::SendResponse *response) override { diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index f32ed54ef..b4c68764a 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -142,6 +142,25 @@ message ListenResponse { int32 errno_ = 2; // "errno" may fail to compile in c++. } +// The events field is overloaded: when used for request, it is copied into the +// events field of posix struct pollfd; when used for response, it is filled by +// the revents field from the posix struct pollfd. +message PollFd { + int32 fd = 1; + uint32 events = 2; +} + +message PollRequest { + repeated PollFd pfds = 1; + int32 timeout_millis = 2; +} + +message PollResponse { + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. + repeated PollFd pfds = 3; +} + message SendRequest { int32 sockfd = 1; bytes buf = 2; @@ -226,6 +245,10 @@ service Posix { rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse); // Call listen() on the DUT. rpc Listen(ListenRequest) returns (ListenResponse); + // Call poll() on the DUT. Only pollfds that have non-empty revents are + // returned, the only way to tie the response back to the original request + // is using the fd number. + rpc Poll(PollRequest) returns (PollResponse); // Call send() on the DUT. rpc Send(SendRequest) returns (SendResponse); // Call sendto() on the DUT. diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 5c3c569de..a7c46781f 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -175,9 +175,6 @@ ALL_TESTS = [ name = "udp_discard_mcast_source_addr", ), PacketimpactTestInfo( - name = "udp_recv_mcast_bcast", - ), - PacketimpactTestInfo( name = "udp_any_addr_recv_unicast", ), PacketimpactTestInfo( @@ -281,6 +278,9 @@ ALL_TESTS = [ name = "tcp_rack", expect_netstack_failure = True, ), + PacketimpactTestInfo( + name = "tcp_info", + ), ] def validate_all_tests(): diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 576577310..1453ac232 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -1008,6 +1008,13 @@ func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { return sa } +// SrcPort returns the source port of this connection. +func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 { + t.Helper() + + return *conn.udpState(t).out.SrcPort +} + // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { @@ -1024,6 +1031,11 @@ func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ... (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + (*Connection)(conn).send(t, overrideLayers, additionalLayers...) +} + // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { @@ -1053,6 +1065,14 @@ func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout return (*Connection)(conn).ExpectFrame(t, expected, timeout) } +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv4) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) +} + // Close frees associated resources held by the UDPIPv4 connection. func (conn *UDPIPv4) Close(t *testing.T) { t.Helper() @@ -1136,6 +1156,13 @@ func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 return sa } +// SrcPort returns the source port of this connection. +func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 { + t.Helper() + + return *conn.udpState(t).out.SrcPort +} + // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { @@ -1152,6 +1179,11 @@ func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers . (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } +// SendFrame sends a frame on the wire and updates the state of all layers. +func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + (*Connection)(conn).send(t, overrideLayers, additionalLayers...) +} + // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { @@ -1181,6 +1213,14 @@ func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout return (*Connection)(conn).ExpectFrame(t, expected, timeout) } +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv6) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) +} + // Close frees associated resources held by the UDPIPv6 connection. func (conn *UDPIPv6) Close(t *testing.T) { t.Helper() diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 66a0255b8..aedcf6013 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -486,6 +486,56 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backl return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } +// Poll calls poll on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over error handling is needed, use PollWithErrno. +// Only pollfds with non-empty revents are returned, the only way to tie the +// response back to the original request is using the fd number. +func (dut *DUT) Poll(t *testing.T, pfds []unix.PollFd, timeout time.Duration) []unix.PollFd { + t.Helper() + + ctx := context.Background() + var cancel context.CancelFunc + if timeout >= 0 { + ctx, cancel = context.WithTimeout(ctx, timeout+RPCTimeout) + defer cancel() + } + ret, result, err := dut.PollWithErrno(ctx, t, pfds, timeout) + if ret < 0 { + t.Fatalf("failed to poll: %s", err) + } + return result +} + +// PollWithErrno calls poll on the DUT. +func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.PollFd, timeout time.Duration) (int32, []unix.PollFd, error) { + t.Helper() + + req := pb.PollRequest{ + TimeoutMillis: int32(timeout.Milliseconds()), + } + for _, pfd := range pfds { + req.Pfds = append(req.Pfds, &pb.PollFd{ + Fd: pfd.Fd, + Events: uint32(pfd.Events), + }) + } + resp, err := dut.posixServer.Poll(ctx, &req) + if err != nil { + t.Fatalf("failed to call Poll: %s", err) + } + if ret, npfds := resp.GetRet(), len(resp.GetPfds()); ret >= 0 && int(ret) != npfds { + t.Fatalf("nonsensical poll response: ret(%d) != len(pfds)(%d)", ret, npfds) + } + var result []unix.PollFd + for _, protoPfd := range resp.GetPfds() { + result = append(result, unix.PollFd{ + Fd: protoPfd.GetFd(), + Revents: int16(protoPfd.GetEvents()), + }) + } + return resp.GetRet(), result, syscall.Errno(resp.GetErrno_()) +} + // Send calls send on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendWithErrno. @@ -544,7 +594,7 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, } resp, err := dut.posixServer.SendTo(ctx, &req) if err != nil { - t.Fatalf("faled to call SendTo: %s", err) + t.Fatalf("failed to call SendTo: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 6c6f2bdf7..baa3ae5e9 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -38,18 +38,6 @@ packetimpact_testbench( ) packetimpact_testbench( - name = "udp_recv_mcast_bcast", - srcs = ["udp_recv_mcast_bcast_test.go"], - deps = [ - "//pkg/tcpip", - "//pkg/tcpip/header", - "//test/packetimpact/testbench", - "@com_github_google_go_cmp//cmp:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) - -packetimpact_testbench( name = "udp_any_addr_recv_unicast", srcs = ["udp_any_addr_recv_unicast_test.go"], deps = [ @@ -340,6 +328,8 @@ packetimpact_testbench( name = "udp_send_recv_dgram", srcs = ["udp_send_recv_dgram_test.go"], deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", "//test/packetimpact/testbench", "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", @@ -390,6 +380,19 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_info", + srcs = ["tcp_info_test.go"], + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/tcpip/header", + "//pkg/usermem", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go index d2203082d..ee050e2c6 100644 --- a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go @@ -45,8 +45,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { ipPayloadLen int fragments []fragmentInfo expectReply bool - skip bool - skipReason string }{ { description: "basic reassembly", @@ -78,8 +76,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 7, more: 0}, }, expectReply: true, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, { description: "fragment subset", @@ -91,8 +87,6 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 8, more: 0}, }, expectReply: true, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, { description: "fragment overlap", @@ -104,16 +98,10 @@ func TestIPv4FragmentReassembly(t *testing.T) { {offset: 2000, size: 1000, id: 9, more: 0}, }, expectReply: false, - skip: true, - skipReason: "gvisor.dev/issues/4971", }, } for _, test := range tests { - if test.skip { - t.Skip("%s test skipped: %s", test.description, test.skipReason) - continue - } t.Run(test.description, func(t *testing.T) { dut := testbench.NewDUT(t) conn := dut.Net.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{}) diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go new file mode 100644 index 000000000..69275e54b --- /dev/null +++ b/test/packetimpact/tests/tcp_info_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_info_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +func TestTCPInfo(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + // Send and receive sample data. + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + dut.Send(t, acceptFD, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + info := linux.TCPInfo{} + infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) + + rtt := time.Duration(info.RTT) * time.Microsecond + rttvar := time.Duration(info.RTTVar) * time.Microsecond + rto := time.Duration(info.RTO) * time.Microsecond + if rtt == 0 || rttvar == 0 || rto == 0 { + t.Errorf("expected rtt(%v), rttvar(%v) and rto(%v) to be greater than zero", rtt, rttvar, rto) + } + if info.ReordSeen != 0 { + t.Errorf("expected the connection to not have any reordering, got: %v want: 0", info.ReordSeen) + } + if info.SndCwnd == 0 { + t.Errorf("expected send congestion window to be greater than zero") + } + if info.CaState != linux.TCP_CA_Open { + t.Errorf("expected the connection to be in open state, got: %v want: %v", info.CaState, linux.TCP_CA_Open) + } + + if t.Failed() { + t.FailNow() + } + + // Check the congestion control state and send congestion window after + // retransmission timeout. + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + dut.Send(t, acceptFD, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + // Expect retransmission of the packet within 1.5*RTO. + timeout := time.Duration(float64(info.RTO)*1.5) * time.Microsecond + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + info = linux.TCPInfo{} + infoBytes = dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) + if info.CaState != linux.TCP_CA_Loss { + t.Errorf("expected the connection to be in loss recovery, got: %v want: %v", info.CaState, linux.TCP_CA_Loss) + } + if info.SndCwnd != 1 { + t.Errorf("expected send congestion window to be 1, got: %v %v", info.SndCwnd) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index f0af5352d..c874a8912 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -34,6 +34,21 @@ func TestTcpNoAcceptCloseReset(t *testing.T) { conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) conn.Connect(t) defer conn.Close(t) + // We need to wait for POLLIN event on listenFd to know the connection is + // established. Otherwise there could be a race when we issue the Close + // command prior to the DUT receiving the last ack of the handshake and + // it will only respond RST instead of RST+ACK. + timeout := time.Second + pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, timeout) + if n := len(pfds); n != 1 { + t.Fatalf("poll returned %d ready file descriptors, expected 1", n) + } + if readyFd := pfds[0].Fd; readyFd != listenFd { + t.Fatalf("poll returned an fd %d that was not requested (%d)", readyFd, listenFd) + } + if got, want := pfds[0].Revents, int16(unix.POLLIN); got&want == 0 { + t.Fatalf("poll returned no events in our interest, got: %#b, want: %#b", got, want) + } dut.Close(t, listenFd) if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index 1b041932a..8909a348e 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -84,6 +84,24 @@ func TestTCPOutsideTheWindow(t *testing.T) { if tt.expectACK && err != nil { t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) } + // Data packets w/o SYN bits are always acked by Linux. Netstack ACK's data packets + // always right now. So only send a second segment and test for no ACK for packets + // with no data. + if tt.expectACK && tt.payload == nil { + // Sending another out-of-window segment immediately should not trigger + // an ACK if less than 500ms(default rate limit for out-of-window ACKs) + // has passed since the last ACK was sent. + t.Logf("sending another segment") + conn.Send(t, testbench.TCP{ + Flags: testbench.Uint8(tt.tcpFlags), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), + }, tt.payload...) + timeout := 3 * time.Second + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + if err == nil { + t.Fatalf("expected no ACK packet but got one: %s", gotACK) + } + } if !tt.expectACK && gotACK != nil { t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK) } diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go index 0a2381c97..fb2a4cc90 100644 --- a/test/packetimpact/tests/tcp_rack_test.go +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -70,8 +70,11 @@ func closeSACKConnection(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4 func getRTTAndRTO(t *testing.T, dut testbench.DUT, acceptFd int32) (rtt, rto time.Duration) { info := linux.TCPInfo{} - ret := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) - binary.Unmarshal(ret, usermem.ByteOrder, &info) + infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) return time.Duration(info.RTT) * time.Microsecond, time.Duration(info.RTO) * time.Microsecond } @@ -219,3 +222,200 @@ func TestRACKTLPWithSACK(t *testing.T) { } closeSACKConnection(t, dut, conn, acceptFd, listenFd) } + +// TestRACKWithoutReorder tests that without reordering RACK will retransmit the +// lost packets after reorder timer expires. +func TestRACKWithoutReorder(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 4 + sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + // SACK for [3,4] packets. + sackBlock := make([]byte, 40) + start := seqNum1.Add(seqnum.Size(2 * payloadSize)) + end := start.Add(seqnum.Size(2 * payloadSize)) + sbOff := 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock[sbOff:]) + time.Sleep(simulatedRTT) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // RACK marks #1 and #2 packets as lost and retransmits both after + // RTT + reorderWindow. The reorderWindow initially will be a small + // fraction of RTT. + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + timeout := 2 * rtt + for i, sn := 0, seqNum1; i < 2; i++ { + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(sn))}, timeout); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + sn.UpdateForward(seqnum.Size(payloadSize)) + } + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} + +// TestRACKWithReorder tests that RACK will retransmit segments when there is +// reordering in the connection and reorder timer expires. +func TestRACKWithReorder(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 4 + sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + time.Sleep(simulatedRTT) + // SACK in reverse order for the connection to detect reorder. + var start seqnum.Value + var end seqnum.Value + for i := 0; i < numPkts-1; i++ { + sackBlock := make([]byte, 40) + sbOff := 0 + start = seqNum1.Add(seqnum.Size((numPkts - i - 1) * payloadSize)) + end = start.Add(seqnum.Size((i + 1) * payloadSize)) + sackBlock = make([]byte, 40) + sbOff = 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock[sbOff:]) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + } + + // Send a DSACK block indicating both original and retransmitted + // packets are received, RACK will increase the reordering window on + // every DSACK. + dsackBlock := make([]byte, 40) + dbOff := 0 + start = seqNum1 + end = start.Add(seqnum.Size(2 * payloadSize)) + dbOff += header.EncodeNOP(dsackBlock[dbOff:]) + dbOff += header.EncodeNOP(dsackBlock[dbOff:]) + dbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, dsackBlock[dbOff:]) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1 + numPkts*payloadSize)), Options: dsackBlock[:dbOff]}) + + seqNum1.UpdateForward(seqnum.Size(numPkts * payloadSize)) + sendTime := time.Now() + sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + time.Sleep(simulatedRTT) + // Send SACK for [2-5] packets. + sackBlock := make([]byte, 40) + sbOff := 0 + start = seqNum1.Add(seqnum.Size(payloadSize)) + end = start.Add(seqnum.Size(3 * payloadSize)) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock[sbOff:]) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // Expect the retransmission of #1 packet after RTT+ReorderWindow. + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + diff := time.Now().Sub(sendTime) + if diff < rtt { + t.Fatalf("expected payload was received too sonn, within RTT") + } + + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} + +// TestRACKWithLostRetransmission tests that RACK will not enter RTO when a +// retransmitted segment is lost and enters fast recovery. +func TestRACKWithLostRetransmission(t *testing.T) { + dut, conn, acceptFd, listenFd := createSACKConnection(t) + seqNum1 := *conn.RemoteSeqNum(t) + + // Send ACK for data packets to establish RTT. + sendAndReceive(t, dut, conn, numPktsForRTT, acceptFd, true /* sendACK */) + seqNum1.UpdateForward(seqnum.Size(numPktsForRTT * payloadSize)) + + // We are not sending ACK for these packets. + const numPkts = 5 + sendAndReceive(t, dut, conn, numPkts, acceptFd, false /* sendACK */) + + // SACK for [2-5] packets. + sackBlock := make([]byte, 40) + start := seqNum1.Add(seqnum.Size(payloadSize)) + end := start.Add(seqnum.Size(4 * payloadSize)) + sbOff := 0 + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeNOP(sackBlock[sbOff:]) + sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock[sbOff:]) + time.Sleep(simulatedRTT) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + + // RACK marks #1 packet as lost and retransmits it after + // RTT + reorderWindow. The reorderWindow is bounded between a small + // fraction of RTT and 1 RTT. + rtt, _ := getRTTAndRTO(t, dut, acceptFd) + timeout := 2 * rtt + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Send #6 packet. + payload := make([]byte, payloadSize) + dut.Send(t, acceptFd, payload, 0) + gotOne, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1 + 5*payloadSize))}, time.Second) + if err != nil { + t.Fatalf("Expect #6: %s", err) + } + if gotOne == nil { + t.Fatalf("#6: expected a packet within a second but got none") + } + + // SACK for [2-6] packets. + sackBlock1 := make([]byte, 40) + start = seqNum1.Add(seqnum.Size(payloadSize)) + end = start.Add(seqnum.Size(5 * payloadSize)) + sbOff1 := 0 + sbOff1 += header.EncodeNOP(sackBlock1[sbOff1:]) + sbOff1 += header.EncodeNOP(sackBlock1[sbOff1:]) + sbOff1 += header.EncodeSACKBlocks([]header.SACKBlock{{ + start, end, + }}, sackBlock1[sbOff1:]) + time.Sleep(simulatedRTT) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock1[:sbOff1]}) + + // Expect re-retransmission of #1 packet without entering an RTO. + if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Check the congestion control state. + info := linux.TCPInfo{} + infoBytes := dut.GetSockOpt(t, acceptFd, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) + if got, want := len(infoBytes), linux.SizeOfTCPInfo; got != want { + t.Fatalf("expected %T, got %d bytes want %d bytes", info, got, want) + } + binary.Unmarshal(infoBytes, usermem.ByteOrder, &info) + if info.CaState != linux.TCP_CA_Recovery { + t.Fatalf("expected connection to be in fast recovery, want: %v got: %v", linux.TCP_CA_Recovery, info.CaState) + } + + closeSACKConnection(t, dut, conn, acceptFd, listenFd) +} diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 1ab9ee1b2..b15b8fc25 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -66,33 +66,39 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) - startProbeDuration := time.Second - current := startProbeDuration - first := time.Now() // Ask the dut to send out data. dut.Send(t, acceptFd, sampleData, 0) + + var prev time.Duration // Expect the dut to keep the connection alive as long as the remote is // acknowledging the zero-window probes. - for i := 0; i < 5; i++ { + for i := 1; i <= 5; i++ { start := time.Now() // Expect zero-window probe with a timeout which is a function of the typical // first retransmission time. The retransmission times is supposed to // exponentially increase. - if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Duration(i)*time.Second); err != nil { t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) } - if i == 0 { - startProbeDuration = time.Now().Sub(first) - current = 2 * startProbeDuration + if i == 1 { + // Skip the first probe as computing transmit time for that is + // non-deterministic because of the arbitrary time taken for + // the dut to receive a send command and issue a send. continue } - // Check if the probes came at exponentially increasing intervals. - if got, want := time.Since(start), current-startProbeDuration; got < want { + + // Check if the time taken to receive the probe from the dut is + // increasing exponentially. To avoid flakes, use a correction + // factor for the expected duration which accounts for any + // scheduling non-determinism. + const timeCorrection = 200 * time.Millisecond + got := time.Since(start) + if want := (2 * prev) - timeCorrection; prev != 0 && got < want { t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) } + prev = got // Acknowledge the zero-window probes from the dut. conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - current *= 2 } // Advertize non-zero window. conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go deleted file mode 100644 index b29c07825..000000000 --- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package udp_recv_mcast_bcast_test - -import ( - "context" - "flag" - "fmt" - "net" - "syscall" - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.Initialize(flag.CommandLine) -} - -func TestUDPRecvMcastBcast(t *testing.T) { - dut := testbench.NewDUT(t) - subnetBcastAddr := broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)) - for _, v := range []struct { - bound, to net.IP - }{ - {bound: net.IPv4zero, to: subnetBcastAddr}, - {bound: net.IPv4zero, to: net.IPv4bcast}, - {bound: net.IPv4zero, to: net.IPv4allsys}, - - {bound: subnetBcastAddr, to: subnetBcastAddr}, - - // FIXME(gvisor.dev/issue/4896): Previously by the time subnetBcastAddr is - // created, IPv4PrefixLength is still 0 because genPseudoFlags is not called - // yet, it was only called in NewDUT, so the test didn't do what the author - // original intended to and becomes failing because we process all flags at - // the very beginning. - // - // {bound: subnetBcastAddr, to: net.IPv4bcast}, - - {bound: net.IPv4bcast, to: net.IPv4bcast}, - {bound: net.IPv4allsys, to: net.IPv4allsys}, - } { - t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) { - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound) - defer dut.Close(t, boundFD) - conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close(t) - - payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) - conn.SendIP( - t, - testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))}, - testbench.UDP{}, - &testbench.Payload{Bytes: payload}, - ) - got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) - } - }) - } -} - -func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) { - dut := testbench.NewDUT(t) - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, dut.Net.RemoteIPv4) - dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0}) - defer dut.Close(t, boundFD) - conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close(t) - - for _, to := range []net.IP{ - broadcastAddr(dut.Net.RemoteIPv4, net.CIDRMask(dut.Net.IPv4PrefixLength, 32)), - net.IPv4(255, 255, 255, 255), - net.IPv4(224, 0, 0, 1), - } { - t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) { - payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) - conn.SendIP( - t, - testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))}, - testbench.UDP{}, - &testbench.Payload{Bytes: payload}, - ) - ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0) - if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { - t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) - } - }) - } -} - -func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { - result := make(net.IP, net.IPv4len) - ip4 := ip.To4() - for i := range ip4 { - result[i] = ip4[i] | ^mask[i] - } - return result -} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 7ee2c8014..6e45cb143 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -15,13 +15,18 @@ package udp_send_recv_dgram_test import ( + "context" "flag" + "fmt" "net" + "syscall" "testing" "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -30,74 +35,295 @@ func init() { } type udpConn interface { - Send(*testing.T, testbench.UDP, ...testbench.Layer) - ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) - Drain(*testing.T) + SrcPort(*testing.T) uint16 + SendFrame(*testing.T, testbench.Layers, ...testbench.Layer) + ExpectFrame(*testing.T, testbench.Layers, time.Duration) (testbench.Layers, error) Close(*testing.T) } +type testCase struct { + bindTo, sendTo net.IP + sendToBroadcast, bindToDevice, expectData bool +} + func TestUDP(t *testing.T) { dut := testbench.NewDUT(t) + subnetBcast := func() net.IP { + subnet := (&tcpip.AddressWithPrefix{ + Address: tcpip.Address(dut.Net.RemoteIPv4.To4()), + PrefixLen: dut.Net.IPv4PrefixLength, + }).Subnet() + return net.IP(subnet.Broadcast()) + }() - for _, isIPv4 := range []bool{true, false} { - ipVersionName := "IPv6" - if isIPv4 { - ipVersionName = "IPv4" - } - t.Run(ipVersionName, func(t *testing.T) { - var addr net.IP - if isIPv4 { - addr = dut.Net.RemoteIPv4 - } else { - addr = dut.Net.RemoteIPv6 + t.Run("Send", func(t *testing.T) { + var testCases []testCase + // Test every valid combination of bound/unbound, broadcast/multicast/unicast + // bound/destination address, and bound/not-bound to device. + for _, bindTo := range []net.IP{ + nil, // Do not bind. + net.IPv4zero, + net.IPv4bcast, + net.IPv4allsys, + subnetBcast, + dut.Net.RemoteIPv4, + dut.Net.RemoteIPv6, + } { + for _, sendTo := range []net.IP{ + net.IPv4bcast, + net.IPv4allsys, + subnetBcast, + dut.Net.LocalIPv4, + dut.Net.LocalIPv6, + } { + // Cannot send to an IPv4 address from a socket bound to IPv6 (except for IPv4-mapped IPv6), + // and viceversa. + if bindTo != nil && ((bindTo.To4() == nil) != (sendTo.To4() == nil)) { + continue + } + for _, bindToDevice := range []bool{true, false} { + expectData := true + switch { + case bindTo.Equal(dut.Net.RemoteIPv4): + // If we're explicitly bound to an interface's unicast address, + // packets are always sent on that interface. + case bindToDevice: + // If we're explicitly bound to an interface, packets are always + // sent on that interface. + case !sendTo.Equal(net.IPv4bcast) && !sendTo.IsMulticast(): + // If we're not sending to limited broadcast or multicast, the route table + // will be consulted and packets will be sent on the correct interface. + default: + expectData = false + } + testCases = append( + testCases, + testCase{ + bindTo: bindTo, + sendTo: sendTo, + sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: expectData, + }, + ) + } } - boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, addr) - defer dut.Close(t, boundFD) - - var conn udpConn - var localAddr unix.Sockaddr - if isIPv4 { - v4Conn := dut.Net.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v4Conn.LocalAddr(t) - conn = &v4Conn - } else { - v6Conn := dut.Net.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - localAddr = v6Conn.LocalAddr(t, dut.Net.RemoteDevID) - conn = &v6Conn - } - defer conn.Close(t) - - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. + } + for _, tc := range testCases { + boundTestCaseName := "unbound" + if tc.bindTo != nil { + boundTestCaseName = fmt.Sprintf("bindTo=%s", tc.bindTo) } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("Send", func(t *testing.T) { - conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) - got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload - if diff := cmp.Diff(want, got); diff != "" { - t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + t.Run(fmt.Sprintf("%s/sendTo=%s/bindToDevice=%t/expectData=%t", boundTestCaseName, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) { + runTestCase( + t, + dut, + tc, + func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) { + var destSockaddr unix.Sockaddr + if sendTo4 := tc.sendTo.To4(); sendTo4 != nil { + addr := unix.SockaddrInet4{ + Port: int(conn.SrcPort(t)), + } + copy(addr.Addr[:], sendTo4) + destSockaddr = &addr + } else { + addr := unix.SockaddrInet6{ + Port: int(conn.SrcPort(t)), + ZoneId: dut.Net.RemoteDevID, + } + copy(addr.Addr[:], tc.sendTo.To16()) + destSockaddr = &addr } - }) - t.Run("Recv", func(t *testing.T) { - conn.Drain(t) - if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { - t.Fatalf("short write got: %d, want: %d", got, want) + if got, want := dut.SendTo(t, socketFD, payload, 0, destSockaddr), len(payload); int(got) != want { + t.Fatalf("got dut.SendTo = %d, want %d", got, want) } - if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + layers = append(layers, &testbench.Payload{ + Bytes: payload, + }) + _, err := conn.ExpectFrame(t, layers, time.Second) + + if !tc.expectData && err == nil { + t.Fatal("received unexpected packet, socket is not bound to device") + } + if err != nil && tc.expectData { t.Fatal(err) } - }) - }) + }, + ) + }) + } + }) + t.Run("Recv", func(t *testing.T) { + // Test every valid combination of broadcast/multicast/unicast + // bound/destination address, and bound/not-bound to device. + var testCases []testCase + for _, addr := range []net.IP{ + net.IPv4bcast, + net.IPv4allsys, + dut.Net.RemoteIPv4, + dut.Net.RemoteIPv6, + } { + for _, bindToDevice := range []bool{true, false} { + testCases = append( + testCases, + testCase{ + bindTo: addr, + sendTo: addr, + sendToBroadcast: addr.Equal(subnetBcast) || addr.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: true, + }, + ) } - }) + } + for _, bindTo := range []net.IP{ + net.IPv4zero, + subnetBcast, + dut.Net.RemoteIPv4, + } { + for _, sendTo := range []net.IP{ + subnetBcast, + net.IPv4bcast, + net.IPv4allsys, + } { + // TODO(gvisor.dev/issue/4896): Add bindTo=subnetBcast/sendTo=IPv4bcast + // and bindTo=subnetBcast/sendTo=IPv4allsys test cases. + if bindTo.Equal(subnetBcast) && (sendTo.Equal(net.IPv4bcast) || sendTo.IsMulticast()) { + continue + } + // Expect that a socket bound to a unicast address does not receive + // packets sent to an address other than the bound unicast address. + // + // Note: we cannot use net.IP.IsGlobalUnicast to test this condition + // because IsGlobalUnicast does not check whether the address is the + // subnet broadcast, and returns true in that case. + expectData := !bindTo.Equal(dut.Net.RemoteIPv4) || sendTo.Equal(dut.Net.RemoteIPv4) + for _, bindToDevice := range []bool{true, false} { + testCases = append( + testCases, + testCase{ + bindTo: bindTo, + sendTo: sendTo, + sendToBroadcast: sendTo.Equal(subnetBcast) || sendTo.Equal(net.IPv4bcast), + bindToDevice: bindToDevice, + expectData: expectData, + }, + ) + } + } + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("bindTo=%s/sendTo=%s/bindToDevice=%t/expectData=%t", tc.bindTo, tc.sendTo, tc.bindToDevice, tc.expectData), func(t *testing.T) { + runTestCase( + t, + dut, + tc, + func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers) { + conn.SendFrame(t, layers, &testbench.Payload{Bytes: payload}) + + if tc.expectData { + got, want := dut.Recv(t, socketFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + } else { + // Expected receive error, set a short receive timeout. + dut.SetSockOptTimeval( + t, + socketFD, + unix.SOL_SOCKET, + unix.SO_RCVTIMEO, + &unix.Timeval{ + Sec: 1, + Usec: 0, + }, + ) + ret, recvPayload, errno := dut.RecvWithErrno(context.Background(), t, socketFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, recvPayload, errno) + } + } + }, + ) + }) + } + }) +} + +func runTestCase( + t *testing.T, + dut testbench.DUT, + tc testCase, + runTc func(t *testing.T, dut testbench.DUT, conn udpConn, socketFD int32, tc testCase, payload []byte, layers testbench.Layers), +) { + var ( + socketFD int32 + outgoingUDP, incomingUDP testbench.UDP + ) + if tc.bindTo != nil { + var remotePort uint16 + socketFD, remotePort = dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, tc.bindTo) + outgoingUDP.DstPort = &remotePort + incomingUDP.SrcPort = &remotePort + } else { + // An unbound socket will auto-bind to INNADDR_ANY and a random + // port on sendto. + socketFD = dut.Socket(t, unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP) + } + defer dut.Close(t, socketFD) + if tc.bindToDevice { + dut.SetSockOpt(t, socketFD, unix.SOL_SOCKET, unix.SO_BINDTODEVICE, []byte(dut.Net.RemoteDevName)) + } + + var ethernetLayer testbench.Ether + if tc.sendToBroadcast { + dut.SetSockOptInt(t, socketFD, unix.SOL_SOCKET, unix.SO_BROADCAST, 1) + + // When sending to broadcast (subnet or limited), the expected ethernet + // address is also broadcast. + ethernetBroadcastAddress := header.EthernetBroadcastAddress + ethernetLayer.DstAddr = ðernetBroadcastAddress + } else if tc.sendTo.IsMulticast() { + ethernetMulticastAddress := header.EthernetAddressFromMulticastIPv4Address(tcpip.Address(tc.sendTo.To4())) + ethernetLayer.DstAddr = ðernetMulticastAddress + } + expectedLayers := testbench.Layers{ðernetLayer} + + var conn udpConn + if sendTo4 := tc.sendTo.To4(); sendTo4 != nil { + v4Conn := dut.Net.NewUDPIPv4(t, outgoingUDP, incomingUDP) + conn = &v4Conn + expectedLayers = append( + expectedLayers, + &testbench.IPv4{ + DstAddr: testbench.Address(tcpip.Address(sendTo4)), + }, + ) + } else { + v6Conn := dut.Net.NewUDPIPv6(t, outgoingUDP, incomingUDP) + conn = &v6Conn + expectedLayers = append( + expectedLayers, + &testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(tc.sendTo)), + }, + ) + } + defer conn.Close(t) + + expectedLayers = append(expectedLayers, &incomingUDP) + for _, v := range []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. + } { + runTc(t, dut, conn, socketFD, tc, v.payload, expectedLayers) } } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 6ee2b73c1..e43f30ba3 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -557,6 +557,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:setgid_test", +) + +syscall_test( add_overlay = True, test = "//test/syscalls/linux:splice_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 0da295e2d..80e2837f8 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -482,7 +482,9 @@ cc_binary( "//test/util:fs_util", "@com_google_absl//absl/strings", gtest, + "//test/util:logging", "//test/util:mount_util", + "//test/util:multiprocess_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -582,6 +584,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", gtest, + "//test/util:fs_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", @@ -671,6 +674,7 @@ cc_binary( gtest, "//test/util:logging", "//test/util:memory_util", + "//test/util:multiprocess_util", "//test/util:signal_util", "//test/util:test_main", "//test/util:test_util", @@ -1379,6 +1383,7 @@ cc_binary( "//test/util:file_descriptor", "//test/util:fs_util", gtest, + "//test/util:posix_error", "//test/util:temp_path", "//test/util:temp_umask", "//test/util:test_main", @@ -2142,6 +2147,24 @@ cc_binary( ) cc_binary( + name = "setgid_test", + testonly = 1, + srcs = ["setgid.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + gtest, + ], +) + +cc_binary( name = "splice_test", testonly = 1, srcs = ["splice.cc"], @@ -3826,6 +3849,8 @@ cc_binary( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", gtest, + "//test/util:cleanup", + "//test/util:multiprocess_util", "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", @@ -4081,6 +4106,7 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", gtest, + "//test/util:cleanup", "//test/util:test_main", "//test/util:test_util", ], diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc index a06b5cfd6..8233df0f8 100644 --- a/test/syscalls/linux/chmod.cc +++ b/test/syscalls/linux/chmod.cc @@ -98,6 +98,42 @@ TEST(ChmodTest, FchmodatBadF) { ASSERT_THAT(fchmodat(-1, "foo", 0444, 0), SyscallFailsWithErrno(EBADF)); } +TEST(ChmodTest, FchmodFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF)); +} + +TEST(ChmodTest, FchmodDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + + ASSERT_THAT(fchmod(fd.get(), 0444), SyscallFailsWithErrno(EBADF)); +} + +TEST(ChmodTest, FchmodatWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Drop capabilities that allow us to override file permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + const auto parent_fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(GetAbsoluteTestTmpdir().c_str(), O_PATH | O_DIRECTORY)); + + ASSERT_THAT( + fchmodat(parent_fd.get(), std::string(Basename(temp_file.path())).c_str(), + 0444, 0), + SyscallSucceeds()); + + EXPECT_THAT(open(temp_file.path().c_str(), O_RDWR), + SyscallFailsWithErrno(EACCES)); +} + TEST(ChmodTest, FchmodatNotDir) { ASSERT_THAT(fchmodat(-1, "", 0444, 0), SyscallFailsWithErrno(ENOENT)); } diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc index 5530ad18f..ff0d39343 100644 --- a/test/syscalls/linux/chown.cc +++ b/test/syscalls/linux/chown.cc @@ -48,6 +48,36 @@ TEST(ChownTest, FchownatBadF) { ASSERT_THAT(fchownat(-1, "fff", 0, 0, 0), SyscallFailsWithErrno(EBADF)); } +TEST(ChownTest, FchownFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()), + SyscallFailsWithErrno(EBADF)); +} + +TEST(ChownTest, FchownDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + + ASSERT_THAT(fchown(fd.get(), geteuid(), getegid()), + SyscallFailsWithErrno(EBADF)); +} + +TEST(ChownTest, FchownatWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + const auto dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY | O_PATH)); + ASSERT_THAT( + fchownat(dirfd.get(), file.path().c_str(), geteuid(), getegid(), 0), + SyscallSucceeds()); +} + TEST(ChownTest, FchownatEmptyPath) { const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const auto fd = @@ -209,6 +239,14 @@ INSTANTIATE_TEST_SUITE_P( owner, group, 0); MaybeSave(); return errorFromReturn("fchownat-dirfd", rc); + }, + [](const std::string& path, uid_t owner, gid_t group) -> PosixError { + ASSIGN_OR_RETURN_ERRNO(auto dirfd, Open(std::string(Dirname(path)), + O_DIRECTORY | O_PATH)); + int rc = fchownat(dirfd.get(), std::string(Basename(path)).c_str(), + owner, group, 0); + MaybeSave(); + return errorFromReturn("fchownat-opathdirfd", rc); })); } // namespace diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 85ec013d5..fab79d300 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -32,7 +32,9 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/logging.h" #include "test/util/mount_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -47,17 +49,20 @@ namespace { TEST(ChrootTest, Success) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); + const auto rest = [] { + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(ChrootTest, PermissionDenied) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission on - // directories. - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); - ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + // CAP_DAC_READ_SEARCH and CAP_DAC_OVERRIDE may override Execute permission + // on directories. + AutoCapability cap_search(CAP_DAC_READ_SEARCH, false); + AutoCapability cap_override(CAP_DAC_OVERRIDE, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0666 /* mode */)); @@ -78,8 +83,10 @@ TEST(ChrootTest, NotExist) { } TEST(ChrootTest, WithoutCapability) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETPCAP))); + // Unset CAP_SYS_CHROOT. - ASSERT_NO_ERRNO(SetCapability(CAP_SYS_CHROOT, false)); + AutoCapability cap(CAP_SYS_CHROOT, false); auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(chroot(temp_dir.path().c_str()), SyscallFailsWithErrno(EPERM)); @@ -97,51 +104,53 @@ TEST(ChrootTest, CreatesNewRoot) { auto file_in_new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(new_root.path())); - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); - - // getcwd should return "(unreachable)" followed by the initial_cwd. - char cwd[1024]; - ASSERT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - std::string expected_cwd = "(unreachable)"; - expected_cwd += initial_cwd; - EXPECT_STREQ(cwd, expected_cwd.c_str()); - - // Should not be able to stat file by its full path. - struct stat statbuf; - EXPECT_THAT(stat(file_in_new_root.path().c_str(), &statbuf), - SyscallFailsWithErrno(ENOENT)); - - // Should be able to stat file at new rooted path. - auto basename = std::string(Basename(file_in_new_root.path())); - auto rootedFile = "/" + basename; - ASSERT_THAT(stat(rootedFile.c_str(), &statbuf), SyscallSucceeds()); - - // Should be able to stat cwd at '.' even though it's outside root. - ASSERT_THAT(stat(".", &statbuf), SyscallSucceeds()); - - // chdir into new root. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // getcwd should return "/". - EXPECT_THAT(syscall(__NR_getcwd, cwd, sizeof(cwd)), SyscallSucceeds()); - EXPECT_STREQ(cwd, "/"); - - // Statting '.', '..', '/', and '/..' all return the same dev and inode. - struct stat statbuf_dot; - ASSERT_THAT(stat(".", &statbuf_dot), SyscallSucceeds()); - struct stat statbuf_dotdot; - ASSERT_THAT(stat("..", &statbuf_dotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_dotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_dotdot.st_ino); - struct stat statbuf_slash; - ASSERT_THAT(stat("/", &statbuf_slash), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slash.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slash.st_ino); - struct stat statbuf_slashdotdot; - ASSERT_THAT(stat("/..", &statbuf_slashdotdot), SyscallSucceeds()); - EXPECT_EQ(statbuf_dot.st_dev, statbuf_slashdotdot.st_dev); - EXPECT_EQ(statbuf_dot.st_ino, statbuf_slashdotdot.st_ino); + const auto rest = [&] { + // chroot into new_root. + TEST_CHECK_SUCCESS(chroot(new_root.path().c_str())); + + // getcwd should return "(unreachable)" followed by the initial_cwd. + char cwd[1024]; + TEST_CHECK_SUCCESS(syscall(__NR_getcwd, cwd, sizeof(cwd))); + std::string expected_cwd = "(unreachable)"; + expected_cwd += initial_cwd; + TEST_CHECK(strcmp(cwd, expected_cwd.c_str()) == 0); + + // Should not be able to stat file by its full path. + struct stat statbuf; + TEST_CHECK_ERRNO(stat(file_in_new_root.path().c_str(), &statbuf), ENOENT); + + // Should be able to stat file at new rooted path. + auto basename = std::string(Basename(file_in_new_root.path())); + auto rootedFile = "/" + basename; + TEST_CHECK_SUCCESS(stat(rootedFile.c_str(), &statbuf)); + + // Should be able to stat cwd at '.' even though it's outside root. + TEST_CHECK_SUCCESS(stat(".", &statbuf)); + + // chdir into new root. + TEST_CHECK_SUCCESS(chdir("/")); + + // getcwd should return "/". + TEST_CHECK_SUCCESS(syscall(__NR_getcwd, cwd, sizeof(cwd))); + TEST_CHECK_SUCCESS(strcmp(cwd, "/") == 0); + + // Statting '.', '..', '/', and '/..' all return the same dev and inode. + struct stat statbuf_dot; + TEST_CHECK_SUCCESS(stat(".", &statbuf_dot)); + struct stat statbuf_dotdot; + TEST_CHECK_SUCCESS(stat("..", &statbuf_dotdot)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_dotdot.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_dotdot.st_ino); + struct stat statbuf_slash; + TEST_CHECK_SUCCESS(stat("/", &statbuf_slash)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_slash.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_slash.st_ino); + struct stat statbuf_slashdotdot; + TEST_CHECK_SUCCESS(stat("/..", &statbuf_slashdotdot)); + TEST_CHECK(statbuf_dot.st_dev == statbuf_slashdotdot.st_dev); + TEST_CHECK(statbuf_dot.st_ino == statbuf_slashdotdot.st_ino); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(ChrootTest, DotDotFromOpenFD) { @@ -152,18 +161,20 @@ TEST(ChrootTest, DotDotFromOpenFD) { Open(dir_outside_root.path(), O_RDONLY | O_DIRECTORY)); auto new_root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - // chroot into new_root. - ASSERT_THAT(chroot(new_root.path().c_str()), SyscallSucceeds()); + const auto rest = [&] { + // chroot into new_root. + TEST_CHECK_SUCCESS(chroot(new_root.path().c_str())); - // openat on fd with path .. will succeed. - int other_fd; - ASSERT_THAT(other_fd = openat(fd.get(), "..", O_RDONLY), SyscallSucceeds()); - EXPECT_THAT(close(other_fd), SyscallSucceeds()); + // openat on fd with path .. will succeed. + int other_fd; + TEST_CHECK_SUCCESS(other_fd = openat(fd.get(), "..", O_RDONLY)); + TEST_CHECK_SUCCESS(close(other_fd)); - // getdents on fd should not error. - char buf[1024]; - ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)), - SyscallSucceeds()); + // getdents on fd should not error. + char buf[1024]; + TEST_CHECK_SUCCESS(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf))); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Test that link resolution in a chroot can escape the root by following an @@ -179,24 +190,27 @@ TEST(ChrootTest, ProcFdLinkResolutionInChroot) { const FileDescriptor proc_fd = ASSERT_NO_ERRNO_AND_VALUE( Open("/proc", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Opening relative to an already open fd to a node outside the chroot works. - const FileDescriptor proc_self_fd = ASSERT_NO_ERRNO_AND_VALUE( - OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); - - // Proc fd symlinks can escape the chroot if the fd the symlink refers to - // refers to an object outside the chroot. - struct stat s = {}; - EXPECT_THAT( - fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0), - SyscallSucceeds()); - - // Try to stat the stdin fd. Internally, this is handled differently from a - // proc fd entry pointing to a file, since stdin is backed by a host fd, and - // isn't a walkable path on the filesystem inside the sandbox. - EXPECT_THAT(fstatat(proc_self_fd.get(), "0", &s, 0), SyscallSucceeds()); + const auto rest = [&] { + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Opening relative to an already open fd to a node outside the chroot + // works. + const FileDescriptor proc_self_fd = TEST_CHECK_NO_ERRNO_AND_VALUE( + OpenAt(proc_fd.get(), "self/fd", O_DIRECTORY | O_RDONLY | O_CLOEXEC)); + + // Proc fd symlinks can escape the chroot if the fd the symlink refers to + // refers to an object outside the chroot. + struct stat s = {}; + TEST_CHECK_SUCCESS( + fstatat(proc_self_fd.get(), absl::StrCat(fd.get()).c_str(), &s, 0)); + + // Try to stat the stdin fd. Internally, this is handled differently from a + // proc fd entry pointing to a file, since stdin is backed by a host fd, and + // isn't a walkable path on the filesystem inside the sandbox. + TEST_CHECK_SUCCESS(fstatat(proc_self_fd.get(), "0", &s, 0)); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // This test will verify that when you hold a fd to proc before entering @@ -209,28 +223,30 @@ TEST(ChrootTest, ProcMemSelfFdsNoEscapeProcOpen) { const FileDescriptor proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're - // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo. - const std::string fd_path = absl::StrCat("self/fd/", foo.get()); - char buf[1024] = {}; - size_t bytes_read = 0; - ASSERT_THAT(bytes_read = - readlinkat(proc.get(), fd_path.c_str(), buf, sizeof(buf) - 1), - SyscallSucceeds()); - - // The link should resolve to something. - ASSERT_GT(bytes_read, 0); - - // Assert that the link doesn't contain the chroot path and is only /foo. - EXPECT_STREQ(buf, "/foo"); + const auto rest = [&] { + // Create and enter a chroot directory. + const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Open a file inside the chroot at /foo. + const FileDescriptor foo = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); + + // Examine /proc/self/fd/{foo_fd} to see if it exposes the fact that we're + // inside a chroot, the path should be /foo and NOT {chroot_dir}/foo. + const std::string fd_path = absl::StrCat("self/fd/", foo.get()); + char buf[1024] = {}; + size_t bytes_read = 0; + TEST_CHECK_SUCCESS(bytes_read = readlinkat(proc.get(), fd_path.c_str(), buf, + sizeof(buf) - 1)); + + // The link should resolve to something. + TEST_CHECK(bytes_read > 0); + + // Assert that the link doesn't contain the chroot path and is only /foo. + TEST_CHECK(strcmp(buf, "/foo") == 0); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // This test will verify that a file inside a chroot when mmapped will not @@ -242,39 +258,41 @@ TEST(ChrootTest, ProcMemSelfMapsNoEscapeProcOpen) { const FileDescriptor proc = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - // Create and enter a chroot directory. - const auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - ASSERT_THAT(chroot(temp_dir.path().c_str()), SyscallSucceeds()); - - // Open a file inside the chroot at /foo. - const FileDescriptor foo = - ASSERT_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); - - // Mmap the newly created file. - void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, - foo.get(), 0); - ASSERT_THAT(reinterpret_cast<int64_t>(foo_map), SyscallSucceeds()); - - // Always unmap. - auto cleanup_map = Cleanup( - [&] { EXPECT_THAT(munmap(foo_map, kPageSize), SyscallSucceeds()); }); - - // Examine /proc/self/maps to be sure that /foo doesn't appear to be - // mapped with the full chroot path. - const FileDescriptor maps = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), "self/maps", O_RDONLY)); - - size_t bytes_read = 0; - char buf[8 * 1024] = {}; - ASSERT_THAT(bytes_read = ReadFd(maps.get(), buf, sizeof(buf)), - SyscallSucceeds()); - - // The maps file should have something. - ASSERT_GT(bytes_read, 0); - - // Finally we want to make sure the maps don't contain the chroot path - ASSERT_EQ(std::string(buf, bytes_read).find(temp_dir.path()), - std::string::npos); + const auto rest = [&] { + // Create and enter a chroot directory. + const auto temp_dir = TEST_CHECK_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK_SUCCESS(chroot(temp_dir.path().c_str())); + + // Open a file inside the chroot at /foo. + const FileDescriptor foo = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/foo", O_CREAT | O_RDONLY, 0644)); + + // Mmap the newly created file. + void* foo_map = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE, foo.get(), 0); + TEST_CHECK_SUCCESS(reinterpret_cast<int64_t>(foo_map)); + + // Always unmap. + auto cleanup_map = + Cleanup([&] { TEST_CHECK_SUCCESS(munmap(foo_map, kPageSize)); }); + + // Examine /proc/self/maps to be sure that /foo doesn't appear to be + // mapped with the full chroot path. + const FileDescriptor maps = TEST_CHECK_NO_ERRNO_AND_VALUE( + OpenAt(proc.get(), "self/maps", O_RDONLY)); + + size_t bytes_read = 0; + char buf[8 * 1024] = {}; + TEST_CHECK_SUCCESS(bytes_read = ReadFd(maps.get(), buf, sizeof(buf))); + + // The maps file should have something. + TEST_CHECK(bytes_read > 0); + + // Finally we want to make sure the maps don't contain the chroot path + TEST_CHECK(std::string(buf, bytes_read).find(temp_dir.path()) == + std::string::npos); + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Test that mounts outside the chroot will not appear in /proc/self/mounts or @@ -283,81 +301,76 @@ TEST(ChrootTest, ProcMountsMountinfoNoEscape) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); - // We are going to create some mounts and then chroot. In order to be able to - // unmount the mounts after the test run, we must chdir to the root and use - // relative paths for all mounts. That way, as long as we never chdir into - // the new root, we can access the mounts via relative paths and unmount them. - ASSERT_THAT(chdir("/"), SyscallSucceeds()); - - // Create nested tmpfs mounts. Note the use of relative paths in Mount calls. + // Create nested tmpfs mounts. auto const outer_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", outer_dir.path()), "tmpfs", 0, "mode=0700", 0)); + auto const outer_mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("none", outer_dir.path(), "tmpfs", 0, "mode=0700", 0)); auto const inner_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(outer_dir.path())); - auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE(Mount( - "none", JoinPath(".", inner_dir.path()), "tmpfs", 0, "mode=0700", 0)); - - // Filenames that will be checked for mounts, all relative to /proc dir. - std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"}; - - for (const std::string& path : paths) { - // We should have both inner and outer mounts. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path))); - EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()), - HasSubstr(inner_dir.path()))); - // We better have at least two mounts: the mounts we created plus the root. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_GT(submounts.size(), 2); - } - - // Get a FD to /proc before we enter the chroot. - const FileDescriptor proc = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); - - // Chroot to outer mount. - ASSERT_THAT(chroot(outer_dir.path().c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - - // Only two mounts visible from this chroot: the inner and outer. Both - // paths should be relative to the new chroot. - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - EXPECT_THAT(contents, - AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))), - Not(HasSubstr(outer_dir.path())), - Not(HasSubstr(inner_dir.path())))); - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 2); - } - - // Chroot to inner mount. We must use an absolute path accessible to our - // chroot. - const std::string inner_dir_basename = - absl::StrCat("/", Basename(inner_dir.path())); - ASSERT_THAT(chroot(inner_dir_basename.c_str()), SyscallSucceeds()); - - for (const std::string& path : paths) { - const FileDescriptor proc_file = - ASSERT_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); - const std::string contents = - ASSERT_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); - - // Only the inner mount visible from this chroot. - std::vector<absl::string_view> submounts = - absl::StrSplit(contents, '\n', absl::SkipWhitespace()); - EXPECT_EQ(submounts.size(), 1); - } - - // Chroot back to ".". - ASSERT_THAT(chroot("."), SyscallSucceeds()); + auto const inner_mount = ASSERT_NO_ERRNO_AND_VALUE( + Mount("none", inner_dir.path(), "tmpfs", 0, "mode=0700", 0)); + + const auto rest = [&outer_dir, &inner_dir] { + // Filenames that will be checked for mounts, all relative to /proc dir. + std::string paths[3] = {"mounts", "self/mounts", "self/mountinfo"}; + + for (const std::string& path : paths) { + // We should have both inner and outer mounts. + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContents(JoinPath("/proc", path))); + EXPECT_THAT(contents, AllOf(HasSubstr(outer_dir.path()), + HasSubstr(inner_dir.path()))); + // We better have at least two mounts: the mounts we created plus the + // root. + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() > 2); + } + + // Get a FD to /proc before we enter the chroot. + const FileDescriptor proc = + TEST_CHECK_NO_ERRNO_AND_VALUE(Open("/proc", O_RDONLY)); + + // Chroot to outer mount. + TEST_CHECK_SUCCESS(chroot(outer_dir.path().c_str())); + + for (const std::string& path : paths) { + const FileDescriptor proc_file = + TEST_CHECK_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); + + // Only two mounts visible from this chroot: the inner and outer. Both + // paths should be relative to the new chroot. + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); + EXPECT_THAT(contents, + AllOf(HasSubstr(absl::StrCat(Basename(inner_dir.path()))), + Not(HasSubstr(outer_dir.path())), + Not(HasSubstr(inner_dir.path())))); + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() == 2); + } + + // Chroot to inner mount. We must use an absolute path accessible to our + // chroot. + const std::string inner_dir_basename = + absl::StrCat("/", Basename(inner_dir.path())); + TEST_CHECK_SUCCESS(chroot(inner_dir_basename.c_str())); + + for (const std::string& path : paths) { + const FileDescriptor proc_file = + TEST_CHECK_NO_ERRNO_AND_VALUE(OpenAt(proc.get(), path, O_RDONLY)); + const std::string contents = + TEST_CHECK_NO_ERRNO_AND_VALUE(GetContentsFD(proc_file.get())); + + // Only the inner mount visible from this chroot. + std::vector<absl::string_view> submounts = + absl::StrSplit(contents, '\n', absl::SkipWhitespace()); + TEST_CHECK(submounts.size() == 1); + } + }; + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } } // namespace diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc index 4f773bc75..ba4e13fb9 100644 --- a/test/syscalls/linux/dup.cc +++ b/test/syscalls/linux/dup.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -44,14 +45,6 @@ PosixErrorOr<FileDescriptor> Dup3(const FileDescriptor& fd, int target_fd, return FileDescriptor(new_fd); } -void CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) { - struct stat stat_result1, stat_result2; - ASSERT_THAT(fstat(fd1.get(), &stat_result1), SyscallSucceeds()); - ASSERT_THAT(fstat(fd2.get(), &stat_result2), SyscallSucceeds()); - EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev); - EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino); -} - TEST(DupTest, Dup) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -59,7 +52,7 @@ TEST(DupTest, Dup) { // Dup the descriptor and make sure it's the same file. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); } TEST(DupTest, DupClearsCloExec) { @@ -70,10 +63,24 @@ TEST(DupTest, DupClearsCloExec) { // Duplicate the descriptor. Ensure that it doesn't have FD_CLOEXEC set. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); } +TEST(DupTest, DupWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Dup the descriptor and make sure it's the same file. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + TEST(DupTest, Dup2) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -82,13 +89,13 @@ TEST(DupTest, Dup2) { FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); // Dup over the file above. int target_fd = nfd.release(); FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd)); EXPECT_EQ(target_fd, nfd2.get()); - CheckSameFile(fd, nfd2); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2)); } TEST(DupTest, Dup2SameFD) { @@ -99,6 +106,28 @@ TEST(DupTest, Dup2SameFD) { ASSERT_THAT(dup2(fd.get(), fd.get()), SyscallSucceedsWithValue(fd.get())); } +TEST(DupTest, Dup2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Regular dup once. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); + + // Dup over the file above. + int target_fd = nfd.release(); + FileDescriptor nfd2 = ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, target_fd)); + EXPECT_EQ(target_fd, nfd2.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2)); + EXPECT_THAT(fcntl(nfd2.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + TEST(DupTest, Dup3) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); @@ -106,16 +135,16 @@ TEST(DupTest, Dup3) { // Regular dup once. FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); ASSERT_NE(fd.get(), nfd.get()); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); // Dup over the file above, check that it has no CLOEXEC. nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0)); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); // Dup over the file again, check that it does not CLOEXEC. nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC)); - CheckSameFile(fd, nfd); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); } @@ -127,6 +156,32 @@ TEST(DupTest, Dup3FailsSameFD) { ASSERT_THAT(dup3(fd.get(), fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(DupTest, Dup3WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + int flags; + ASSERT_THAT(flags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + + // Regular dup once. + FileDescriptor nfd = ASSERT_NO_ERRNO_AND_VALUE(fd.Dup()); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + + // Dup over the file above, check that it has no CLOEXEC. + nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), 0)); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); + + // Dup over the file again, check that it does not CLOEXEC. + nfd = ASSERT_NO_ERRNO_AND_VALUE(Dup3(fd, nfd.release(), O_CLOEXEC)); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(flags)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/fadvise64.cc b/test/syscalls/linux/fadvise64.cc index 2af7aa6d9..ac24c4066 100644 --- a/test/syscalls/linux/fadvise64.cc +++ b/test/syscalls/linux/fadvise64.cc @@ -45,6 +45,17 @@ TEST(FAdvise64Test, Basic) { SyscallSucceeds()); } +TEST(FAdvise64Test, FAdvise64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL), + SyscallFailsWithErrno(EBADF)); + ASSERT_THAT(syscall(__NR_fadvise64, fd.get(), 0, 10, POSIX_FADV_NORMAL), + SyscallFailsWithErrno(EBADF)); +} + TEST(FAdvise64Test, InvalidArgs) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index edd23e063..5c839447e 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -108,6 +108,13 @@ TEST_F(AllocateTest, FallocateReadonly) { EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF)); } +TEST_F(AllocateTest, FallocateWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + EXPECT_THAT(fallocate(fd.get(), 0, 0, 10), SyscallFailsWithErrno(EBADF)); +} + TEST_F(AllocateTest, FallocatePipe) { int pipes[2]; EXPECT_THAT(pipe(pipes), SyscallSucceeds()); diff --git a/test/syscalls/linux/fchdir.cc b/test/syscalls/linux/fchdir.cc index 08bcae1e8..c6675802d 100644 --- a/test/syscalls/linux/fchdir.cc +++ b/test/syscalls/linux/fchdir.cc @@ -71,6 +71,18 @@ TEST(FchdirTest, NotDir) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST(FchdirTest, FchdirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(temp_dir.path(), O_PATH)); + ASSERT_THAT(open(temp_dir.path().c_str(), O_DIRECTORY | O_PATH), + SyscallSucceeds()); + + EXPECT_THAT(fchdir(fd.get()), SyscallSucceeds()); + // Change CWD to a permanent location as temp dirs will be cleaned up. + EXPECT_THAT(chdir("/"), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 75a5c9f17..4fa6751ff 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -207,6 +207,41 @@ PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write, return std::move(cleanup); } +TEST(FcntlTest, FcntlDupWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_PATH)); + + int new_fd; + // Dup the descriptor and make sure it's the same file. + EXPECT_THAT(new_fd = fcntl(fd.get(), F_DUPFD, 0), SyscallSucceeds()); + + FileDescriptor nfd = FileDescriptor(new_fd); + ASSERT_NE(fd.get(), nfd.get()); + ASSERT_NO_ERRNO(CheckSameFile(fd, nfd)); + EXPECT_THAT(fcntl(nfd.get(), F_GETFL), SyscallSucceedsWithValue(O_PATH)); +} + +TEST(FcntlTest, SetFileStatusFlagWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + EXPECT_THAT(fcntl(fd.get(), F_SETFL, 0), SyscallFailsWithErrno(EBADF)); +} + +TEST(FcntlTest, BadFcntlsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + EXPECT_THAT(fcntl(fd.get(), F_SETOWN, 0), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd.get(), F_GETOWN, 0), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(fcntl(fd.get(), F_SETOWN_EX, 0), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(fcntl(fd.get(), F_GETOWN_EX, 0), SyscallFailsWithErrno(EBADF)); +} + TEST(FcntlTest, SetCloExecBadFD) { // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set. FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); @@ -226,6 +261,32 @@ TEST(FcntlTest, SetCloExec) { ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); } +TEST(FcntlTest, SetCloExecWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Open a file descriptor with FD_CLOEXEC descriptor flag not set. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(0)); + + // Set the FD_CLOEXEC flag. + ASSERT_THAT(fcntl(fd.get(), F_SETFD, FD_CLOEXEC), SyscallSucceeds()); + ASSERT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC)); +} + +TEST(FcntlTest, DupFDCloExecWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + // Open a file descriptor with FD_CLOEXEC descriptor flag not set. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + int nfd; + ASSERT_THAT(nfd = fcntl(fd.get(), F_DUPFD_CLOEXEC, 0), SyscallSucceeds()); + FileDescriptor dup_fd(nfd); + + // Check for the FD_CLOEXEC flag. + ASSERT_THAT(fcntl(dup_fd.get(), F_GETFD), + SyscallSucceedsWithValue(FD_CLOEXEC)); +} + TEST(FcntlTest, ClearCloExec) { // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag set. FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_CLOEXEC)); @@ -267,6 +328,22 @@ TEST(FcntlTest, GetAllFlags) { EXPECT_EQ(rflags, expected); } +// When O_PATH is specified in flags, flag bits other than O_CLOEXEC, +// O_DIRECTORY, and O_NOFOLLOW are ignored. +TEST(FcntlTest, GetOpathFlag) { + SKIP_IF(IsRunningWithVFS1()); + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + int flags = O_RDWR | O_DIRECT | O_SYNC | O_NONBLOCK | O_APPEND | O_PATH | + O_NOFOLLOW | O_DIRECTORY; + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), flags)); + + int expected = O_PATH | O_NOFOLLOW | O_DIRECTORY; + + int rflags; + EXPECT_THAT(rflags = fcntl(fd.get(), F_GETFL), SyscallSucceeds()); + EXPECT_EQ(rflags, expected); +} + TEST(FcntlTest, SetFlags) { TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), 0)); @@ -395,6 +472,22 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsRead) { EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl1), SyscallFailsWithErrno(EBADF)); } +TEST_F(FcntlLockTest, SetLockWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + struct flock fl0; + fl0.l_type = F_WRLCK; + fl0.l_whence = SEEK_SET; + fl0.l_start = 0; + fl0.l_len = 0; // Lock all file + + // Expect that setting a write lock using a Opath file descriptor + // won't work. + EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl0), SyscallFailsWithErrno(EBADF)); +} + TEST_F(FcntlLockTest, SetLockUnlockOnNothing) { auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index 93c692dd6..2f2b14037 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -429,6 +429,32 @@ TYPED_TEST(GetdentsTest, NotDir) { SyscallFailsWithErrno(ENOTDIR)); } +// Test that getdents returns EBADF when called on an opath file. +TYPED_TEST(GetdentsTest, OpathFile) { + SKIP_IF(IsRunningWithVFS1()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + typename TestFixture::DirentBufferType dirents(256); + EXPECT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), + dirents.Size()), + SyscallFailsWithErrno(EBADF)); +} + +// Test that getdents returns EBADF when called on an opath directory. +TYPED_TEST(GetdentsTest, OpathDirectory) { + SKIP_IF(IsRunningWithVFS1()); + + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_PATH | O_DIRECTORY)); + + typename TestFixture::DirentBufferType dirents(256); + ASSERT_THAT(RetryEINTR(syscall)(this->SyscallNum(), fd.get(), dirents.Data(), + dirents.Size()), + SyscallFailsWithErrno(EBADF)); +} + // Test that SEEK_SET to 0 causes getdents to re-read the entries. TYPED_TEST(GetdentsTest, SeekResetsCursor) { // . and .. should be in an otherwise empty directory. diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc index 0e51d42a8..e84cbfdc3 100644 --- a/test/syscalls/linux/getrusage.cc +++ b/test/syscalls/linux/getrusage.cc @@ -23,6 +23,7 @@ #include "absl/time/time.h" #include "test/util/logging.h" #include "test/util/memory_util.h" +#include "test/util/multiprocess_util.h" #include "test/util/signal_util.h" #include "test/util/test_util.h" @@ -93,59 +94,66 @@ TEST(GetrusageTest, Grandchild) { // Verifies that processes ignoring SIGCHLD do not have updated child maxrss // updated. TEST(GetrusageTest, IgnoreSIGCHLD) { - struct sigaction sa; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); - pid_t pid = fork(); - if (pid == 0) { + const auto rest = [] { + struct sigaction sa; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + auto cleanup = TEST_CHECK_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGCHLD, sa)); + pid_t pid = fork(); + if (pid == 0) { + struct rusage rusage_self; + TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + // The child has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss != 0); + _exit(0); + } + TEST_CHECK_SUCCESS(pid); + int status; + TEST_CHECK_ERRNO(RetryEINTR(waitpid)(pid, &status, 0), ECHILD); struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - int status; - ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), - SyscallFailsWithErrno(ECHILD)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child's maxrss should not have propagated up. - EXPECT_EQ(rusage_children.ru_maxrss, 0); + TEST_CHECK_SUCCESS(getrusage(RUSAGE_SELF, &rusage_self)); + struct rusage rusage_children; + TEST_CHECK_SUCCESS(getrusage(RUSAGE_CHILDREN, &rusage_children)); + // The parent has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss > 0); + // The child's maxrss should not have propagated up. + TEST_CHECK(rusage_children.ru_maxrss == 0); + }; + // Execute inside a forked process so that rusage_children is clean. + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } // Verifies that zombie processes do not update their parent's maxrss. Only // reaped processes should do this. TEST(GetrusageTest, IgnoreZombie) { - pid_t pid = fork(); - if (pid == 0) { + const auto rest = [] { + pid_t pid = fork(); + if (pid == 0) { + struct rusage rusage_self; + TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + struct rusage rusage_children; + TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); + // The child has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss != 0); + // The child has no children of its own. + TEST_CHECK(rusage_children.ru_maxrss == 0); + _exit(0); + } + TEST_CHECK_SUCCESS(pid); + // Give the child time to exit. Because we don't call wait, the child should + // remain a zombie. + absl::SleepFor(absl::Seconds(5)); struct rusage rusage_self; - TEST_PCHECK(getrusage(RUSAGE_SELF, &rusage_self) == 0); + TEST_CHECK_SUCCESS(getrusage(RUSAGE_SELF, &rusage_self)); struct rusage rusage_children; - TEST_PCHECK(getrusage(RUSAGE_CHILDREN, &rusage_children) == 0); - // The child has consumed some memory. - TEST_CHECK(rusage_self.ru_maxrss != 0); - // The child has no children of its own. + TEST_CHECK_SUCCESS(getrusage(RUSAGE_CHILDREN, &rusage_children)); + // The parent has consumed some memory. + TEST_CHECK(rusage_self.ru_maxrss > 0); + // The child has consumed some memory, but hasn't been reaped. TEST_CHECK(rusage_children.ru_maxrss == 0); - _exit(0); - } - ASSERT_THAT(pid, SyscallSucceeds()); - // Give the child time to exit. Because we don't call wait, the child should - // remain a zombie. - absl::SleepFor(absl::Seconds(5)); - struct rusage rusage_self; - ASSERT_THAT(getrusage(RUSAGE_SELF, &rusage_self), SyscallSucceeds()); - struct rusage rusage_children; - ASSERT_THAT(getrusage(RUSAGE_CHILDREN, &rusage_children), SyscallSucceeds()); - // The parent has consumed some memory. - EXPECT_GT(rusage_self.ru_maxrss, 0); - // The child has consumed some memory, but hasn't been reaped. - EXPECT_EQ(rusage_children.ru_maxrss, 0); + }; + // Execute inside a forked process so that rusage_children is clean. + EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); } TEST(GetrusageTest, Wait4) { diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc index b0a07a064..9b16d1558 100644 --- a/test/syscalls/linux/ioctl.cc +++ b/test/syscalls/linux/ioctl.cc @@ -76,6 +76,19 @@ TEST_F(IoctlTest, InvalidControlNumber) { EXPECT_THAT(ioctl(STDOUT_FILENO, 0), SyscallFailsWithErrno(ENOTTY)); } +TEST_F(IoctlTest, IoctlWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/null", O_PATH)); + + int set = 1; + EXPECT_THAT(ioctl(fd.get(), FIONBIO, &set), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(ioctl(fd.get(), FIONCLEX), SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(ioctl(fd.get(), FIOCLEX), SyscallFailsWithErrno(EBADF)); +} + TEST_F(IoctlTest, FIONBIOSucceeds) { EXPECT_FALSE(CheckNonBlocking(fd())); int set = 1; diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc index 544681168..4f9ca1a65 100644 --- a/test/syscalls/linux/link.cc +++ b/test/syscalls/linux/link.cc @@ -50,6 +50,8 @@ bool IsSameFile(const std::string& f1, const std::string& f2) { return stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino; } +// TODO(b/178640646): Add test for linkat with AT_EMPTY_PATH + TEST(LinkTest, CanCreateLinkFile) { auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const std::string newname = NewTempAbsPath(); @@ -235,6 +237,59 @@ TEST(LinkTest, AbsPathsWithNonDirFDs) { SyscallSucceeds()); } +TEST(LinkTest, NewDirFDWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string newname_parent = NewTempAbsPath(); + const std::string newname_base = "child"; + const std::string newname = JoinPath(newname_parent, newname_base); + + // Create newname_parent directory, and get an FD. + EXPECT_THAT(mkdir(newname_parent.c_str(), 0777), SyscallSucceeds()); + const FileDescriptor newname_parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(newname_parent, O_DIRECTORY | O_PATH)); + + // Link newname to oldfile, using newname_parent_fd. + EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), newname_parent_fd.get(), + newname.c_str(), 0), + SyscallSucceeds()); + + EXPECT_TRUE(IsSameFile(oldfile.path(), newname)); +} + +TEST(LinkTest, RelPathsNonDirFDsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + // Create a file that will be passed as the directory fd for old/new names. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + // Using file_fd as olddirfd will fail. + EXPECT_THAT(linkat(file_fd.get(), "foo", AT_FDCWD, "bar", 0), + SyscallFailsWithErrno(ENOTDIR)); + + // Using file_fd as newdirfd will fail. + EXPECT_THAT(linkat(AT_FDCWD, oldfile.path().c_str(), file_fd.get(), "bar", 0), + SyscallFailsWithErrno(ENOTDIR)); +} + +TEST(LinkTest, AbsPathsNonDirFDsWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + + auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string newname = NewTempAbsPath(); + + // Create a file that will be passed as the directory fd for old/new names. + TempPath path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_PATH)); + + // Using file_fd as the dirfds is OK as long as paths are absolute. + EXPECT_THAT(linkat(file_fd.get(), oldfile.path().c_str(), file_fd.get(), + newname.c_str(), 0), + SyscallSucceeds()); +} + TEST(LinkTest, LinkDoesNotFollowSymlinks) { // Create oldfile, and oldsymlink which points to it. auto oldfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc index 5a1973f60..6e714b12c 100644 --- a/test/syscalls/linux/madvise.cc +++ b/test/syscalls/linux/madvise.cc @@ -179,9 +179,9 @@ TEST(MadviseDontforkTest, DontforkShared) { // First page is mapped in child and modifications are visible to parent // via the shared mapping. TEST_CHECK(IsMapped(ms1.addr())); - ExpectAllMappingBytes(ms1, 2); + CheckAllMappingBytes(ms1, 2); memset(ms1.ptr(), 1, kPageSize); - ExpectAllMappingBytes(ms1, 1); + CheckAllMappingBytes(ms1, 1); // Second page must not be mapped in child. TEST_CHECK(!IsMapped(ms2.addr())); @@ -222,9 +222,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) { // page. The mapping is private so the modifications are not visible to // the parent. TEST_CHECK(IsMapped(mp1.addr())); - ExpectAllMappingBytes(mp1, 1); + CheckAllMappingBytes(mp1, 1); memset(mp1.ptr(), 11, kPageSize); - ExpectAllMappingBytes(mp1, 11); + CheckAllMappingBytes(mp1, 11); // Verify second page is not mapped. TEST_CHECK(!IsMapped(mp2.addr())); @@ -233,9 +233,9 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) { // page. The mapping is private so the modifications are not visible to // the parent. TEST_CHECK(IsMapped(mp3.addr())); - ExpectAllMappingBytes(mp3, 3); + CheckAllMappingBytes(mp3, 3); memset(mp3.ptr(), 13, kPageSize); - ExpectAllMappingBytes(mp3, 13); + CheckAllMappingBytes(mp3, 13); }; EXPECT_THAT(InForkedProcess(rest), IsPosixErrorOkAndHolds(0)); diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 83546830d..93a6d9cde 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -930,6 +930,18 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) { SyscallFailsWithErrno(EACCES)); } +// Mmap not allowed on O_PATH FDs. +TEST_F(MMapFileTest, MmapFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + uintptr_t addr; + EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0), + SyscallFailsWithErrno(EBADF)); +} + // The FD must be readable. TEST_P(MMapFileParamTest, WriteOnlyFd) { const FileDescriptor fd = diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index fcd162ca2..e65ffee8f 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -45,7 +45,7 @@ namespace { // * O_CREAT // * O_DIRECTORY // * O_NOFOLLOW -// * O_PATH <- Will we ever support this? +// * O_PATH // // Special operations on open: // * O_EXCL @@ -75,55 +75,52 @@ class OpenTest : public FileTest { }; TEST_F(OpenTest, OTrunc) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, OTruncAndReadOnlyDir) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, OTruncAndReadOnlyFile) { - auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - const FileDescriptor existing = - ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666)); - const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE( - Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + EXPECT_NO_ERRNO(Open(path, O_TRUNC | O_RDONLY, 0666)); } TEST_F(OpenTest, OCreateDirectory) { SKIP_IF(IsRunningWithVFS1()); - auto dirpath = GetAbsoluteTestTmpdir(); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Normal case: existing directory. - ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(dir.path().c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // Trailing separator on existing directory. - ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(dir.path().append("/").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // Trailing separator on non-existing directory. - ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(), + ASSERT_THAT(open(JoinPath(dir.path(), "non-existent").append("/").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); // "." special case. - ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666), + ASSERT_THAT(open(JoinPath(dir.path(), ".").c_str(), O_RDWR | O_CREAT, 0666), SyscallFailsWithErrno(EISDIR)); } TEST_F(OpenTest, MustCreateExisting) { - auto dirPath = GetAbsoluteTestTmpdir(); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Existing directory. - ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + ASSERT_THAT(open(dir.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), SyscallFailsWithErrno(EEXIST)); // Existing file. - auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath)); + auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), SyscallFailsWithErrno(EEXIST)); } @@ -206,7 +203,8 @@ TEST_F(OpenTest, AtAbsPath) { } TEST_F(OpenTest, OpenNoFollowSymlink) { - const std::string link_path = JoinPath(GetAbsoluteTestTmpdir(), "link"); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string link_path = JoinPath(dir.path().c_str(), "link"); ASSERT_THAT(symlink(test_file_name_.c_str(), link_path.c_str()), SyscallSucceeds()); auto cleanup = Cleanup([link_path]() { @@ -227,8 +225,7 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { // // We will then open tmp_folder/sym_folder/file with O_NOFOLLOW and it // should succeed as O_NOFOLLOW only applies to the final path component. - auto tmp_path = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto tmp_path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto sym_path = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), tmp_path.path())); auto file_path = @@ -246,8 +243,7 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) { // // open("root/child/symlink/root/child/file") TEST_F(OpenTest, SymlinkRecurse) { - auto root = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir())); + auto root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path())); auto symlink = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(child.path(), "../..")); @@ -481,12 +477,8 @@ TEST_F(OpenTest, CanTruncateWithStrangePermissions) { ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); const DisableSave ds; // Permissions are dropped. std::string path = NewTempAbsPath(); - int fd; // Create a file without user permissions. - EXPECT_THAT( // SAVE_BELOW - fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + EXPECT_NO_ERRNO(Open(path, O_CREAT | O_TRUNC | O_WRONLY, 055)); // Cannot open file because we are owner and have no permissions set. EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); @@ -495,8 +487,7 @@ TEST_F(OpenTest, CanTruncateWithStrangePermissions) { EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds()); // Now we can open the file again. - EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + EXPECT_NO_ERRNO(Open(path, O_RDWR)); } TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) { @@ -517,6 +508,26 @@ TEST_F(OpenTest, OpenWithStrangeFlags) { EXPECT_THAT(read(fd.get(), &c, 1), SyscallFailsWithErrno(EBADF)); } +TEST_F(OpenTest, OpenWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + const DisableSave ds; // Permissions are dropped. + std::string path = NewTempAbsPath(); + + // Create a file without user permissions. + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055)); + + // Cannot open file as read only because we are owner and have no permissions + // set. + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + + // Can open file with O_PATH because don't need permissions on the object when + // opening with O_PATH. + ASSERT_NO_ERRNO(Open(path, O_PATH)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 9d63782fb..f8fbea79e 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -22,6 +22,7 @@ #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/temp_umask.h" #include "test/util/test_util.h" @@ -31,85 +32,60 @@ namespace testing { namespace { TEST(CreateTest, TmpFile) { - int fd; - EXPECT_THAT(fd = open(JoinPath(GetAbsoluteTestTmpdir(), "a").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "a"), O_RDWR | O_CREAT, 0666)); } TEST(CreateTest, ExistingFile) { - int fd; - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "ExistingFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "ExistingFile"); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + EXPECT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); } TEST(CreateTest, CreateAtFile) { - int dirfd; - EXPECT_THAT(dirfd = open(GetAbsoluteTestTmpdir().c_str(), O_DIRECTORY, 0666), - SyscallSucceeds()); - EXPECT_THAT(openat(dirfd, "CreateAtFile", O_RDWR | O_CREAT, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto dirfd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY, 0666)); + EXPECT_THAT(openat(dirfd.get(), "CreateAtFile", O_RDWR | O_CREAT, 0666), SyscallSucceeds()); - EXPECT_THAT(close(dirfd), SyscallSucceeds()); } TEST(CreateTest, HonorsUmask_NoRandomSave) { const DisableSave ds; // file cannot be re-opened as writable. + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); TempUmask mask(0222); - int fd; - ASSERT_THAT( - fd = open(JoinPath(GetAbsoluteTestTmpdir(), "UmaskedFile").c_str(), - O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); + auto fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(dir.path(), "UmaskedFile"), O_RDWR | O_CREAT, 0666)); struct stat statbuf; - ASSERT_THAT(fstat(fd, &statbuf), SyscallSucceeds()); + ASSERT_THAT(fstat(fd.get(), &statbuf), SyscallSucceeds()); EXPECT_EQ(0444, statbuf.st_mode & 0777); - EXPECT_THAT(close(fd), SyscallSucceeds()); } TEST(CreateTest, CreateExclusively) { - std::string filename = NewTempAbsPath(); - - int fd; - ASSERT_THAT(fd = open(filename.c_str(), O_CREAT | O_RDWR, 0644), - SyscallSucceeds()); - EXPECT_THAT(close(fd), SyscallSucceeds()); - - EXPECT_THAT(open(filename.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + EXPECT_NO_ERRNO(Open(path, O_CREAT | O_RDWR, 0644)); + EXPECT_THAT(open(path.c_str(), O_CREAT | O_EXCL | O_RDWR, 0644), SyscallFailsWithErrno(EEXIST)); } TEST(CreateTest, CreatWithOTrunc) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_CREAT | O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } TEST(CreateTest, CreatDirWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); - ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(open(dir.path().c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { - std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); - int dirfd; - ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), - SyscallSucceeds()); - ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), - SyscallSucceeds()); - ASSERT_THAT(close(dirfd), SyscallSucceeds()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto path = JoinPath(dir.path(), "foo"); + ASSERT_NO_ERRNO(Open(path, O_RDWR | O_CREAT, 0666)); + ASSERT_NO_ERRNO(Open(path, O_CREAT | O_TRUNC | O_RDONLY, 0666)); } TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc index a9bfdb37b..999c8ab6b 100644 --- a/test/syscalls/linux/ping_socket.cc +++ b/test/syscalls/linux/ping_socket.cc @@ -31,51 +31,36 @@ namespace gvisor { namespace testing { namespace { -class PingSocket : public ::testing::Test { - protected: - // Creates a socket to be used in tests. - void SetUp() override; - - // Closes the socket created by SetUp(). - void TearDown() override; - - // The loopback address. - struct sockaddr_in addr_; -}; - -void PingSocket::SetUp() { - // On some hosts ping sockets are restricted to specific groups using the - // sysctl "ping_group_range". - int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP); - if (s < 0 && errno == EPERM) { - GTEST_SKIP(); - } - close(s); - - addr_ = {}; - // Just a random port as the destination port number is irrelevant for ping - // sockets. - addr_.sin_port = 12345; - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; -} - -void PingSocket::TearDown() {} - // Test ICMP port exhaustion returns EAGAIN. // // We disable both random/cooperative S/R for this test as it makes way too many // syscalls. -TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) { +TEST(PingSocket, ICMPPortExhaustion_NoRandomSave) { DisableSave ds; + + { + auto s = Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP); + if (!s.ok()) { + ASSERT_EQ(s.error().errno_value(), EACCES); + GTEST_SKIP(); + } + } + + const struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_addr = + { + .s_addr = htonl(INADDR_LOOPBACK), + }, + }; + std::vector<FileDescriptor> sockets; constexpr int kSockets = 65536; - addr_.sin_port = 0; for (int i = 0; i < kSockets; i++) { auto s = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); - int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_), - sizeof(addr_)); + int ret = connect(s.get(), reinterpret_cast<const struct sockaddr*>(&addr), + sizeof(addr)); if (ret == 0) { sockets.push_back(std::move(s)); continue; diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc index bcdbbb044..c74990ba1 100644 --- a/test/syscalls/linux/pread64.cc +++ b/test/syscalls/linux/pread64.cc @@ -77,6 +77,16 @@ TEST_F(Pread64Test, WriteOnlyNotReadable) { EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF)); } +TEST_F(Pread64Test, Pread64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + char buf[1024]; + EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallFailsWithErrno(EBADF)); +} + TEST_F(Pread64Test, DirNotReadable) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(GetAbsoluteTestTmpdir(), O_RDONLY)); diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index 5b0743fe9..1c40f0915 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -89,6 +89,20 @@ TEST(PreadvTest, MMConcurrencyStress) { // The test passes if it neither deadlocks nor crashes the OS. } +// This test calls preadv with an O_PATH fd. +TEST(PreadvTest, PreadvWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + struct iovec iov; + iov.iov_base = nullptr; + iov.iov_len = 0; + + EXPECT_THAT(preadv(fd.get(), &iov, 1, 0), SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc index 4a9acd7ae..cb58719c4 100644 --- a/test/syscalls/linux/preadv2.cc +++ b/test/syscalls/linux/preadv2.cc @@ -226,6 +226,24 @@ TEST(Preadv2Test, TestUnreadableFile) { SyscallFailsWithErrno(EBADF)); } +// This test calls preadv2 with a file opened with O_PATH. +TEST(Preadv2Test, Preadv2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(preadv2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + auto iov = absl::make_unique<struct iovec[]>(1); + iov[0].iov_base = nullptr; + iov[0].iov_len = 0; + + EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/0, + /*flags=*/0), + SyscallFailsWithErrno(EBADF)); +} + // Calling preadv2 with a non-negative offset calls preadv. Calling preadv with // an unseekable file is not allowed. A pipe is used for an unseekable file. TEST(Preadv2Test, TestUnseekableFileInvalid) { diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 662c6feb2..d61d94309 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -18,6 +18,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/test_util.h" @@ -341,6 +342,8 @@ TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { int clientfd; ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), SyscallSucceeds()); + auto cleanup = Cleanup( + [clientfd]() { ASSERT_THAT(close(clientfd), SyscallSucceeds()); }); // Find the entry for the accepted socket. UDS proc entries don't have a // remote address, so we distinguish the accepted socket from the listen diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 0b174e2be..85ff258df 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1338,6 +1338,7 @@ TEST_F(JobControlTest, SetTTYDifferentSession) { TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); TEST_PCHECK(gcwstatus == 0); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTY) { @@ -1515,7 +1516,8 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - creates a child process in a new process group // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. -TEST_F(JobControlTest, SetForegroundProcessGroup) { +// TODO(gvisor.dev/issue/5357): Fix and enable. +TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroup) { auto res = RunInChild([=]() { TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); @@ -1557,6 +1559,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) { TEST_PCHECK(pgid = getpgid(0) == 0); TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { @@ -1576,8 +1579,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { ASSERT_NO_ERRNO(ret); } -TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - auto ret = RunInChild([=]() { +// TODO(gvisor.dev/issue/5357): Fix and enable. +TEST_F(JobControlTest, DISABLED_SetForegroundProcessGroupEmptyProcessGroup) { + auto res = RunInChild([=]() { TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); // Create a new process, put it in a new process group, make that group the @@ -1595,6 +1599,7 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && errno == ESRCH); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc index e69794910..1b2f25363 100644 --- a/test/syscalls/linux/pwrite64.cc +++ b/test/syscalls/linux/pwrite64.cc @@ -77,6 +77,17 @@ TEST_F(Pwrite64, Overflow) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(Pwrite64, Pwrite64WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + std::vector<char> buf(1); + EXPECT_THAT(PwriteFd(fd.get(), buf.data(), 1, 0), + SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index 63b686c62..00aed61b4 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -283,6 +283,23 @@ TEST(Pwritev2Test, ReadOnlyFile) { SyscallFailsWithErrno(EBADF)); } +TEST(Pwritev2Test, Pwritev2WithOpath) { + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + char buf[16]; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + + EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, /*offset=*/0, /*flags=*/0), + SyscallFailsWithErrno(EBADF)); +} + // This test calls pwritev2 with an invalid flag. TEST(Pwritev2Test, InvalidFlag) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc index 2633ba31b..98d5e432d 100644 --- a/test/syscalls/linux/read.cc +++ b/test/syscalls/linux/read.cc @@ -112,6 +112,15 @@ TEST_F(ReadTest, ReadDirectoryFails) { EXPECT_THAT(ReadFd(file.get(), buf.data(), 1), SyscallFailsWithErrno(EISDIR)); } +TEST_F(ReadTest, ReadWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + std::vector<char> buf(1); + EXPECT_THAT(ReadFd(fd.get(), buf.data(), 1), SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index baaf9f757..86808d255 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -251,6 +251,20 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { SyscallFailsWithErrno(EFAULT)); } +TEST_F(ReadvTest, ReadvWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + char buffer[1024]; + struct iovec iov[1]; + iov[0].iov_base = buffer; + iov[0].iov_len = 1024; + + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + + ASSERT_THAT(readv(fd.get(), iov, 1), SyscallFailsWithErrno(EBADF)); +} + // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { diff --git a/test/syscalls/linux/setgid.cc b/test/syscalls/linux/setgid.cc new file mode 100644 index 000000000..bfd91ba4f --- /dev/null +++ b/test/syscalls/linux/setgid.cc @@ -0,0 +1,370 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <limits.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kDirmodeMask = 07777; +constexpr int kDirmodeSgid = S_ISGID | 0777; +constexpr int kDirmodeNoExec = S_ISGID | 0767; +constexpr int kDirmodeNoSgid = 0777; + +// Sets effective GID and returns a Cleanup that restores the original. +PosixErrorOr<Cleanup> Setegid(gid_t egid) { + gid_t old_gid = getegid(); + if (setegid(egid) < 0) { + return PosixError(errno, absl::StrFormat("setegid(%d)", egid)); + } + return Cleanup( + [old_gid]() { EXPECT_THAT(setegid(old_gid), SyscallSucceeds()); }); +} + +// Returns a pair of groups that the user is a member of. +PosixErrorOr<std::pair<gid_t, gid_t>> Groups() { + // See whether the user is a member of at least 2 groups. + std::vector<gid_t> groups(64); + for (; groups.size() <= NGROUPS_MAX; groups.resize(groups.size() * 2)) { + int ngroups = getgroups(groups.size(), groups.data()); + if (ngroups < 0 && errno == EINVAL) { + // Need a larger list. + continue; + } + if (ngroups < 0) { + return PosixError(errno, absl::StrFormat("getgroups(%d, %p)", + groups.size(), groups.data())); + } + if (ngroups >= 2) { + return std::pair<gid_t, gid_t>(groups[0], groups[1]); + } + // There aren't enough groups. + break; + } + + // If we're root in the root user namespace, we can set our GID to whatever we + // want. Try that before giving up. + constexpr gid_t kGID1 = 1111; + constexpr gid_t kGID2 = 2222; + auto cleanup1 = Setegid(kGID1); + if (!cleanup1.ok()) { + return cleanup1.error(); + } + auto cleanup2 = Setegid(kGID2); + if (!cleanup2.ok()) { + return cleanup2.error(); + } + return std::pair<gid_t, gid_t>(kGID1, kGID2); +} + +class SetgidDirTest : public ::testing::Test { + protected: + void SetUp() override { + original_gid_ = getegid(); + + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETGID))); + + temp_dir_ = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + groups_ = ASSERT_NO_ERRNO_AND_VALUE(Groups()); + } + + void TearDown() override { + ASSERT_THAT(setegid(original_gid_), SyscallSucceeds()); + } + + void MkdirAsGid(gid_t gid, const std::string& path, mode_t mode) { + auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(Setegid(gid)); + ASSERT_THAT(mkdir(path.c_str(), mode), SyscallSucceeds()); + } + + PosixErrorOr<struct stat> Stat(const std::string& path) { + struct stat stats; + if (stat(path.c_str(), &stats) < 0) { + return PosixError(errno, absl::StrFormat("stat(%s, _)", path)); + } + return stats; + } + + PosixErrorOr<struct stat> Stat(const FileDescriptor& fd) { + struct stat stats; + if (fstat(fd.get(), &stats) < 0) { + return PosixError(errno, "fstat(_, _)"); + } + return stats; + } + + TempPath temp_dir_; + std::pair<gid_t, gid_t> groups_; + gid_t original_gid_; +}; + +// The control test. Files created with a given GID are owned by that group. +TEST_F(SetgidDirTest, Control) { + // Set group to G1 and create a directory. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, 0777)); + + // Set group to G2, create a file in g1owned, and confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2owned").c_str(), O_CREAT | O_RDWR, 0777)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Setgid directories cause created files to inherit GID. +TEST_F(SetgidDirTest, CreateFile) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Set group to G2, create a file, and confirm that G1 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); +} + +// Setgid directories cause created directories to inherit GID. +TEST_F(SetgidDirTest, CreateDir) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G1 owns it, and that the + // setgid bit is enabled. + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & S_ISGID, S_ISGID); +} + +// Setgid directories with group execution disabled still cause GID inheritance. +TEST_F(SetgidDirTest, NoGroupExec) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoExec)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoExec), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G2 owns it, and that the + // setgid bit is enabled. + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & S_ISGID, S_ISGID); +} + +// Setting the setgid bit on directories with an existing file does not change +// the file's group. +TEST_F(SetgidDirTest, OldFile) { + // Set group to G1 and create a directory. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); + + // Set group to G2, create a file, confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "g2created").c_str(), O_CREAT | O_RDWR, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + + // Enable setgid. + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Confirm that the file's group is still G2. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Setting the setgid bit on directories with an existing subdirectory does not +// change the subdirectory's group. +TEST_F(SetgidDirTest, OldDir) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoSgid)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoSgid), SyscallSucceeds()); + + // Set group to G2, create a directory, confirm that G2 owns it. + ASSERT_THAT(setegid(groups_.second), SyscallSucceeds()); + auto g2created = JoinPath(g1owned, "g2created"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.second, g2created, 0666)); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.second); + + // Enable setgid. + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeSgid), SyscallSucceeds()); + + // Confirm that the file's group is still G2. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g2created)); + EXPECT_EQ(stats.st_gid, groups_.second); +} + +// Chowning a file clears the setgid and setuid bits. +TEST_F(SetgidDirTest, ChownFileClears) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeMask)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeMask), SyscallSucceeds()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "newfile").c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), 0777 | S_ISUID | S_ISGID), SyscallSucceeds()); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISUID | S_ISGID); + + // Change the owning group. + ASSERT_THAT(fchown(fd.get(), -1, groups_.second), SyscallSucceeds()); + + // The setgid and setuid bits should be cleared. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), 0); +} + +// Chowning a file with setgid enabled, but not the group exec bit, does not +// clear the setgid bit. Such files are mandatory locked. +TEST_F(SetgidDirTest, ChownNoExecFileDoesNotClear) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeNoExec)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeNoExec), SyscallSucceeds()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(g1owned, "newdir").c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), 0766 | S_ISUID | S_ISGID), SyscallSucceeds()); + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.first); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISUID | S_ISGID); + + // Change the owning group. + ASSERT_THAT(fchown(fd.get(), -1, groups_.second), SyscallSucceeds()); + + // Only the setuid bit is cleared. + stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(fd)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & (S_ISUID | S_ISGID), S_ISGID); +} + +// Chowning a directory with setgid enabled does not clear the bit. +TEST_F(SetgidDirTest, ChownDirDoesNotClear) { + // Set group to G1, create a directory, and enable setgid. + auto g1owned = JoinPath(temp_dir_.path(), "g1owned/"); + ASSERT_NO_FATAL_FAILURE(MkdirAsGid(groups_.first, g1owned, kDirmodeMask)); + ASSERT_THAT(chmod(g1owned.c_str(), kDirmodeMask), SyscallSucceeds()); + + // Change the owning group. + ASSERT_THAT(chown(g1owned.c_str(), -1, groups_.second), SyscallSucceeds()); + + struct stat stats = ASSERT_NO_ERRNO_AND_VALUE(Stat(g1owned)); + EXPECT_EQ(stats.st_gid, groups_.second); + EXPECT_EQ(stats.st_mode & kDirmodeMask, kDirmodeMask); +} + +struct FileModeTestcase { + std::string name; + mode_t mode; + mode_t result_mode; + + FileModeTestcase(const std::string& name, mode_t mode, mode_t result_mode) + : name(name), mode(mode), result_mode(result_mode) {} +}; + +class FileModeTest : public ::testing::TestWithParam<FileModeTestcase> {}; + +TEST_P(FileModeTest, WriteToFile) { + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + auto path = JoinPath(temp_dir.path(), GetParam().name); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); + struct stat stats; + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode); + + // For security reasons, writing to the file clears the SUID bit, and clears + // the SGID bit when the group executable bit is unset (which is not a true + // SGID binary). + constexpr char kInput = 'M'; + ASSERT_THAT(write(fd.get(), &kInput, sizeof(kInput)), + SyscallSucceedsWithValue(sizeof(kInput))); + + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode); +} + +TEST_P(FileModeTest, TruncateFile) { + // TODO(b/175325250): Enable when setgid directories are supported. + SKIP_IF(IsRunningOnGvisor()); + + auto temp_dir = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0777 /* mode */)); + auto path = JoinPath(temp_dir.path(), GetParam().name); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_CREAT | O_RDWR, 0666)); + ASSERT_THAT(fchmod(fd.get(), GetParam().mode), SyscallSucceeds()); + struct stat stats; + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().mode); + + // For security reasons, truncating the file clears the SUID bit, and clears + // the SGID bit when the group executable bit is unset (which is not a true + // SGID binary). + ASSERT_THAT(ftruncate(fd.get(), 0), SyscallSucceeds()); + + ASSERT_THAT(fstat(fd.get(), &stats), SyscallSucceeds()); + EXPECT_EQ(stats.st_mode & kDirmodeMask, GetParam().result_mode); +} + +INSTANTIATE_TEST_SUITE_P( + FileModes, FileModeTest, + ::testing::ValuesIn<FileModeTestcase>( + {FileModeTestcase("normal file", 0777, 0777), + FileModeTestcase("setuid", S_ISUID | 0777, 00777), + FileModeTestcase("setgid", S_ISGID | 0777, 00777), + FileModeTestcase("setuid and setgid", S_ISUID | S_ISGID | 0777, 00777), + FileModeTestcase("setgid without exec", S_ISGID | 0767, + S_ISGID | 0767), + FileModeTestcase("setuid and setgid without exec", + S_ISGID | S_ISUID | 0767, S_ISGID | 0767)})); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index 6aabd79e7..baf794152 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -372,18 +372,18 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { SetupGvisorDeathTest(); const auto rest = [&] { - ShmSegment shm = ASSERT_NO_ERRNO_AND_VALUE( + ShmSegment shm = TEST_CHECK_NO_ERRNO_AND_VALUE( Shmget(IPC_PRIVATE, kAllocSize, IPC_CREAT | 0777)); - char* addr = ASSERT_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); + char* addr = TEST_CHECK_NO_ERRNO_AND_VALUE(Shmat(shm.id(), nullptr, 0)); // Mark the segment as destroyed so it's automatically cleaned up when we // crash below. We can't rely on the standard cleanup since the destructor // will not run after the SIGSEGV. Note that this doesn't destroy the // segment immediately since we're still attached to it. - ASSERT_NO_ERRNO(shm.Rmid()); + TEST_CHECK_NO_ERRNO(shm.Rmid()); addr[0] = 'x'; - ASSERT_NO_ERRNO(Shmdt(addr)); + TEST_CHECK_NO_ERRNO(Shmdt(addr)); // This access should cause a SIGSEGV. addr[0] = 'x'; diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 831d96262..579e824cd 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -65,6 +65,37 @@ TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) { SyscallSucceeds()); } +// Copied from include/net/tcp.h. +constexpr int TCP_CA_OPEN = 0; + +TEST_P(TCPSocketPairTest, CheckTcpInfoFields) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char buf[10] = {}; + ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + // Wait until second_fd sees the data and then recv it. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN, 0}; + constexpr int kPollTimeoutMs = 2000; // Wait up to 2 seconds for the data. + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + struct tcp_info opt = {}; + socklen_t optLen = sizeof(opt); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_TCP, TCP_INFO, &opt, &optLen), + SyscallSucceeds()); + ASSERT_EQ(optLen, sizeof(opt)); + + // Validates the received tcp_info fields. + EXPECT_EQ(opt.tcpi_ca_state, TCP_CA_OPEN); + EXPECT_GT(opt.tcpi_snd_cwnd, 0); + EXPECT_GT(opt.tcpi_rto, 0); +} + // This test validates that an RST is sent instead of a FIN when data is // unread on calls to close(2). TEST_P(TCPSocketPairTest, RSTSentOnCloseWithUnreadData) { diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 6e7142a42..72f888659 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -221,6 +221,43 @@ TEST_F(StatTest, TrailingSlashNotCleanedReturnsENOTDIR) { EXPECT_THAT(lstat(bad_path.c_str(), &buf), SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(StatTest, FstatFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + struct stat st; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + + // Stat the directory. + ASSERT_THAT(fstat(fd.get(), &st), SyscallSucceeds()); +} + +TEST_F(StatTest, FstatDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + struct stat st; + TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY)); + + // Stat the directory. + ASSERT_THAT(fstat(dirfd.get(), &st), SyscallSucceeds()); +} + +// fstatat with an O_PATH fd +TEST_F(StatTest, FstatatDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpdir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + FileDescriptor dirfd = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpdir.path().c_str(), O_PATH | O_DIRECTORY)); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + struct stat st = {}; + EXPECT_THAT(fstatat(dirfd.get(), tmpfile.path().c_str(), &st, 0), + SyscallSucceeds()); + EXPECT_FALSE(S_ISDIR(st.st_mode)); + EXPECT_TRUE(S_ISREG(st.st_mode)); +} + // Test fstatating a symlink directory. TEST_F(StatTest, FstatatSymlinkDir) { // Create a directory and symlink to it. diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index f0fb166bd..d4ea8e026 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -64,6 +64,16 @@ TEST(FstatfsTest, InternalTmpfs) { EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); } +TEST(FstatfsTest, CanStatFileWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH)); + + struct statfs st; + EXPECT_THAT(fstatfs(fd.get(), &st), SyscallSucceeds()); +} + TEST(FstatfsTest, InternalDevShm) { auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index 4d9eba7f0..ea219a091 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -269,6 +269,36 @@ TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { EXPECT_THAT(close(dirfd), SyscallSucceeds()); } +TEST(SymlinkTest, SymlinkAtDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string filepath = NewTempAbsPathInDir(dir.path()); + const std::string base = std::string(Basename(filepath)); + FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH)); + + EXPECT_THAT(symlinkat("/dangling", dirfd.get(), base.c_str()), + SyscallSucceeds()); +} + +TEST(SymlinkTest, ReadlinkAtDirWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string filepath = NewTempAbsPathInDir(dir.path()); + const std::string base = std::string(Basename(filepath)); + ASSERT_THAT(symlink("/dangling", filepath.c_str()), SyscallSucceeds()); + + FileDescriptor dirfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path().c_str(), O_DIRECTORY | O_PATH)); + + std::vector<char> buf(1024); + int linksize; + EXPECT_THAT( + linksize = readlinkat(dirfd.get(), base.c_str(), buf.data(), 1024), + SyscallSucceeds()); + EXPECT_EQ(0, strncmp("/dangling", buf.data(), linksize)); +} + TEST(SymlinkTest, ReadlinkAtDegradedPermissions_NoRandomSave) { // Drop capabilities that allow us to override file and directory permissions. ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc index 8aa2525a9..84a2c4ed7 100644 --- a/test/syscalls/linux/sync.cc +++ b/test/syscalls/linux/sync.cc @@ -49,10 +49,20 @@ TEST(SyncTest, SyncFromPipe) { EXPECT_THAT(close(pipes[1]), SyscallSucceeds()); } -TEST(SyncTest, CannotSyncFileSytemAtBadFd) { +TEST(SyncTest, CannotSyncFileSystemAtBadFd) { EXPECT_THAT(syncfs(-1), SyscallFailsWithErrno(EBADF)); } +TEST(SyncTest, CannotSyncFileSystemAtOpathFD) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_PATH)); + + EXPECT_THAT(syncfs(fd.get()), SyscallFailsWithErrno(EBADF)); +} } // namespace } // namespace testing diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 9028ab024..f56c50e61 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1168,6 +1168,42 @@ TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) { EXPECT_EQ(read_bytes, kBufSz); } +TEST_P(SimpleTcpSocketTest, SelfConnectSend_NoRandomSave) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + constexpr int max_seg = 256; + ASSERT_THAT( + setsockopt(s.get(), SOL_TCP, TCP_MAXSEG, &max_seg, sizeof(max_seg)), + SyscallSucceeds()); + + ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT( + getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)( + s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + std::vector<char> writebuf(512 << 10); // 512 KiB. + + // Try to send the whole thing. + int n; + ASSERT_THAT(n = SendFd(s.get(), writebuf.data(), writebuf.size(), 0), + SyscallSucceeds()); + + // We should have written the whole thing. + EXPECT_EQ(n, writebuf.size()); + EXPECT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceedsWithValue(0)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnect) { const FileDescriptor listener = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index bfc95ed38..17832c47d 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -196,6 +196,16 @@ TEST(TruncateTest, FtruncateNonWriteable) { EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(TruncateTest, FtruncateWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + auto temp_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), absl::string_view(), 0555 /* mode */)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(temp_file.path(), O_PATH)); + EXPECT_THAT(ftruncate(fd.get(), 0), AnyOf(SyscallFailsWithErrno(EBADF), + SyscallFailsWithErrno(EINVAL))); +} + // ftruncate(2) should succeed as long as the file descriptor is writeable, // regardless of whether the file permissions allow writing. TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) { diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index 64d6d0b8f..4139a18d8 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -23,6 +23,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "test/util/capability_util.h" +#include "test/util/cleanup.h" +#include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -33,6 +35,16 @@ ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID"); ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID"); +// Force use of syscall instead of glibc set*id() wrappers because we want to +// apply to the current task only. libc sets all threads in a process because +// "POSIX requires that all threads in a process share the same credentials." +#define setuid USE_SYSCALL_INSTEAD +#define setgid USE_SYSCALL_INSTEAD +#define setreuid USE_SYSCALL_INSTEAD +#define setregid USE_SYSCALL_INSTEAD +#define setresuid USE_SYSCALL_INSTEAD +#define setresgid USE_SYSCALL_INSTEAD + using ::testing::UnorderedElementsAreArray; namespace gvisor { @@ -137,21 +149,31 @@ TEST(UidGidRootTest, Setuid) { TEST(UidGidRootTest, Setgid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(syscall(SYS_setgid, -1), SyscallFailsWithErrno(EINVAL)); - const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); - ASSERT_THAT(setgid(gid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + ScopedThread([&] { + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); + EXPECT_THAT(syscall(SYS_setgid, gid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + }); } TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) { +#pragma push_macro("allow_setgid") +#undef setgid + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); + int old_gid = getgid(); + auto clean = Cleanup([old_gid] { setgid(old_gid); }); + const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1); // NOTE(b/64676707): Do setgid in a separate thread so that we can test if // info.si_pid is set correctly. ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); }); EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid)); + +#pragma pop_macro("allow_setgid") } TEST(UidGidRootTest, Setreuid) { @@ -159,27 +181,25 @@ TEST(UidGidRootTest, Setreuid) { // "Supplying a value of -1 for either the real or effective user ID forces // the system to leave that ID unchanged." - setreuid(2) - EXPECT_THAT(setreuid(-1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setreuid, -1, -1), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. After + // calling setuid(non-zero-UID), there is no way to get root privileges + // back. ScopedThread([&] { const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1); const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2); - // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. EXPECT_THAT(syscall(SYS_setreuid, ruid, euid), SyscallSucceeds()); // "If the real user ID is set or the effective user ID is set to a value - // not equal to the previous real user ID, the saved set-user-ID will be set - // to the new effective user ID." - setreuid(2) + // not equal to the previous real user ID, the saved set-user-ID will be + // set to the new effective user ID." - setreuid(2) EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, euid)); }); } @@ -187,13 +207,15 @@ TEST(UidGidRootTest, Setreuid) { TEST(UidGidRootTest, Setregid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setregid(-1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setregid, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); - const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); - ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); + ScopedThread([&] { + const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1); + const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2); + ASSERT_THAT(syscall(SYS_setregid, rgid, egid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid)); + }); } TEST(UidGidRootTest, Setresuid) { @@ -201,23 +223,24 @@ TEST(UidGidRootTest, Setresuid) { // "If one of the arguments equals -1, the corresponding value is not // changed." - setresuid(2) - EXPECT_THAT(setresuid(-1, -1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setresuid, -1, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckUIDs(0, 0, 0)); // Do setuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. After calling - // setuid(non-zero-UID), there is no way to get root privileges back. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. After + // calling setuid(non-zero-UID), there is no way to get root privileges + // back. ScopedThread([&] { const uid_t ruid = 12345; const uid_t euid = 23456; const uid_t suid = 34567; // Use syscall instead of glibc setuid wrapper because we want this setuid - // call to only apply to this task. posix threads, however, require that all - // threads have the same UIDs, so using the setuid wrapper sets all threads' - // real UID. + // call to only apply to this task. posix threads, however, require that + // all threads have the same UIDs, so using the setuid wrapper sets all + // threads' real UID. EXPECT_THAT(syscall(SYS_setresuid, ruid, euid, suid), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckUIDs(ruid, euid, suid)); }); @@ -226,14 +249,16 @@ TEST(UidGidRootTest, Setresuid) { TEST(UidGidRootTest, Setresgid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); - EXPECT_THAT(setresgid(-1, -1, -1), SyscallSucceeds()); + EXPECT_THAT(syscall(SYS_setresgid, -1, -1, -1), SyscallSucceeds()); EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0)); - const gid_t rgid = 12345; - const gid_t egid = 23456; - const gid_t sgid = 34567; - ASSERT_THAT(setresgid(rgid, egid, sgid), SyscallSucceeds()); - EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid)); + ScopedThread([&] { + const gid_t rgid = 12345; + const gid_t egid = 23456; + const gid_t sgid = 34567; + ASSERT_THAT(syscall(SYS_setresgid, rgid, egid, sgid), SyscallSucceeds()); + EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, sgid)); + }); } TEST(UidGidRootTest, Setgroups) { @@ -254,14 +279,14 @@ TEST(UidGidRootTest, Setuid_prlimit) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot())); // Do seteuid in a separate thread so that after finishing this test, the - // process can still open files the test harness created before starting this - // test. Otherwise, the files are created by root (UID before the test), but - // cannot be opened by the `uid` set below after the test. + // process can still open files the test harness created before starting + // this test. Otherwise, the files are created by root (UID before the + // test), but cannot be opened by the `uid` set below after the test. ScopedThread([&] { - // Use syscall instead of glibc setuid wrapper because we want this seteuid - // call to only apply to this task. POSIX threads, however, require that all - // threads have the same UIDs, so using the seteuid wrapper sets all - // threads' UID. + // Use syscall instead of glibc setuid wrapper because we want this + // seteuid call to only apply to this task. POSIX threads, however, + // require that all threads have the same UIDs, so using the seteuid + // wrapper sets all threads' UID. EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds()); // Despite the UID change, we should be able to get our own limits. diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 77bcfbb8a..740992d0a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -218,6 +218,44 @@ TEST_F(WriteTest, PwriteNoChangeOffset) { EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); } +TEST_F(WriteTest, WriteWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallFailsWithErrno(EBADF)); +} + +TEST_F(WriteTest, WritevWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + char buf[16]; + struct iovec iov; + iov.iov_base = buf; + iov.iov_len = sizeof(buf); + + EXPECT_THAT(writev(fd, &iov, /*__count=*/1), SyscallFailsWithErrno(EBADF)); +} + +TEST_F(WriteTest, PwriteWithOpath) { + SKIP_IF(IsRunningWithVFS1()); + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_PATH)); + int fd = f.get(); + + const std::string data = "hello world\n"; + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index bd3f829c4..a953a55fe 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -607,6 +607,27 @@ TEST_F(XattrTest, XattrWithFD) { EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); } +TEST_F(XattrTest, XattrWithOPath) { + SKIP_IF(IsRunningWithVFS1()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), O_PATH)); + const char name[] = "user.test"; + int val = 1234; + size_t size = sizeof(val); + EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EBADF)); + + int buf; + EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size), + SyscallFailsWithErrno(EBADF)); + + char list[sizeof(name)]; + EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)), + SyscallFailsWithErrno(EBADF)); + + EXPECT_THAT(fremovexattr(fd.get(), name), SyscallFailsWithErrno(EBADF)); +} + TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { // Trusted namespace not supported in VFS1. SKIP_IF(IsRunningWithVFS1()); diff --git a/test/util/capability_util.h b/test/util/capability_util.h index bb9ea1fe5..a03bc7e05 100644 --- a/test/util/capability_util.h +++ b/test/util/capability_util.h @@ -96,6 +96,19 @@ inline PosixError DropPermittedCapability(int cap) { PosixErrorOr<bool> CanCreateUserNamespace(); +class AutoCapability { + public: + AutoCapability(int cap, bool set) : cap_(cap), set_(set) { + EXPECT_NO_ERRNO(SetCapability(cap_, set_)); + } + + ~AutoCapability() { EXPECT_NO_ERRNO(SetCapability(cap_, !set_)); } + + private: + int cap_; + bool set_; +}; + } // namespace testing } // namespace gvisor #endif // GVISOR_TEST_UTIL_CAPABILITY_UTIL_H_ diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index b16055dd8..5f1ce0d8a 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -663,5 +663,21 @@ PosixErrorOr<bool> IsOverlayfs(const std::string& path) { return stat.f_type == OVERLAYFS_SUPER_MAGIC; } +PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2) { + struct stat stat_result1, stat_result2; + int res = fstat(fd1.get(), &stat_result1); + if (res < 0) { + return PosixError(errno, absl::StrCat("fstat ", fd1.get())); + } + + res = fstat(fd2.get(), &stat_result2); + if (res < 0) { + return PosixError(errno, absl::StrCat("fstat ", fd2.get())); + } + EXPECT_EQ(stat_result1.st_dev, stat_result2.st_dev); + EXPECT_EQ(stat_result1.st_ino, stat_result2.st_ino); + + return NoError(); +} } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h index c99cf5eb7..2190c3bca 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -191,6 +191,8 @@ PosixErrorOr<bool> IsTmpfs(const std::string& path); // IsOverlayfs returns true if the file at path is backed by overlayfs. PosixErrorOr<bool> IsOverlayfs(const std::string& path); +PosixError CheckSameFile(const FileDescriptor& fd1, const FileDescriptor& fd2); + namespace internal { // Not part of the public API. std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); diff --git a/test/util/logging.cc b/test/util/logging.cc index 5d5e76c46..5fadb076b 100644 --- a/test/util/logging.cc +++ b/test/util/logging.cc @@ -69,9 +69,7 @@ int WriteNumber(int fd, uint32_t val) { } // namespace void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno) { - int saved_errno = errno; - + size_t msg_size, int errno_value) { constexpr char kCheckFailure[] = "Check failed: "; Write(2, kCheckFailure, sizeof(kCheckFailure) - 1); Write(2, cond, cond_size); @@ -81,10 +79,10 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, Write(2, msg, msg_size); } - if (include_errno) { + if (errno_value != 0) { constexpr char kErrnoMessage[] = " (errno "; Write(2, kErrnoMessage, sizeof(kErrnoMessage) - 1); - WriteNumber(2, saved_errno); + WriteNumber(2, errno_value); Write(2, ")", 1); } diff --git a/test/util/logging.h b/test/util/logging.h index 589166fab..5c17f1233 100644 --- a/test/util/logging.h +++ b/test/util/logging.h @@ -21,7 +21,7 @@ namespace gvisor { namespace testing { void CheckFailure(const char* cond, size_t cond_size, const char* msg, - size_t msg_size, bool include_errno); + size_t msg_size, int errno_value); // If cond is false, aborts the current process. // @@ -30,7 +30,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, false); \ + 0, 0); \ } \ } while (0) @@ -41,7 +41,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, false); \ + sizeof(msg) - 1, 0); \ } \ } while (0) @@ -52,7 +52,7 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, nullptr, \ - 0, true); \ + 0, errno); \ } \ } while (0) @@ -63,10 +63,54 @@ void CheckFailure(const char* cond, size_t cond_size, const char* msg, do { \ if (!(cond)) { \ ::gvisor::testing::CheckFailure(#cond, sizeof(#cond) - 1, msg, \ - sizeof(msg) - 1, true); \ + sizeof(msg) - 1, errno); \ } \ } while (0) +// expr must return PosixErrorOr<T>. The current process is aborted if +// !PosixError<T>.ok(). +// +// This macro is async-signal-safe. +#define TEST_CHECK_NO_ERRNO(expr) \ + ({ \ + auto _expr_result = (expr); \ + if (!_expr_result.ok()) { \ + ::gvisor::testing::CheckFailure( \ + #expr, sizeof(#expr) - 1, nullptr, 0, \ + _expr_result.error().errno_value()); \ + } \ + }) + +// expr must return PosixErrorOr<T>. The current process is aborted if +// !PosixError<T>.ok(). Otherwise, PosixErrorOr<T> value is returned. +// +// This macro is async-signal-safe. +#define TEST_CHECK_NO_ERRNO_AND_VALUE(expr) \ + ({ \ + auto _expr_result = (expr); \ + if (!_expr_result.ok()) { \ + ::gvisor::testing::CheckFailure( \ + #expr, sizeof(#expr) - 1, nullptr, 0, \ + _expr_result.error().errno_value()); \ + } \ + std::move(_expr_result).ValueOrDie(); \ + }) + +// cond must be greater or equal than 0. Used to test result of syscalls. +// +// This macro is async-signal-safe. +#define TEST_CHECK_SUCCESS(cond) TEST_PCHECK((cond) >= 0) + +// cond must be -1 and errno must match errno_value. Used to test errors from +// syscalls. +// +// This macro is async-signal-safe. +#define TEST_CHECK_ERRNO(cond, errno_value) \ + do { \ + TEST_PCHECK((cond) == -1); \ + TEST_PCHECK_MSG(errno == (errno_value), #cond " expected " #errno_value); \ + } while (0) + } // namespace testing } // namespace gvisor diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h index 2f3bf4a6f..840fde4ee 100644 --- a/test/util/multiprocess_util.h +++ b/test/util/multiprocess_util.h @@ -123,7 +123,8 @@ inline PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, // Calls fn in a forked subprocess and returns the exit status of the // subprocess. // -// fn must be async-signal-safe. +// fn must be async-signal-safe. Use of ASSERT/EXPECT functions is prohibited. +// Use TEST_CHECK variants instead. PosixErrorOr<int> InForkedProcess(const std::function<void()>& fn); } // namespace testing diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc index deed0c05b..8522e4c81 100644 --- a/test/util/posix_error.cc +++ b/test/util/posix_error.cc @@ -50,7 +50,7 @@ std::string PosixError::ToString() const { ret = absl::StrCat("PosixError(errno=", errno_, " ", res, ")"); #endif - if (!msg_.empty()) { + if (strnlen(msg_, sizeof(msg_)) > 0) { ret.append(" "); ret.append(msg_); } diff --git a/test/util/posix_error.h b/test/util/posix_error.h index b634a7f78..27557ad44 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -26,12 +26,18 @@ namespace gvisor { namespace testing { +// PosixError must be async-signal-safe. class ABSL_MUST_USE_RESULT PosixError { public: PosixError() {} + explicit PosixError(int errno_value) : errno_(errno_value) {} - PosixError(int errno_value, std::string msg) - : errno_(errno_value), msg_(std::move(msg)) {} + + PosixError(int errno_value, std::string_view msg) : errno_(errno_value) { + // Check that `msg` will fit, leaving room for '\0' at the end. + TEST_CHECK(msg.size() < sizeof(msg_)); + msg.copy(msg_, msg.size()); + } PosixError(PosixError&& other) = default; PosixError& operator=(PosixError&& other) = default; @@ -45,7 +51,7 @@ class ABSL_MUST_USE_RESULT PosixError { const PosixError& error() const { return *this; } int errno_value() const { return errno_; } - std::string message() const { return msg_; } + const char* message() const { return msg_; } // ToString produces a full string representation of this posix error // including the printable representation of the errno and the error message. @@ -58,7 +64,7 @@ class ABSL_MUST_USE_RESULT PosixError { private: int errno_ = 0; - std::string msg_; + char msg_[1024] = {}; }; template <typename T> diff --git a/tools/bazel.mk b/tools/bazel.mk index fb0fc6524..60b50cfb0 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -75,6 +75,7 @@ UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) USERADD_OPTIONS := DOCKER_RUN_OPTIONS := +DOCKER_RUN_OPTIONS += --rm DOCKER_RUN_OPTIONS += --user $(UID):$(GID) DOCKER_RUN_OPTIONS += --entrypoint "" DOCKER_RUN_OPTIONS += --init @@ -160,15 +161,13 @@ bazel-image: load-default ## Ensures that the local builder exists. @docker commit $(BUILDER_NAME) gvisor.dev/images/builder >&2 .PHONY: bazel-image -# Note: when starting the bazel server, we tie the life of the container to the -# bazel server's life, so that the container disappears naturally. ifneq (true,$(shell $(wrapper echo true))) bazel-server: bazel-image ## Ensures that the server exists. @$(call header,DOCKER RUN) @docker rm -f $(DOCKER_NAME) 2>/dev/null || true - @mkdir -p $(GCLOUD_CONFIG) @mkdir -p $(BAZEL_CACHE) - @docker run -d --rm --name $(DOCKER_NAME) \ + @mkdir -p $(GCLOUD_CONFIG) + @docker run -d --name $(DOCKER_NAME) \ -v "$(CURDIR):$(CURDIR)" \ --workdir "$(CURDIR)" \ $(DOCKER_RUN_OPTIONS) \ diff --git a/tools/bazel_gazelle_generate.patch b/tools/bazel_gazelle_generate.patch new file mode 100644 index 000000000..fd1e1bda6 --- /dev/null +++ b/tools/bazel_gazelle_generate.patch @@ -0,0 +1,15 @@ +diff --git a/language/go/generate.go b/language/go/generate.go +index 2892948..feb4ad6 100644 +--- a/language/go/generate.go ++++ b/language/go/generate.go +@@ -691,6 +691,10 @@ func (g *generator) setImportAttrs(r *rule.Rule, importPath string) { + } + + func (g *generator) commonVisibility(importPath string) []string { ++ if importPath == "golang.org/x/tools/go/analysis/internal/facts" { ++ // Imported by nogo main. We add a visibility exception. ++ return []string{"//visibility:public"} ++ } + // If the Bazel package name (rel) contains "internal", add visibility for + // subpackages of the parent. + // If the import path contains "internal" but rel does not, this is diff --git a/tools/bazel_gazelle.patch b/tools/bazel_gazelle_noise.patch index e35f38933..e35f38933 100644 --- a/tools/bazel_gazelle.patch +++ b/tools/bazel_gazelle_noise.patch diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD index 8956be621..940538b9e 100644 --- a/tools/checkescape/BUILD +++ b/tools/checkescape/BUILD @@ -8,8 +8,8 @@ go_library( nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ - "@org_golang_x_tools//go/analysis:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library", - "@org_golang_x_tools//go/ssa:go_tool_library", + "@org_golang_x_tools//go/analysis:go_default_library", + "@org_golang_x_tools//go/analysis/passes/buildssa:go_default_library", + "@org_golang_x_tools//go/ssa:go_default_library", ], ) diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index 0c264151b..0bb07b415 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -8,6 +8,6 @@ go_library( nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ - "@org_golang_x_tools//go/analysis:go_tool_library", + "@org_golang_x_tools//go/analysis:go_default_library", ], ) diff --git a/tools/github/BUILD b/tools/github/BUILD index aad088d13..7d0a179f7 100644 --- a/tools/github/BUILD +++ b/tools/github/BUILD @@ -9,7 +9,7 @@ go_binary( deps = [ "//tools/github/nogo", "//tools/github/reviver", - "@com_github_google_go_github_v28//github:go_default_library", + "@com_github_google_go_github_v32//github:go_default_library", "@org_golang_x_oauth2//:go_default_library", ], ) diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD index 19b7eec4d..4259fe94c 100644 --- a/tools/github/nogo/BUILD +++ b/tools/github/nogo/BUILD @@ -11,6 +11,6 @@ go_library( ], deps = [ "//tools/nogo", - "@com_github_google_go_github_v28//github:go_default_library", + "@com_github_google_go_github_v32//github:go_default_library", ], ) diff --git a/tools/github/reviver/BUILD b/tools/github/reviver/BUILD index 7d78480a7..fc54782f5 100644 --- a/tools/github/reviver/BUILD +++ b/tools/github/reviver/BUILD @@ -12,7 +12,7 @@ go_library( visibility = [ "//tools/github:__subpackages__", ], - deps = ["@com_github_google_go_github_v28//github:go_default_library"], + deps = ["@com_github_google_go_github_v32//github:go_default_library"], ) go_test( diff --git a/tools/go_branch.sh b/tools/go_branch.sh index 026733d3c..392e40619 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -16,14 +16,25 @@ set -xeou pipefail +# Remember our current directory. +declare orig_dir +orig_dir=$(pwd) +readonly orig_dir + +# Record the current working commit. +declare head +head=$(git describe --always) +readonly head + # Create a temporary working directory, and ensure that this directory and all # subdirectories are cleaned up upon exit. declare tmp_dir tmp_dir=$(mktemp -d) readonly tmp_dir finish() { - cd / # Leave tmp_dir. - rm -rf "${tmp_dir}" + cd "${orig_dir}" # Leave tmp_dir. + rm -rf "${tmp_dir}" # Remove all contents. + git checkout -f "${head}" # Restore commit. } trap finish EXIT @@ -37,7 +48,7 @@ readonly module origpwd othersrc # Build an amd64 & arm64 gopath. declare -r go_amd64="${tmp_dir}/amd64" declare -r go_arm64="${tmp_dir}/arm64" -make build BAZEL_OPTIONS="" TARGETS="//:gopath" 2>/dev/null +make build BAZEL_OPTIONS="" TARGETS="//:gopath" rsync --recursive --delete --copy-links bazel-bin/gopath/ "${go_amd64}" make build BAZEL_OPTIONS=--config=cross-aarch64 TARGETS="//:gopath" 2>/dev/null rsync --recursive --delete --copy-links bazel-bin/gopath/ "${go_arm64}" @@ -70,11 +81,6 @@ declare -r go_merged="${tmp_dir}/merged" rsync --recursive "${go_amd64}/" "${go_merged}" rsync --recursive "${go_arm64}/" "${go_merged}" -# Record the current working commit. -declare head -head=$(git describe --always) -readonly head - # We expect to have an existing go branch that we will use as the basis for this # commit. That branch may be empty, but it must exist. We search for this branch # using the local branch, the "origin" branch, and other remotes, in order. diff --git a/tools/make_apt.sh b/tools/make_apt.sh index 302ed8aa3..68f6973ec 100755 --- a/tools/make_apt.sh +++ b/tools/make_apt.sh @@ -119,7 +119,11 @@ for dir in "${root}"/pool/*/binary-*; do arches+=("${arch}") repo_packages="${release}"/main/"${name}" mkdir -p "${repo_packages}" - (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages) + (cd "${root}" && apt-ftparchive packages "${dir##${root}/}" > "${repo_packages}"/Packages) + if ! [[ -s "${repo_packages}"/Packages ]]; then + echo "Packages file is size zero." >&2 + exit 1 + fi (cd "${repo_packages}" && cat Packages | gzip > Packages.gz) (cd "${repo_packages}" && cat Packages | xz > Packages.xz) done diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 566e0889e..7976c7521 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -38,34 +38,34 @@ go_library( "//tools/checkunsafe", "@co_honnef_go_tools//staticcheck:go_default_library", "@co_honnef_go_tools//stylecheck:go_default_library", - "@org_golang_x_tools//go/analysis:go_tool_library", - "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/composite:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/errorsas:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/httpresponse:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/shadow:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/stringintconv:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/unreachable:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library", - "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library", - "@org_golang_x_tools//go/gcexportdata:go_tool_library", + "@org_golang_x_tools//go/analysis:go_default_library", + "@org_golang_x_tools//go/analysis/internal/facts:go_default_library", + "@org_golang_x_tools//go/analysis/passes/asmdecl:go_default_library", + "@org_golang_x_tools//go/analysis/passes/assign:go_default_library", + "@org_golang_x_tools//go/analysis/passes/atomic:go_default_library", + "@org_golang_x_tools//go/analysis/passes/bools:go_default_library", + "@org_golang_x_tools//go/analysis/passes/buildtag:go_default_library", + "@org_golang_x_tools//go/analysis/passes/cgocall:go_default_library", + "@org_golang_x_tools//go/analysis/passes/composite:go_default_library", + "@org_golang_x_tools//go/analysis/passes/copylock:go_default_library", + "@org_golang_x_tools//go/analysis/passes/errorsas:go_default_library", + "@org_golang_x_tools//go/analysis/passes/httpresponse:go_default_library", + "@org_golang_x_tools//go/analysis/passes/loopclosure:go_default_library", + "@org_golang_x_tools//go/analysis/passes/lostcancel:go_default_library", + "@org_golang_x_tools//go/analysis/passes/nilfunc:go_default_library", + "@org_golang_x_tools//go/analysis/passes/nilness:go_default_library", + "@org_golang_x_tools//go/analysis/passes/printf:go_default_library", + "@org_golang_x_tools//go/analysis/passes/shadow:go_default_library", + "@org_golang_x_tools//go/analysis/passes/shift:go_default_library", + "@org_golang_x_tools//go/analysis/passes/stdmethods:go_default_library", + "@org_golang_x_tools//go/analysis/passes/stringintconv:go_default_library", + "@org_golang_x_tools//go/analysis/passes/structtag:go_default_library", + "@org_golang_x_tools//go/analysis/passes/tests:go_default_library", + "@org_golang_x_tools//go/analysis/passes/unmarshal:go_default_library", + "@org_golang_x_tools//go/analysis/passes/unreachable:go_default_library", + "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_default_library", + "@org_golang_x_tools//go/analysis/passes/unusedresult:go_default_library", + "@org_golang_x_tools//go/gcexportdata:go_default_library", ], ) diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 161ea972e..0c48a7a5a 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -188,6 +188,14 @@ def _nogo_aspect_impl(target, ctx): # All work is done in the shadow properties for go rules. For a proto # library, we simply skip the analysis portion but still need to return a # valid NogoInfo to reference the generated binary. + # + # Note that we almost exclusively use go_library, not go_tool_library. + # This is because nogo is manually annotated, so the go_tool_library kind + # is not needed to avoid dependency loops. Unfortunately, bazel coverdata + # is exported *only* as a go_tool_library. This does not cause a problem, + # since there is guaranteed to be no conflict. However for consistency, + # we should not introduce new go_tool_library dependencies unless strictly + # necessary. if ctx.rule.kind in ("go_library", "go_tool_library", "go_binary", "go_test"): srcs = ctx.rule.files.srcs deps = ctx.rule.attr.deps diff --git a/tools/rules_go.patch b/tools/rules_go_symbols.patch index 5e1e87084..46767f169 100644 --- a/tools/rules_go.patch +++ b/tools/rules_go_symbols.patch @@ -2,13 +2,13 @@ diff --git a/go/private/rules/test.bzl b/go/private/rules/test.bzl index 17516ad7..76b6c68c 100644 --- a/go/private/rules/test.bzl +++ b/go/private/rules/test.bzl -@@ -121,9 +121,6 @@ def _go_test_impl(ctx): +@@ -117,9 +117,6 @@ def _go_test_impl(ctx): ) - + test_gc_linkopts = gc_linkopts(ctx) - if not go.mode.debug: - # Disable symbol table and DWARF generation for test binaries. - test_gc_linkopts.extend(["-s", "-w"]) - - # Now compile the test binary itself - test_library = GoLibrary( + + # Link in the run_dir global for bzltestutil + test_gc_linkopts.extend(["-X", "github.com/bazelbuild/rules_go/go/tools/bzltestutil.RunDir=" + run_dir]) diff --git a/tools/rules_go_visibility.patch b/tools/rules_go_visibility.patch new file mode 100644 index 000000000..e5bb2e3d5 --- /dev/null +++ b/tools/rules_go_visibility.patch @@ -0,0 +1,22 @@ +diff --git a/third_party/org_golang_x_tools-gazelle.patch b/third_party/org_golang_x_tools-gazelle.patch +index 7bdacff5..2fe9ce93 100644 +--- a/third_party/org_golang_x_tools-gazelle.patch ++++ b/third_party/org_golang_x_tools-gazelle.patch +@@ -2054,7 +2054,7 @@ diff -urN b/go/analysis/internal/facts/BUILD.bazel c/go/analysis/internal/facts/ + + "imports.go", + + ], + + importpath = "golang.org/x/tools/go/analysis/internal/facts", +-+ visibility = ["//go/analysis:__subpackages__"], +++ visibility = ["//visibility:public"], + + deps = [ + + "//go/analysis", + + "//go/types/objectpath", +@@ -2078,7 +2078,7 @@ diff -urN b/go/analysis/internal/facts/BUILD.bazel c/go/analysis/internal/facts/ + +alias( + + name = "go_default_library", + + actual = ":facts", +-+ visibility = ["//go/analysis:__subpackages__"], +++ visibility = ["//visibility:public"], + +) + + + +go_test( |