diff options
248 files changed, 8956 insertions, 2173 deletions
@@ -384,3 +384,7 @@ nogo: ## Surfaces all nogo findings. @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...") @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo") .PHONY: nogo + +gazelle: ## Runs gazelle to update WORKSPACE. + @$(call submake,run TARGETS="//:gazelle" ARGS="update-repos -from_file=go.mod -prune") +.PHONY: gazelle @@ -43,7 +43,7 @@ http_archive( ) http_archive( - name = "io_bazel_rules_go_bazel3", # To replace the above. + name = "io_bazel_rules_go_bazel3", # To replace the above. patch_args = ["-p1"], patches = [ "//tools/nogo:io_bazel_rules_go-visibility.patch", @@ -56,7 +56,7 @@ http_archive( ) http_archive( - name = "bazel_gazelle_bazel3", # To replace the above. + name = "bazel_gazelle_bazel3", # To replace the above. sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", @@ -81,8 +81,8 @@ gazelle_dependencies() go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", - sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=", - version = "v0.0.0-20200302150141-5c8b2ff67527", + sum = "h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=", + version = "v0.0.0-20200323222414-85ca7c5b95cd", ) # Load C++ rules. @@ -126,7 +126,7 @@ http_archive( ) http_archive( - name = "bazel_toolchains_bazel3", # To replace the above. + name = "bazel_toolchains_bazel3", # To replace the above. sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48", strip_prefix = "bazel-toolchains-3.1.1", urls = [ @@ -208,8 +208,7 @@ http_archive( go_repository( name = "com_github_sirupsen_logrus", importpath = "github.com/sirupsen/logrus", - replace = "github.com/Sirupsen/logrus", - sum = "h1:cWjBmzJnL1sO88XdqJYmq7aiWClqXIQQMJ3Utgy1f+I=", + sum = "h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=", version = "v1.4.2", ) @@ -330,8 +329,8 @@ go_repository( go_repository( name = "org_golang_x_net", importpath = "golang.org/x/net", - sum = "h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=", - version = "v0.0.0-20200625001655-4c5254603344", + sum = "h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=", + version = "v0.0.0-20200822124328-c89045814202", ) go_repository( @@ -358,15 +357,15 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE=", - version = "v0.0.0-20200707200213-416e8f4faf8a", + sum = "h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ=", + version = "v0.0.0-20200918232735-d647fc253266", ) go_repository( name = "org_golang_x_xerrors", importpath = "golang.org/x/xerrors", - sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", - version = "v0.0.0-20191204190536-9bdfabe68543", + sum = "h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=", + version = "v0.0.0-20200804184101-5ec99f83aff1", ) go_repository( @@ -470,8 +469,8 @@ go_repository( go_repository( name = "com_github_microsoft_go_winio", importpath = "github.com/Microsoft/go-winio", - sum = "h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA=", - version = "v0.4.15-0.20190919025122-fc70bd9a86b5", + sum = "h1:9pygWVFqbY9lPxM0peffumuVDyMuIMzNLyO9uFjJuQo=", + version = "v0.4.15-0.20200908182639-5b44b70ab3ab", ) go_repository( @@ -484,8 +483,8 @@ go_repository( go_repository( name = "org_uber_go_atomic", importpath = "go.uber.org/atomic", - sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=", - version = "v1.6.0", + sum = "h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=", + version = "v1.7.0", ) go_repository( @@ -660,8 +659,8 @@ go_repository( go_repository( name = "com_github_google_go_cmp", importpath = "github.com/google/go-cmp", - sum = "h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=", - version = "v0.5.0", + sum = "h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=", + version = "v0.5.1", ) go_repository( @@ -786,15 +785,15 @@ go_repository( go_repository( name = "com_github_urfave_cli", importpath = "github.com/urfave/cli", - sum = "h1:MCfT24H3f//U5+UCrZp1/riVO3B50BovxtDiNn0XKkk=", - version = "v0.0.0-20171014202726-7bc6a0acffa5", + sum = "h1:gsqYFH8bb9ekPA12kRo0hfjngWQjkJPlN9R0N78BoUo=", + version = "v1.22.2", ) go_repository( name = "com_github_yuin_goldmark", importpath = "github.com/yuin/goldmark", - sum = "h1:5tjfNdR2ki3yYQ842+eX2sQHeiwpKJ0RnHO4IYOc4V8=", - version = "v1.1.32", + sum = "h1:ruQGxdhGHe7FWOJPT0mKs5+pD2Xs1Bm/kdGlHO04FmM=", + version = "v1.2.1", ) go_repository( @@ -1024,8 +1023,8 @@ go_repository( go_repository( name = "com_github_vishvananda_netns", importpath = "github.com/vishvananda/netns", - sum = "h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so=", - version = "v0.0.0-20200520041808-52d707b772fe", + sum = "h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=", + version = "v0.0.0-20200728191858-db3c7e526aae", ) go_repository( @@ -1073,6 +1072,48 @@ go_repository( go_repository( name = "com_github_dpjacques_clockwork", importpath = "github.com/dpjacques/clockwork", - sum = "h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=", - version = "v0.1.1-0.20190114191937-d864eecc357b", + sum = "h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c=", + version = "v0.1.1-0.20200827220843-c1f524b839be", +) + +go_repository( + name = "com_github_cilium_ebpf", + importpath = "github.com/cilium/ebpf", + sum = "h1:i8+1fuPLjSgAYXUyBlHNhFwjcfAsP4ufiuH1+PWkyDU=", + version = "v0.0.0-20200110133405-4032b1d8aae3", +) + +go_repository( + name = "com_github_coreos_go_systemd_v22", + importpath = "github.com/coreos/go-systemd/v22", + sum = "h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQamW5YV28=", + version = "v22.0.0", +) + +go_repository( + name = "com_github_cpuguy83_go_md2man_v2", + importpath = "github.com/cpuguy83/go-md2man/v2", + sum = "h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=", + version = "v2.0.0", +) + +go_repository( + name = "com_github_godbus_dbus_v5", + importpath = "github.com/godbus/dbus/v5", + sum = "h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME=", + version = "v5.0.3", +) + +go_repository( + name = "com_github_russross_blackfriday_v2", + importpath = "github.com/russross/blackfriday/v2", + sum = "h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=", + version = "v2.0.1", +) + +go_repository( + name = "com_github_shurcool_sanitized_anchor_name", + importpath = "github.com/shurcooL/sanitized_anchor_name", + sum = "h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=", + version = "v1.0.0", ) diff --git a/g3doc/user_guide/networking.md b/g3doc/user_guide/networking.md index 4aa394c91..95f675633 100644 --- a/g3doc/user_guide/networking.md +++ b/g3doc/user_guide/networking.md @@ -2,9 +2,9 @@ [TOC] -gVisor implements its own network stack called [netstack][netstack]. All aspects -of the network stack are handled inside the Sentry — including TCP connection -state, control messages, and packet assembly — keeping it isolated from the host +gVisor implements its own network stack called netstack. All aspects of the +network stack are handled inside the Sentry — including TCP connection state, +control messages, and packet assembly — keeping it isolated from the host network stack. Data link layer packets are written directly to the virtual device inside the network namespace setup by Docker or Kubernetes. @@ -82,4 +82,3 @@ Offload (GSO) to run with a kernel that is newer than 3.17. Add the } ``` -[netstack]: https://github.com/google/netstack @@ -1,14 +1,15 @@ module gvisor.dev/gvisor -go 1.14 +go 1.15 replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0 require ( cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 // indirect - github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 // indirect + github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab // indirect github.com/Microsoft/hcsshim v0.8.6 // indirect github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect + github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect github.com/containerd/containerd v1.3.4 // indirect github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect @@ -17,17 +18,19 @@ require ( github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 // indirect github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect + github.com/coreos/go-systemd/v22 v22.0.0 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible // indirect github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect github.com/docker/go-connections v0.3.0 // indirect github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect github.com/docker/go-units v0.4.0 // indirect - github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b // indirect + github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be // indirect github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect github.com/golang/protobuf v1.4.2 // indirect - github.com/google/go-cmp v0.5.0 // indirect + github.com/google/go-cmp v0.5.1 // indirect github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect github.com/hashicorp/go-multierror v1.0.0 // indirect @@ -39,13 +42,13 @@ require ( github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect github.com/pborman/uuid v1.2.0 // indirect github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect - github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5 // indirect + github.com/urfave/cli v1.22.2 // indirect github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect - github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe // indirect - go.uber.org/atomic v1.6.0 // indirect + github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect + go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.2.0 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect - golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a // indirect + golang.org/x/tools v0.0.0-20200918232735-d647fc253266 // indirect google.golang.org/grpc v1.29.0 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gotest.tools v2.2.0+incompatible // indirect @@ -18,13 +18,15 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= -github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA= github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= +github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab h1:9pygWVFqbY9lPxM0peffumuVDyMuIMzNLyO9uFjJuQo= +github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg= github.com/Microsoft/hcsshim v0.8.7/go.mod h1:OHd7sQqRFrYd3RmSgbgji+ctCwkbq2wbEYNSzOYtcBQ= github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= -github.com/Microsoft/hcsshim v0.8.9 h1:VrfodqvztU8YSOvygU+DN1BGaSGxmrNfqOv5oOuX2Bk= github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= +github.com/Microsoft/hcsshim v0.8.10 h1:k5wTrpnVU2/xv8ZuzGkbXVd3js5zJ8RnumPo5RxiIxU= +github.com/Microsoft/hcsshim v0.8.10/go.mod h1:g5uw8EV2mAlzqe94tfNBNdr89fnbD/n3HV0OhsddkmM= github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= @@ -32,12 +34,13 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3/go.mod h1:MA5e5Lr8slmEg9bt0VpxxWqJlO4iwu3FBdHUzV7wQVg= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY= github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI= -github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f h1:tSNMc+rJDfmYntojat8lljbt1mgKNpTxUZJsSzJ9Y1s= -github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f/go.mod h1:OApqhQ4XNSNC13gXIwDjhOQxjWa/NxkwZXJ1EvqT0ko= +github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59 h1:qWj4qVYZ95vLWwqyNJCQg7rDsG5wPdze0UaPolH7DUk= +github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM= github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc= github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE= @@ -59,9 +62,12 @@ github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15/go.mod h1:UAxOpgT github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o= github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQamW5YV28= +github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= @@ -74,8 +80,8 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U= -github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= +github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c= +github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -85,11 +91,12 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= +github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= +github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -118,8 +125,8 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= @@ -146,7 +153,6 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -171,10 +177,11 @@ github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zM github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y= github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= -github.com/opencontainers/runtime-spec v0.1.2-0.20190507144316-5b71a03e2700/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8= github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= +github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -186,6 +193,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1: github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= @@ -199,18 +208,18 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= -github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w= github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= -github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so= -github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= -go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -233,7 +242,6 @@ golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= @@ -256,8 +264,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -280,20 +288,18 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190514135907-3a4b5fb9f71f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o= -golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -302,7 +308,6 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -319,16 +324,16 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE= -golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200918232735-d647fc253266 h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ= +golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= diff --git a/images/Makefile b/images/Makefile index d4b6524ba..12927c509 100644 --- a/images/Makefile +++ b/images/Makefile @@ -23,7 +23,7 @@ ARCH := $(shell uname -m) # tests are using locally-defined images (that are consistent and idempotent). REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit LOCAL_IMAGE_PREFIX ?= gvisor.dev/images -ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -exec dirname {} \;))) +ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq))) ifneq ($(ARCH),$(shell uname -m)) DOCKER_PLATFORM_ARGS := --platform=$(ARCH) else @@ -51,6 +51,7 @@ load-%-images: # ensuring that images will always be sourced using the local files if there # are changes. path = $(subst _,/,$(1)) +dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi) tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16) remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1)) local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) @@ -59,11 +60,17 @@ local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) # we need to explicitly repull the base layer in order to ensure that the # architecture is correct. Note that we use the term "rebuild" here to avoid # conflicting with the bazel "build" terminology, which is used elsewhere. -rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2) +rebuild-%: FROM=$(shell grep FROM "$(call path,$*)/$(call dockerfile,$*)" | cut -d' ' -f2) rebuild-%: register-cross + @if ! [ -f "$(call path,$*)/$(call dockerfile,$*)" ]; then \ + (echo "ERROR: Dockerfile for $* not found (is it available for $(ARCH)?)." >&2 && exit 1); \ + fi $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \ T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \ - docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \ + docker build $(DOCKER_PLATFORM_ARGS) \ + -f "$$T/$(call dockerfile,$*)" \ + -t "$(call remote_image,$*)" \ + $$T && \ rm -rf $$T # pull will check the "remote" image and pull if necessary. If the remote image diff --git a/images/basic/mysql/Dockerfile b/images/basic/mysql/Dockerfile index 95da9c48d..d87bfe55b 100644 --- a/images/basic/mysql/Dockerfile +++ b/images/basic/mysql/Dockerfile @@ -1 +1 @@ -FROM mysql:8.0.19 +FROM mysql/mysql-server:8.0.19 diff --git a/images/basic/tomcat/Dockerfile.aarch64 b/images/basic/tomcat/Dockerfile.aarch64 new file mode 100644 index 000000000..ed4096de9 --- /dev/null +++ b/images/basic/tomcat/Dockerfile.aarch64 @@ -0,0 +1 @@ +FROM arm64v8/tomcat:8.0 diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile.x86_64 index ae19f3bfc..ae19f3bfc 100644 --- a/images/jekyll/Dockerfile +++ b/images/jekyll/Dockerfile.x86_64 diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index cdcaa8c73..4a26e28de 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -38,6 +38,7 @@ go_library( "ipc.go", "limits.go", "linux.go", + "membarrier.go", "mm.go", "netdevice.go", "netfilter.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index dc9ac7e7c..7df02dd6d 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,9 +121,27 @@ const ( // Constants from uapi/linux/fsverity.h. const ( - FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_ENABLE_VERITY = 1082156677 + FS_IOC_MEASURE_VERITY = 3221513862 ) +// DigestMetadata is a helper struct for VerityDigest. +// +// +marshal +type DigestMetadata struct { + DigestAlgorithm uint16 + DigestSize uint16 +} + +// SizeOfDigestMetadata is the size of struct DigestMetadata. +const SizeOfDigestMetadata = 4 + +// VerityDigest is struct from uapi/linux/fsverity.h. +type VerityDigest struct { + Metadata DigestMetadata + Digest []byte +} + // IOC outputs the result of _IOC macro in asm-generic/ioctl.h. func IOC(dir, typ, nr, size uint32) uint32 { return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT diff --git a/pkg/abi/linux/membarrier.go b/pkg/abi/linux/membarrier.go new file mode 100644 index 000000000..4f6021a1d --- /dev/null +++ b/pkg/abi/linux/membarrier.go @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// membarrier(2) commands, from include/uapi/linux/membarrier.h. +const ( + MEMBARRIER_CMD_QUERY = 0 + MEMBARRIER_CMD_GLOBAL = (1 << 0) + MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1) + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2) + MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3) + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4) + MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 5) + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE = (1 << 6) + MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ = (1 << 7) + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ = (1 << 8) +) + +// membarrier(2) flags, from include/uapi/linux/membarrier.h. +const ( + MEMBARRIER_CMD_FLAG_CPU = (1 << 0) +) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index a137940b6..6d31eb5e3 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -321,3 +321,16 @@ const ( // Enable all flags. IP6T_INV_MASK = 0x7F ) + +// NFNATRange corresponds to struct nf_nat_range in +// include/uapi/linux/netfilter/nf_nat.h. +type NFNATRange struct { + Flags uint32 + MinAddr Inet6Addr + MaxAddr Inet6Addr + MinProto uint16 // Network byte order. + MaxProto uint16 // Network byte order. +} + +// SizeOfNFNATRange is the size of NFNATRange. +const SizeOfNFNATRange = 40 diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go index b07cafe12..5be3f10f9 100644 --- a/pkg/abi/linux/seccomp.go +++ b/pkg/abi/linux/seccomp.go @@ -83,3 +83,22 @@ type SockFprog struct { pad [6]byte Filter *BPFInstruction } + +// SeccompData is equivalent to struct seccomp_data, which contains the data +// passed to seccomp-bpf filters. +// +// +marshal +type SeccompData struct { + // Nr is the system call number. + Nr int32 + + // Arch is an AUDIT_ARCH_* value indicating the system call convention. + Arch uint32 + + // InstructionPointer is the value of the instruction pointer at the time + // of the system call. + InstructionPointer uint64 + + // Args contains the first 6 system call arguments. + Args [6]uint64 +} diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go index 85fad9956..468c6a387 100644 --- a/pkg/abi/linux/signalfd.go +++ b/pkg/abi/linux/signalfd.go @@ -23,6 +23,8 @@ const ( ) // SignalfdSiginfo is the siginfo encoding for signalfds. +// +// +marshal type SignalfdSiginfo struct { Signo uint32 Errno int32 @@ -41,5 +43,5 @@ type SignalfdSiginfo struct { STime uint64 Addr uint64 AddrLSB uint16 - _ [48]uint8 + _ [48]uint8 `marshal:"unaligned"` } diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 4b4f9bd52..a6e698f57 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -41,7 +41,7 @@ type Layout struct { blockSize int64 // digestSize is the size of a generated hash. digestSize int64 - // levelOffset contains the offset of the begnning of each level in + // levelOffset contains the offset of the beginning of each level in // bytes. The number of levels in the tree is the length of the slice. // The leaf nodes (level 0) contain hashes of blocks of the input data. // Each level N contains hashes of the blocks in level N-1. The highest @@ -123,45 +123,87 @@ func (layout Layout) blockOffset(level int, index int64) int64 { return layout.levelOffset[level] + index*layout.blockSize } +// VerityDescriptor is a struct that is serialized and hashed to get a file's +// root hash, which contains the root hash of the raw content and the file's +// meatadata. +type VerityDescriptor struct { + Name string + Mode uint32 + UID uint32 + GID uint32 + RootHash []byte +} + +func (d *VerityDescriptor) String() string { + return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash) +} + +// GenerateParams contains the parameters used to generate a Merkle tree. +type GenerateParams struct { + // File is a reader of the file to be hashed. + File io.ReadSeeker + // Size is the size of the file. + Size int64 + // Name is the name of the target file. + Name string + // Mode is the mode of the target file. + Mode uint32 + // UID is the user ID of the target file. + UID uint32 + // GID is the group ID of the target file. + GID uint32 + // TreeReader is a reader for the Merkle tree. + TreeReader io.ReadSeeker + // TreeWriter is a writer for the Merkle tree. + TreeWriter io.WriteSeeker + // DataAndTreeInSameFile is true if data and Merkle tree are in the same + // file, or false if Merkle tree is a separate file from data. + DataAndTreeInSameFile bool +} + // Generate constructs a Merkle tree for the contents of data. The output is // written to treeWriter. The treeReader should be able to read the tree after // it has been written. That is, treeWriter and treeReader should point to the // same underlying data but have separate cursors. +// +// Generate returns a hash of descriptor. The descriptor contains the file +// metadata and the hash from file content. +// // Generate will modify the cursor for data, but always restores it to its // original position upon exit. The cursor for tree is modified and not // restored. -func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) { - layout := InitLayout(dataSize, dataAndTreeInSameFile) +func Generate(params *GenerateParams) ([]byte, error) { + layout := InitLayout(params.Size, params.DataAndTreeInSameFile) - numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize + numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize // If the data is in the same file as the tree, zero pad the last data // block. - bytesInLastBlock := dataSize % layout.blockSize - if dataAndTreeInSameFile && bytesInLastBlock != 0 { + bytesInLastBlock := params.Size % layout.blockSize + if params.DataAndTreeInSameFile && bytesInLastBlock != 0 { zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock) - if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF { + if _, err := params.TreeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF { return nil, err } - if _, err := treeWriter.Write(zeroBuf); err != nil { + if _, err := params.TreeWriter.Write(zeroBuf); err != nil { return nil, err } } // Store the current offset, so we can set it back once verification // finishes. - origOffset, err := data.Seek(0, io.SeekCurrent) + origOffset, err := params.File.Seek(0, io.SeekCurrent) if err != nil { return nil, err } - defer data.Seek(origOffset, io.SeekStart) + defer params.File.Seek(origOffset, io.SeekStart) // Read from the beginning of both data and treeReader. - if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF { + if _, err := params.File.Seek(0, io.SeekStart); err != nil && err != io.EOF { return nil, err } - if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF { + if _, err := params.TreeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF { return nil, err } @@ -176,11 +218,11 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree if level == 0 { // Read data block from the target file since level 0 includes hashes // of blocks in the input data. - n, err = data.Read(buf) + n, err = params.File.Read(buf) } else { // Read data block from the tree file since levels higher than 0 are // hashing the lower level hashes. - n, err = treeReader.Read(buf) + n, err = params.TreeReader.Read(buf) } // err is populated as long as the bytes read is smaller than the buffer @@ -200,7 +242,7 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree } // Write the generated hash to the end of the tree file. - if _, err = treeWriter.Write(digest[:]); err != nil { + if _, err = params.TreeWriter.Write(digest[:]); err != nil { return nil, err } } @@ -208,45 +250,124 @@ func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, tree // remaining of the last block. But no need to do so for root. if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 { zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize) - if _, err := treeWriter.Write(zeroBuf[:]); err != nil { + if _, err := params.TreeWriter.Write(zeroBuf[:]); err != nil { return nil, err } } numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock() } - return root, nil + descriptor := VerityDescriptor{ + Name: params.Name, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, + RootHash: root, + } + ret := sha256.Sum256([]byte(descriptor.String())) + return ret[:], nil +} + +// VerifyParams contains the params used to verify a portion of a file against +// a Merkle tree. +type VerifyParams struct { + // Out will be filled with verified data. + Out io.Writer + // File is a handler on the file to be verified. + File io.ReadSeeker + // tree is a handler on the Merkle tree used to verify file. + Tree io.ReadSeeker + // Size is the size of the file. + Size int64 + // Name is the name of the target file. + Name string + // Mode is the mode of the target file. + Mode uint32 + // UID is the user ID of the target file. + UID uint32 + // GID is the group ID of the target file. + GID uint32 + // ReadOffset is the offset of the data range to be verified. + ReadOffset int64 + // ReadSize is the size of the data range to be verified. + ReadSize int64 + // ExpectedRoot is a trusted hash for the file. It is compared with the + // calculated root hash to verify the content. + ExpectedRoot []byte + // DataAndTreeInSameFile is true if data and Merkle tree are in the same + // file, or false if Merkle tree is a separate file from data. + DataAndTreeInSameFile bool +} + +// verifyDescriptor generates a hash from descriptor, and compares it with +// expectedRoot. +func verifyDescriptor(descriptor VerityDescriptor, expectedRoot []byte) error { + h := sha256.Sum256([]byte(descriptor.String())) + if !bytes.Equal(h[:], expectedRoot) { + return fmt.Errorf("unexpected root hash") + } + return nil +} + +// verifyMetadata verifies the metadata by hashing a descriptor that contains +// the metadata and compare the generated hash with expectedRoot. +// +// For verifyMetadata, params.data is not needed. It only accesses params.tree +// for the raw root hash. +func verifyMetadata(params *VerifyParams, layout Layout) error { + if _, err := params.Tree.Seek(layout.blockOffset(layout.rootLevel(), 0 /* index */), io.SeekStart); err != nil { + return fmt.Errorf("failed to seek to root hash: %w", err) + } + root := make([]byte, layout.digestSize) + if _, err := params.Tree.Read(root); err != nil { + return fmt.Errorf("failed to read root hash: %w", err) + } + descriptor := VerityDescriptor{ + Name: params.Name, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, + RootHash: root, + } + return verifyDescriptor(descriptor, params.ExpectedRoot) } // Verify verifies the content read from data with offset. The content is // verified against tree. If content spans across multiple blocks, each block is // verified. Verification fails if the hash of the data does not match the tree // at any level, or if the final root hash does not match expectedRoot. -// Once the data is verified, it will be written using w. +// Once the data is verified, it will be written using params.Out. +// +// Verify checks for both target file content and metadata. If readSize is 0, +// only metadata is checked. +// // Verify will modify the cursor for data, but always restores it to its // original position upon exit. The cursor for tree is modified and not // restored. -func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) { - if readSize <= 0 { - return 0, fmt.Errorf("Unexpected read size: %d", readSize) +func Verify(params *VerifyParams) (int64, error) { + if params.ReadSize < 0 { + return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) + } + layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + if params.ReadSize == 0 { + return 0, verifyMetadata(params, layout) } - layout := InitLayout(int64(dataSize), dataAndTreeInSameFile) // Calculate the index of blocks that includes the target range in input // data. - firstDataBlock := readOffset / layout.blockSize - lastDataBlock := (readOffset + readSize - 1) / layout.blockSize + firstDataBlock := params.ReadOffset / layout.blockSize + lastDataBlock := (params.ReadOffset + params.ReadSize - 1) / layout.blockSize // Store the current offset, so we can set it back once verification // finishes. - origOffset, err := data.Seek(0, io.SeekCurrent) + origOffset, err := params.File.Seek(0, io.SeekCurrent) if err != nil { - return 0, fmt.Errorf("Find current data offset failed: %v", err) + return 0, fmt.Errorf("find current data offset failed: %w", err) } - defer data.Seek(origOffset, io.SeekStart) + defer params.File.Seek(origOffset, io.SeekStart) // Move to the first block that contains target data. - if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { - return 0, fmt.Errorf("Seek to datablock start failed: %v", err) + if _, err := params.File.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { + return 0, fmt.Errorf("seek to datablock start failed: %w", err) } buf := make([]byte, layout.blockSize) @@ -255,7 +376,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in for i := firstDataBlock; i <= lastDataBlock; i++ { // Read a block that includes all or part of target range in // input data. - bytesRead, err := data.Read(buf) + bytesRead, err := params.File.Read(buf) readErr = err // If at the end of input data and all previous blocks are // verified, return the verified input data and EOF. @@ -263,7 +384,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in break } if readErr != nil && readErr != io.EOF { - return 0, fmt.Errorf("Read from data failed: %v", err) + return 0, fmt.Errorf("read from data failed: %w", err) } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. @@ -274,22 +395,29 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in buf[j] = 0 } } - if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil { + descriptor := VerityDescriptor{ + Name: params.Name, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, + } + if err := verifyBlock(params.Tree, descriptor, layout, buf, i, params.ExpectedRoot); err != nil { return 0, err } + // startOff is the beginning of the read range within the // current data block. Note that for all blocks other than the // first, startOff should be 0. startOff := int64(0) if i == firstDataBlock { - startOff = readOffset % layout.blockSize + startOff = params.ReadOffset % layout.blockSize } // endOff is the end of the read range within the current data // block. Note that for all blocks other than the last, endOff // should be the block size. endOff := layout.blockSize if i == lastDataBlock { - endOff = (readOffset+readSize-1)%layout.blockSize + 1 + endOff = (params.ReadOffset+params.ReadSize-1)%layout.blockSize + 1 } // If the provided size exceeds the end of input data, we should // only copy the parts in buf that's part of input data. @@ -299,7 +427,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in if endOff > int64(bytesRead) { endOff = int64(bytesRead) } - n, err := w.Write(buf[startOff:endOff]) + n, err := params.Out.Write(buf[startOff:endOff]) if err != nil { return total, err } @@ -313,9 +441,11 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in // original data. The block is verified through each level of the tree. It // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with -// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to -// maintain the cursor if intended. -func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error { +// expectedRoot. +// +// verifyBlock modifies the cursor for tree. Users needs to maintain the cursor +// if intended. +func verifyBlock(tree io.ReadSeeker, descriptor VerityDescriptor, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -352,21 +482,13 @@ func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex } if !bytes.Equal(digest, expectedDigest) { - return fmt.Errorf("Verification failed") - } - - // If this is the root layer, no need to generate next level - // hash. - if level == layout.rootLevel() { - break + return fmt.Errorf("verification failed") } blockIndex = blockIndex / layout.hashesPerBlock() } - // Verification for the tree succeeded. Now compare the root hash in the - // tree with expectedRoot. - if !bytes.Equal(digest[:], expectedRoot) { - return fmt.Errorf("Verification failed") - } - return nil + // Verification for the tree succeeded. Now hash the descriptor with + // the root hash and compare it with expectedRoot. + descriptor.RootHash = digest + return verifyDescriptor(descriptor, expectedRoot) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index daaca759a..bb11ec844 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -84,6 +84,13 @@ func TestLayout(t *testing.T) { } } +const ( + defaultName = "merkle_test" + defaultMode = 0644 + defaultUID = 0 + defaultGID = 0 +) + // bytesReadWriter is used to read from/write to/seek in a byte array. Unlike // bytes.Buffer, it keeps the whole buffer during read so that it can be reused. type bytesReadWriter struct { @@ -138,19 +145,19 @@ func TestGenerate(t *testing.T) { }{ { data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167}, + expectedRoot: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, }, { data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132}, + expectedRoot: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, }, { data: []byte{'a'}, - expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210}, + expectedRoot: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, }, { data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90}, + expectedRoot: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, }, } @@ -158,16 +165,25 @@ func TestGenerate(t *testing.T) { t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { for _, dataAndTreeInSameFile := range []bool{false, true} { var tree bytesReadWriter - var root []byte - var err error + params := GenerateParams{ + Size: int64(len(tc.data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } if dataAndTreeInSameFile { tree.Write(tc.data) - root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + params.File = &tree } else { - root, err = Generate(&bytesReadWriter{ + params.File = &bytesReadWriter{ bytes: tc.data, - }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + } } + root, err := Generate(¶ms) if err != nil { t.Fatalf("Got err: %v, want nil", err) } @@ -194,6 +210,10 @@ func TestVerify(t *testing.T) { // modified byte falls in verification range, Verify should // fail, otherwise Verify should still succeed. modifyByte int64 + modifyName bool + modifyMode bool + modifyUID bool + modifyGID bool shouldSucceed bool }{ // Verify range start outside the data range should fail. @@ -222,12 +242,48 @@ func TestVerify(t *testing.T) { modifyByte: 0, shouldSucceed: false, }, - // Invalid verify range (0 size) should fail. + // 0 verify size should only verify metadata. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + shouldSucceed: true, + }, + // Modified name should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifyName: true, + shouldSucceed: false, + }, + // Modified mode should fail verification. { dataSize: usermem.PageSize, verifyStart: 0, verifySize: 0, modifyByte: 0, + modifyMode: true, + shouldSucceed: false, + }, + // Modified UID should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifyUID: true, + shouldSucceed: false, + }, + // Modified GID should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifyGID: true, shouldSucceed: false, }, // The test cases below use a block-aligned verify range. @@ -316,16 +372,25 @@ func TestVerify(t *testing.T) { for _, dataAndTreeInSameFile := range []bool{false, true} { var tree bytesReadWriter - var root []byte - var err error + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } if dataAndTreeInSameFile { tree.Write(data) - root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + genParams.File = &tree } else { - root, err = Generate(&bytesReadWriter{ + genParams.File = &bytesReadWriter{ bytes: data, - }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */) + } } + root, err := Generate(&genParams) if err != nil { t.Fatalf("Generate failed: %v", err) } @@ -333,8 +398,34 @@ func TestVerify(t *testing.T) { // Flip a bit in data and checks Verify results. var buf bytes.Buffer data[tc.modifyByte] ^= 1 + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: tc.dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + ReadOffset: tc.verifyStart, + ReadSize: tc.verifySize, + ExpectedRoot: root, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if tc.modifyName { + verifyParams.Name = defaultName + "abc" + } + if tc.modifyMode { + verifyParams.Mode = defaultMode + 1 + } + if tc.modifyUID { + verifyParams.UID = defaultUID + 1 + } + if tc.modifyGID { + verifyParams.GID = defaultGID + 1 + } if tc.shouldSucceed { - n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile) + n, err := Verify(&verifyParams) if err != nil && err != io.EOF { t.Errorf("Verification failed when expected to succeed: %v", err) } @@ -348,7 +439,7 @@ func TestVerify(t *testing.T) { t.Errorf("Incorrect output buf from Verify") } } else { - if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { + if _, err := Verify(&verifyParams); err == nil { t.Errorf("Verification succeeded when expected to fail") } } @@ -368,16 +459,26 @@ func TestVerifyRandom(t *testing.T) { for _, dataAndTreeInSameFile := range []bool{false, true} { var tree bytesReadWriter - var root []byte - var err error + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { tree.Write(data) - root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + genParams.File = &tree } else { - root, err = Generate(&bytesReadWriter{ + genParams.File = &bytesReadWriter{ bytes: data, - }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile) + } } + root, err := Generate(&genParams) if err != nil { t.Fatalf("Generate failed: %v", err) } @@ -387,9 +488,24 @@ func TestVerifyRandom(t *testing.T) { size := rand.Int63n(dataSize) + 1 var buf bytes.Buffer + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + ReadOffset: start, + ReadSize: size, + ExpectedRoot: root, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + // Checks that the random portion of data from the original data is // verified successfully. - n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile) + n, err := Verify(&verifyParams) if err != nil && err != io.EOF { t.Errorf("Verification failed for correct data: %v", err) } @@ -406,13 +522,22 @@ func TestVerifyRandom(t *testing.T) { t.Errorf("Incorrect output buf from Verify") } + // Verify that modified metadata should fail verification. buf.Reset() + verifyParams.Name = defaultName + "abc" + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verify succeeded for modified metadata, expect failure") + } + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() randBytePos := rand.Int63n(size) data[start+randBytePos] ^= 1 + verifyParams.File = bytes.NewReader(data) + verifyParams.Name = defaultName - if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { - t.Errorf("Verification succeeded for modified data") + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verification succeeded for modified data, expect failure") } } } diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index bdef7762c..e828894b0 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -49,7 +49,7 @@ go_test( library = ":seccomp", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/bpf", + "//pkg/usermem", ], ) diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 23f30678d..e1444d18b 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -28,17 +28,10 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/usermem" ) -type seccompData struct { - nr uint32 - arch uint32 - instructionPointer uint64 - args [6]uint64 -} - // newVictim makes a victim binary. func newVictim() (string, error) { f, err := ioutil.TempFile("", "victim") @@ -58,9 +51,14 @@ func newVictim() (string, error) { return path, nil } -// asInput converts a seccompData to a bpf.Input. -func (d *seccompData) asInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +// dataAsInput converts a linux.SeccompData to a bpf.Input. +func dataAsInput(d *linux.SeccompData) bpf.Input { + buf := make([]byte, d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + Order: usermem.ByteOrder, + } } func TestBasic(t *testing.T) { @@ -69,7 +67,7 @@ func TestBasic(t *testing.T) { desc string // data is the input data. - data seccompData + data linux.SeccompData // want is the expected return value of the BPF program. want linux.BPFAction @@ -95,12 +93,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "syscall allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "syscall disallowed", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -131,22 +129,22 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed (1a)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (1b)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "syscall 1 matched 2nd rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "no match", - data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, @@ -168,42 +166,42 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed (1)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (3)", - data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 3, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "allowed (5)", - data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 5, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed (0)", - data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 0, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (2)", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (4)", - data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 4, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (6)", - data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 6, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { desc: "disallowed (100)", - data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 100, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -223,7 +221,7 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arch (123)", - data: seccompData{nr: 1, arch: 123}, + data: linux.SeccompData{Nr: 1, Arch: 123}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, @@ -243,7 +241,7 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "action trap", - data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, + data: linux.SeccompData{Nr: 2, Arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -268,12 +266,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -300,12 +298,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "match first rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "match 2nd rule", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -331,28 +329,28 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "argument allowed (all match)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "argument disallowed (one mismatch)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "argument disallowed (multiple mismatch)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -379,28 +377,28 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (one equal)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (all equal)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, - args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, + Args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -429,27 +427,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000003}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -474,27 +472,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (first arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -522,27 +520,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -567,32 +565,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed (both greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg allowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xabcd000d}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (second arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg smaller)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -620,27 +618,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -665,32 +663,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (first arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -718,27 +716,27 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "high 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits greater", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000003}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "high 32bits equal, low 32bits equal", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits equal, low 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000002_00000001}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "high 32bits less", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x00000001_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, }, @@ -764,32 +762,32 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (first arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x1, 0x0}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (first arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0x0}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg allowed (second arg equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xabcd000d}}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (second arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (both arg greater)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{0x2, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -816,51 +814,51 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "arg allowed (low order mandatory bit)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000001 - args: [6]uint64{0x1}, + Args: [6]uint64{0x1}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg allowed (low order optional bit)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000101 - args: [6]uint64{0x5}, + Args: [6]uint64{0x5}, }, want: linux.SECCOMP_RET_ALLOW, }, { desc: "arg disallowed (lowest order bit not set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000010 - args: [6]uint64{0x2}, + Args: [6]uint64{0x2}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (second lowest order bit set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000000 00000011 - args: [6]uint64{0x3}, + Args: [6]uint64{0x3}, }, want: linux.SECCOMP_RET_TRAP, }, { desc: "arg disallowed (8th bit set)", - data: seccompData{ - nr: 1, - arch: LINUX_AUDIT_ARCH, + data: linux.SeccompData{ + Nr: 1, + Arch: LINUX_AUDIT_ARCH, // 00000000 00000000 00000001 00000000 - args: [6]uint64{0x100}, + Args: [6]uint64{0x100}, }, want: linux.SECCOMP_RET_TRAP, }, @@ -885,12 +883,12 @@ func TestBasic(t *testing.T) { specs: []spec{ { desc: "allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { desc: "disallowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344}, + data: linux.SeccompData{Nr: 1, Arch: LINUX_AUDIT_ARCH, Args: [6]uint64{}, InstructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, @@ -906,7 +904,7 @@ func TestBasic(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) + got, err := bpf.Exec(p, dataAsInput(&spec.data)) if err != nil { t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) } @@ -947,8 +945,8 @@ func TestRandom(t *testing.T) { t.Fatalf("bpf.Compile() got error: %v", err) } for i := uint32(0); i < 200; i++ { - data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH} - got, err := bpf.Exec(p, data.asInput()) + data := linux.SeccompData{Nr: int32(i), Arch: LINUX_AUDIT_ARCH} + got, err := bpf.Exec(p, dataAsInput(&data)) if err != nil { t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i) continue diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index 0f433ee79..fd73751e7 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -154,6 +154,7 @@ func (s State) Proto() *rpb.Registers { Sp: s.Regs.Sp, Pc: s.Regs.Pc, Pstate: s.Regs.Pstate, + Tls: s.Regs.TPIDR_EL0, } return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}} } @@ -232,6 +233,7 @@ func (s *State) RegisterMap() (map[string]uintptr, error) { "Sp": uintptr(s.Regs.Sp), "Pc": uintptr(s.Regs.Pc), "Pstate": uintptr(s.Regs.Pstate), + "Tls": uintptr(s.Regs.TPIDR_EL0), }, nil } diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto index 60c027aab..2727ba08a 100644 --- a/pkg/sentry/arch/registers.proto +++ b/pkg/sentry/arch/registers.proto @@ -83,6 +83,7 @@ message ARM64Registers { uint64 sp = 32; uint64 pc = 33; uint64 pstate = 34; + uint64 tls = 35; } message Registers { oneof arch { diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD index 71c59287c..14a8bf9cd 100644 --- a/pkg/sentry/devices/tundev/BUILD +++ b/pkg/sentry/devices/tundev/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/tcpip/link/tun", + "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 0b701a289..655ea549b 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -16,6 +16,8 @@ package tundev import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -26,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -84,7 +87,16 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) + created, err := fd.device.SetIff(stack.Stack, req.Name(), flags) + if err == nil && created { + // Always start with an ARP address for interfaces so they can handle ARP + // packets. + nicID := fd.device.NICID() + if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) + } + } + return 0, err case linux.TUNGETIFF: var req linux.IFReq diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 9379a4d7b..6b7b451b8 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/syserror", "//pkg/tcpip/link/tun", + "//pkg/tcpip/network/arp", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 5f8c9b5a2..19ffdec47 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -15,6 +15,8 @@ package dev import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -25,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -60,7 +63,7 @@ func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMod } // GetFile implements fs.InodeOperations.GetFile. -func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (*netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil } @@ -80,12 +83,12 @@ type netTunFileOperations struct { var _ fs.FileOperations = (*netTunFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (fops *netTunFileOperations) Release(ctx context.Context) { - fops.device.Release(ctx) +func (n *netTunFileOperations) Release(ctx context.Context) { + n.device.Release(ctx) } // Ioctl implements fs.FileOperations.Ioctl. -func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (n *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { request := args[1].Uint() data := args[2].Pointer() @@ -109,16 +112,25 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) - return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + created, err := n.device.SetIff(stack.Stack, req.Name(), flags) + if err == nil && created { + // Always start with an ARP address for interfaces so they can handle ARP + // packets. + nicID := n.device.NICID() + if err := stack.Stack.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + panic(fmt.Sprintf("failed to add ARP address after creating new TUN/TAP interface with ID = %d", nicID)) + } + } + return 0, err case linux.TUNGETIFF: var req linux.IFReq - copy(req.IFName[:], fops.device.Name()) + copy(req.IFName[:], n.device.Name()) // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. - flags := fops.device.Flags() | linux.IFF_NOFILTER + flags := n.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) _, err := req.CopyOut(t, data) @@ -130,41 +142,41 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u } // Write implements fs.FileOperations.Write. -func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { +func (n *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { data := make([]byte, src.NumBytes()) if _, err := src.CopyIn(ctx, data); err != nil { return 0, err } - return fops.device.Write(data) + return n.device.Write(data) } // Read implements fs.FileOperations.Read. -func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - data, err := fops.device.Read() +func (n *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := n.device.Read() if err != nil { return 0, err } - n, err := dst.CopyOut(ctx, data) - if n > 0 && n < len(data) { + bytesCopied, err := dst.CopyOut(ctx, data) + if bytesCopied > 0 && bytesCopied < len(data) { // Not an error for partial copying. Packet truncated. err = nil } - return int64(n), err + return int64(bytesCopied), err } // Readiness implements watier.Waitable.Readiness. -func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return fops.device.Readiness(mask) +func (n *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return n.device.Readiness(mask) } // EventRegister implements watier.Waitable.EventRegister. -func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - fops.device.EventRegister(e, mask) +func (n *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + n.device.EventRegister(e, mask) } // EventUnregister implements watier.Waitable.EventUnregister. -func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { - fops.device.EventUnregister(e) +func (n *netTunFileOperations) EventUnregister(e *waiter.Entry) { + n.device.EventUnregister(e) } // isNetTunSupported returns whether /dev/net/tun device is supported for s. diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 9eb6f522e..85a23432b 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/kernel/time" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -444,7 +443,7 @@ func (c *CachingInodeOperations) TouchAccessTime(ctx context.Context, inode *fs. // time. // // Preconditions: c.attrMu is locked for writing. -func (c *CachingInodeOperations) touchAccessTimeLocked(now time.Time) { +func (c *CachingInodeOperations) touchAccessTimeLocked(now ktime.Time) { c.attr.AccessTime = now c.dirtyAttr.AccessTime = true } @@ -461,7 +460,7 @@ func (c *CachingInodeOperations) TouchModificationAndStatusChangeTime(ctx contex // and status change times in-place to the current time. // // Preconditions: c.attrMu is locked for writing. -func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now time.Time) { +func (c *CachingInodeOperations) touchModificationAndStatusChangeTimeLocked(now ktime.Time) { c.attr.ModificationTime = now c.dirtyAttr.ModificationTime = true c.attr.StatusChangeTime = now @@ -480,7 +479,7 @@ func (c *CachingInodeOperations) TouchStatusChangeTime(ctx context.Context) { // in-place to the current time. // // Preconditions: c.attrMu is locked for writing. -func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now time.Time) { +func (c *CachingInodeOperations) touchStatusChangeTimeLocked(now ktime.Time) { c.attr.StatusChangeTime = now c.dirtyAttr.StatusChangeTime = true } @@ -672,9 +671,6 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Continue. seg, gap = gap.NextSegment(), FileRangeGapIterator{} } - - default: - break } } unlock() @@ -768,9 +764,6 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error // Continue. seg, gap = gap.NextSegment(), FileRangeGapIterator{} - - default: - break } } rw.maybeGrowFile() diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go index 323506d33..be0900030 100644 --- a/pkg/sentry/fsbridge/vfs.go +++ b/pkg/sentry/fsbridge/vfs.go @@ -122,7 +122,7 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) // remainingTraversals is not configurable in VFS2, all callers are using the // default anyways. func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) { - vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem() + vfsObj := l.root.Mount().Filesystem().VirtualFilesystem() creds := auth.CredentialsFromContext(ctx) path := fspath.Parse(pathname) pop := &vfs.PathOperation{ diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index abc610ef3..7b1eec3da 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/fd", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", @@ -86,9 +88,9 @@ go_test( library = ":ext", deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/ext/disklayout", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go index 8bb104ff0..1165234f9 100644 --- a/pkg/sentry/fsimpl/ext/block_map_file.go +++ b/pkg/sentry/fsimpl/ext/block_map_file.go @@ -18,7 +18,7 @@ import ( "io" "math" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/syserror" ) @@ -34,19 +34,19 @@ type blockMapFile struct { // directBlks are the direct blocks numbers. The physical blocks pointed by // these holds file data. Contains file blocks 0 to 11. - directBlks [numDirectBlks]uint32 + directBlks [numDirectBlks]primitive.Uint32 // indirectBlk is the physical block which contains (blkSize/4) direct block // numbers (as uint32 integers). - indirectBlk uint32 + indirectBlk primitive.Uint32 // doubleIndirectBlk is the physical block which contains (blkSize/4) indirect // block numbers (as uint32 integers). - doubleIndirectBlk uint32 + doubleIndirectBlk primitive.Uint32 // tripleIndirectBlk is the physical block which contains (blkSize/4) doubly // indirect block numbers (as uint32 integers). - tripleIndirectBlk uint32 + tripleIndirectBlk primitive.Uint32 // coverage at (i)th index indicates the amount of file data a node at // height (i) covers. Height 0 is the direct block. @@ -68,10 +68,12 @@ func newBlockMapFile(args inodeArgs) (*blockMapFile, error) { } blkMap := file.regFile.inode.diskInode.Data() - binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks) - binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk) - binary.Unmarshal(blkMap[(numDirectBlks+2)*4:(numDirectBlks+3)*4], binary.LittleEndian, &file.tripleIndirectBlk) + for i := 0; i < numDirectBlks; i++ { + file.directBlks[i].UnmarshalBytes(blkMap[i*4 : (i+1)*4]) + } + file.indirectBlk.UnmarshalBytes(blkMap[numDirectBlks*4 : (numDirectBlks+1)*4]) + file.doubleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+1)*4 : (numDirectBlks+2)*4]) + file.tripleIndirectBlk.UnmarshalBytes(blkMap[(numDirectBlks+2)*4 : (numDirectBlks+3)*4]) return file, nil } @@ -117,16 +119,16 @@ func (f *blockMapFile) ReadAt(dst []byte, off int64) (int, error) { switch { case offset < dirBlksEnd: // Direct block. - curR, err = f.read(f.directBlks[offset/f.regFile.inode.blkSize], offset%f.regFile.inode.blkSize, 0, dst[read:]) + curR, err = f.read(uint32(f.directBlks[offset/f.regFile.inode.blkSize]), offset%f.regFile.inode.blkSize, 0, dst[read:]) case offset < indirBlkEnd: // Indirect block. - curR, err = f.read(f.indirectBlk, offset-dirBlksEnd, 1, dst[read:]) + curR, err = f.read(uint32(f.indirectBlk), offset-dirBlksEnd, 1, dst[read:]) case offset < doubIndirBlkEnd: // Doubly indirect block. - curR, err = f.read(f.doubleIndirectBlk, offset-indirBlkEnd, 2, dst[read:]) + curR, err = f.read(uint32(f.doubleIndirectBlk), offset-indirBlkEnd, 2, dst[read:]) default: // Triply indirect block. - curR, err = f.read(f.tripleIndirectBlk, offset-doubIndirBlkEnd, 3, dst[read:]) + curR, err = f.read(uint32(f.tripleIndirectBlk), offset-doubIndirBlkEnd, 3, dst[read:]) } read += curR @@ -174,13 +176,13 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds read := 0 curChildOff := relFileOff % childCov for i := startIdx; i < endIdx; i++ { - var childPhyBlk uint32 + var childPhyBlk primitive.Uint32 err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk) if err != nil { return read, err } - n, err := f.read(childPhyBlk, curChildOff, height-1, dst[read:]) + n, err := f.read(uint32(childPhyBlk), curChildOff, height-1, dst[read:]) read += n if err != nil { return read, err diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go index 6fa84e7aa..ed98b482e 100644 --- a/pkg/sentry/fsimpl/ext/block_map_test.go +++ b/pkg/sentry/fsimpl/ext/block_map_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -87,29 +87,33 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) { mockDisk := make([]byte, mockBMDiskSize) var fileData []byte blkNums := newBlkNumGen() - var data []byte + off := 0 + data := make([]byte, (numDirectBlks+3)*(*primitive.Uint32)(nil).SizeBytes()) // Write the direct blocks. for i := 0; i < numDirectBlks; i++ { - curBlkNum := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, curBlkNum) - fileData = append(fileData, writeFileDataToBlock(mockDisk, curBlkNum, 0, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(data[off:]) + off += curBlkNum.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(curBlkNum), 0, blkNums)...) } // Write to indirect block. - indirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, indirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, indirectBlk, 1, blkNums)...) - - // Write to indirect block. - doublyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, doublyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, doublyIndirectBlk, 2, blkNums)...) - - // Write to indirect block. - triplyIndirectBlk := blkNums.next() - data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk) - fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...) + indirectBlk := primitive.Uint32(blkNums.next()) + indirectBlk.MarshalBytes(data[off:]) + off += indirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(indirectBlk), 1, blkNums)...) + + // Write to double indirect block. + doublyIndirectBlk := primitive.Uint32(blkNums.next()) + doublyIndirectBlk.MarshalBytes(data[off:]) + off += doublyIndirectBlk.SizeBytes() + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(doublyIndirectBlk), 2, blkNums)...) + + // Write to triple indirect block. + triplyIndirectBlk := primitive.Uint32(blkNums.next()) + triplyIndirectBlk.MarshalBytes(data[off:]) + fileData = append(fileData, writeFileDataToBlock(mockDisk, uint32(triplyIndirectBlk), 3, blkNums)...) args := inodeArgs{ fs: &filesystem{ @@ -142,9 +146,9 @@ func writeFileDataToBlock(disk []byte, blkNum uint32, height uint, blkNums *blkN var fileData []byte for off := blkNum * mockBMBlkSize; off < (blkNum+1)*mockBMBlkSize; off += 4 { - curBlkNum := blkNums.next() - copy(disk[off:off+4], binary.Marshal(nil, binary.LittleEndian, curBlkNum)) - fileData = append(fileData, writeFileDataToBlock(disk, curBlkNum, height-1, blkNums)...) + curBlkNum := primitive.Uint32(blkNums.next()) + curBlkNum.MarshalBytes(disk[off : off+4]) + fileData = append(fileData, writeFileDataToBlock(disk, uint32(curBlkNum), height-1, blkNums)...) } return fileData } diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 452450d82..0ad79b381 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -16,7 +16,6 @@ package ext import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -100,7 +99,7 @@ func newDirectory(args inodeArgs, newDirent bool) (*directory, error) { } else { curDirent.diskDirent = &disklayout.DirentOld{} } - binary.Unmarshal(buf, binary.LittleEndian, curDirent.diskDirent) + curDirent.diskDirent.UnmarshalBytes(buf) if curDirent.diskDirent.Inode() != 0 && len(curDirent.diskDirent.FileName()) != 0 { // Inode number and name length fields being set to 0 is used to indicate diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index 9bd9c76c0..d98a05dd8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -22,10 +22,11 @@ go_library( "superblock_old.go", "test_utils.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/marshal", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group.go b/pkg/sentry/fsimpl/ext/disklayout/block_group.go index ad6f4fef8..0d56ae9da 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // BlockGroup represents a Linux ext block group descriptor. An ext file system // is split into a series of block groups. This provides an access layer to // information needed to access and use a block group. @@ -30,6 +34,8 @@ package disklayout // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#block-group-descriptors. type BlockGroup interface { + marshal.Marshallable + // InodeTable returns the absolute block number of the block containing the // inode table. This points to an array of Inode structs. Inode tables are // statically allocated at mkfs time. The superblock records the number of diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go index 3e16c76db..a35fa22a0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_32.go @@ -17,6 +17,8 @@ package disklayout // BlockGroup32Bit emulates the first half of struct ext4_group_desc in // fs/ext4/ext4.h. It is the block group descriptor struct for ext2, ext3 and // 32-bit ext4 filesystems. It implements BlockGroup interface. +// +// +marshal type BlockGroup32Bit struct { BlockBitmapLo uint32 InodeBitmapLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go index 9a809197a..d54d1d345 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_64.go @@ -18,6 +18,8 @@ package disklayout // It is the block group descriptor struct for 64-bit ext4 filesystems. // It implements BlockGroup interface. It is an extension of the 32-bit // version of BlockGroup. +// +// +marshal type BlockGroup64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go index 0ef4294c0..e4ce484e4 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/block_group_test.go @@ -21,6 +21,8 @@ import ( // TestBlockGroupSize tests that the block group descriptor structs are of the // correct size. func TestBlockGroupSize(t *testing.T) { - assertSize(t, BlockGroup32Bit{}, 32) - assertSize(t, BlockGroup64Bit{}, 64) + var bgSmall BlockGroup32Bit + assertSize(t, &bgSmall, 32) + var bgBig BlockGroup64Bit + assertSize(t, &bgBig, 64) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent.go b/pkg/sentry/fsimpl/ext/disklayout/dirent.go index 417b6cf65..568c8cb4c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent.go @@ -15,6 +15,7 @@ package disklayout import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fs" ) @@ -51,6 +52,8 @@ var ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#linear-classic-directories. type Dirent interface { + marshal.Marshallable + // Inode returns the absolute inode number of the underlying inode. // Inode number 0 signifies an unused dirent. Inode() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go index 29ae4a5c2..51f9c2946 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_new.go @@ -29,12 +29,14 @@ import ( // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentNew struct { InodeNumber uint32 RecordLength uint16 NameLength uint8 FileTypeRaw uint8 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentNew implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go index 6fff12a6e..d4b19e086 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_old.go @@ -22,11 +22,13 @@ import "gvisor.dev/gvisor/pkg/sentry/fs" // Note: This struct can be of variable size on disk. The one described below // is of maximum size and the FileName beyond NameLength bytes might contain // garbage. +// +// +marshal type DirentOld struct { InodeNumber uint32 RecordLength uint16 NameLength uint16 - FileNameRaw [MaxFileName]byte + FileNameRaw [MaxFileName]byte `marshal:"unaligned"` } // Compiles only if DirentOld implements Dirent. diff --git a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go index 934919f8a..3486864dc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/dirent_test.go @@ -21,6 +21,8 @@ import ( // TestDirentSize tests that the dirent structs are of the correct // size. func TestDirentSize(t *testing.T) { - assertSize(t, DirentOld{}, uintptr(DirentSize)) - assertSize(t, DirentNew{}, uintptr(DirentSize)) + var dOld DirentOld + assertSize(t, &dOld, DirentSize) + var dNew DirentNew + assertSize(t, &dNew, DirentSize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go index bdf4e2132..0834e9ba8 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/disklayout.go +++ b/pkg/sentry/fsimpl/ext/disklayout/disklayout.go @@ -36,8 +36,6 @@ // escape analysis on an unknown implementation at compile time. // // Notes: -// - All fields in these structs are exported because binary.Read would -// panic otherwise. // - All structures on disk are in little-endian order. Only jbd2 (journal) // structures are in big-endian order. // - All OS dependent fields in these structures will be interpretted using diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go index 4110649ab..b13999bfc 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + // Extents were introduced in ext4 and provide huge performance gains in terms // data locality and reduced metadata block usage. Extents are organized in // extent trees. The root node is contained in inode.BlocksRaw. @@ -64,6 +68,8 @@ type ExtentNode struct { // ExtentEntry represents an extent tree node entry. The entry can either be // an ExtentIdx or Extent itself. This exists to simplify navigation logic. type ExtentEntry interface { + marshal.Marshallable + // FileBlock returns the first file block number covered by this entry. FileBlock() uint32 @@ -75,6 +81,8 @@ type ExtentEntry interface { // tree node begins with this and is followed by `NumEntries` number of: // - Extent if `Depth` == 0 // - ExtentIdx otherwise +// +// +marshal type ExtentHeader struct { // Magic in the extent magic number, must be 0xf30a. Magic uint16 @@ -96,6 +104,8 @@ type ExtentHeader struct { // internal nodes. Sorted in ascending order based on FirstFileBlock since // Linux does a binary search on this. This points to a block containing the // child node. +// +// +marshal type ExtentIdx struct { FirstFileBlock uint32 ChildBlockLo uint32 @@ -121,6 +131,8 @@ func (ei *ExtentIdx) PhysicalBlock() uint64 { // nodes. Sorted in ascending order based on FirstFileBlock since Linux does a // binary search on this. This points to an array of data blocks containing the // file data. It covers `Length` data blocks starting from `StartBlock`. +// +// +marshal type Extent struct { FirstFileBlock uint32 Length uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go index 8762b90db..c96002e19 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go @@ -21,7 +21,10 @@ import ( // TestExtentSize tests that the extent structs are of the correct // size. func TestExtentSize(t *testing.T) { - assertSize(t, ExtentHeader{}, ExtentHeaderSize) - assertSize(t, ExtentIdx{}, ExtentEntrySize) - assertSize(t, Extent{}, ExtentEntrySize) + var h ExtentHeader + assertSize(t, &h, ExtentHeaderSize) + var i ExtentIdx + assertSize(t, &i, ExtentEntrySize) + var e Extent + assertSize(t, &e, ExtentEntrySize) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode.go b/pkg/sentry/fsimpl/ext/disklayout/inode.go index 88ae913f5..ef25040a9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode.go @@ -16,6 +16,7 @@ package disklayout import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) @@ -38,6 +39,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#index-nodes. type Inode interface { + marshal.Marshallable + // Mode returns the linux file mode which is majorly used to extract // information like: // - File permissions (read/write/execute by user/group/others). diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go index 8f9f574ce..a4503f5cf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_new.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_new.go @@ -27,6 +27,8 @@ import "gvisor.dev/gvisor/pkg/sentry/kernel/time" // are used to provide nanoscond precision. Hence, these timestamps will now // overflow in May 2446. // See https://www.kernel.org/doc/html/latest/filesystems/ext4/dynamic.html#inode-timestamps. +// +// +marshal type InodeNew struct { InodeOld diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go index db25b11b6..e6b28babf 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_old.go @@ -30,6 +30,8 @@ const ( // // All fields representing time are in seconds since the epoch. Which means that // they will overflow in January 2038. +// +// +marshal type InodeOld struct { ModeRaw uint16 UIDLo uint16 diff --git a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go index dd03ee50e..90744e956 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/inode_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/inode_test.go @@ -24,10 +24,12 @@ import ( // TestInodeSize tests that the inode structs are of the correct size. func TestInodeSize(t *testing.T) { - assertSize(t, InodeOld{}, OldInodeSize) + var iOld InodeOld + assertSize(t, &iOld, OldInodeSize) // This was updated from 156 bytes to 160 bytes in Oct 2015. - assertSize(t, InodeNew{}, 160) + var iNew InodeNew + assertSize(t, &iNew, 160) } // TestTimestampSeconds tests that the seconds part of [a/c/m] timestamps in diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock.go b/pkg/sentry/fsimpl/ext/disklayout/superblock.go index 8bb327006..70948ebe9 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock.go @@ -14,6 +14,10 @@ package disklayout +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + const ( // SbOffset is the absolute offset at which the superblock is placed. SbOffset = 1024 @@ -38,6 +42,8 @@ const ( // // See https://www.kernel.org/doc/html/latest/filesystems/ext4/globals.html#super-block. type SuperBlock interface { + marshal.Marshallable + // InodesCount returns the total number of inodes in this filesystem. InodesCount() uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go index 53e515fd3..4dc6080fb 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_32.go @@ -17,6 +17,8 @@ package disklayout // SuperBlock32Bit implements SuperBlock and represents the 32-bit version of // the ext4_super_block struct in fs/ext4/ext4.h. Should be used only if // RevLevel = DynamicRev and 64-bit feature is disabled. +// +// +marshal type SuperBlock32Bit struct { // We embed the old superblock struct here because the 32-bit version is just // an extension of the old version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go index 7c1053fb4..2c9039327 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_64.go @@ -19,6 +19,8 @@ package disklayout // 1024 bytes (smallest possible block size) and hence the superblock always // fits in no more than one data block. Should only be used when the 64-bit // feature is set. +// +// +marshal type SuperBlock64Bit struct { // We embed the 32-bit struct here because 64-bit version is just an extension // of the 32-bit version. diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go index 9221e0251..e4709f23c 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_old.go @@ -16,6 +16,8 @@ package disklayout // SuperBlockOld implements SuperBlock and represents the old version of the // superblock struct. Should be used only if RevLevel = OldRev. +// +// +marshal type SuperBlockOld struct { InodesCountRaw uint32 BlocksCountLo uint32 diff --git a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go index 463b5ba21..b734b6987 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go +++ b/pkg/sentry/fsimpl/ext/disklayout/superblock_test.go @@ -21,7 +21,10 @@ import ( // TestSuperBlockSize tests that the superblock structs are of the correct // size. func TestSuperBlockSize(t *testing.T) { - assertSize(t, SuperBlockOld{}, 84) - assertSize(t, SuperBlock32Bit{}, 336) - assertSize(t, SuperBlock64Bit{}, 1024) + var sbOld SuperBlockOld + assertSize(t, &sbOld, 84) + var sb32 SuperBlock32Bit + assertSize(t, &sb32, 336) + var sb64 SuperBlock64Bit + assertSize(t, &sb64, 1024) } diff --git a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go index 9c63f04c0..a4bc08411 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/test_utils.go +++ b/pkg/sentry/fsimpl/ext/disklayout/test_utils.go @@ -18,13 +18,13 @@ import ( "reflect" "testing" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" ) -func assertSize(t *testing.T, v interface{}, want uintptr) { +func assertSize(t *testing.T, v marshal.Marshallable, want int) { t.Helper() - if got := binary.Size(v); got != want { + if got := v.SizeBytes(); got != want { t.Errorf("struct %s should be exactly %d bytes but is %d bytes", reflect.TypeOf(v).Name(), want, got) } } diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go index 04917d762..778460107 100644 --- a/pkg/sentry/fsimpl/ext/extent_file.go +++ b/pkg/sentry/fsimpl/ext/extent_file.go @@ -18,7 +18,6 @@ import ( "io" "sort" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) @@ -60,7 +59,7 @@ func newExtentFile(args inodeArgs) (*extentFile, error) { func (f *extentFile) buildExtTree() error { rootNodeData := f.regFile.inode.diskInode.Data() - binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header) + f.root.Header.UnmarshalBytes(rootNodeData[:disklayout.ExtentHeaderSize]) // Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries. if f.root.Header.NumEntries > 4 { @@ -79,7 +78,7 @@ func (f *extentFile) buildExtTree() error { // Internal node. curEntry = &disklayout.ExtentIdx{} } - binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry) + curEntry.UnmarshalBytes(rootNodeData[off : off+disklayout.ExtentEntrySize]) f.root.Entries[i].Entry = curEntry } diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go index cd10d46ee..985f76ac0 100644 --- a/pkg/sentry/fsimpl/ext/extent_test.go +++ b/pkg/sentry/fsimpl/ext/extent_test.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" ) @@ -202,13 +201,14 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, [] // writeTree writes the tree represented by `root` to the inode and disk. It // also writes random file data on disk. func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBlkSize uint64) []byte { - rootData := binary.Marshal(nil, binary.LittleEndian, root.Header) + rootData := in.diskInode.Data() + root.Header.MarshalBytes(rootData) + off := root.Header.SizeBytes() for _, ep := range root.Entries { - rootData = binary.Marshal(rootData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(rootData[off:]) + off += ep.Entry.SizeBytes() } - copy(in.diskInode.Data(), rootData) - var fileData []byte for _, ep := range root.Entries { if root.Header.Height == 0 { @@ -223,13 +223,14 @@ func writeTree(in *inode, disk []byte, root *disklayout.ExtentNode, mockExtentBl // writeTreeToDisk is the recursive step for writeTree which writes the tree // on the disk only. Also writes random file data on disk. func writeTreeToDisk(disk []byte, curNode disklayout.ExtentEntryPair) []byte { - nodeData := binary.Marshal(nil, binary.LittleEndian, curNode.Node.Header) + nodeData := disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:] + curNode.Node.Header.MarshalBytes(nodeData) + off := curNode.Node.Header.SizeBytes() for _, ep := range curNode.Node.Entries { - nodeData = binary.Marshal(nodeData, binary.LittleEndian, ep.Entry) + ep.Entry.MarshalBytes(nodeData[off:]) + off += ep.Entry.SizeBytes() } - copy(disk[curNode.Entry.PhysicalBlock()*mockExtentBlkSize:], nodeData) - var fileData []byte for _, ep := range curNode.Node.Entries { if curNode.Node.Header.Height == 0 { diff --git a/pkg/sentry/fsimpl/ext/utils.go b/pkg/sentry/fsimpl/ext/utils.go index d8b728f8c..58ef7b9b8 100644 --- a/pkg/sentry/fsimpl/ext/utils.go +++ b/pkg/sentry/fsimpl/ext/utils.go @@ -17,21 +17,21 @@ package ext import ( "io" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/syserror" ) // readFromDisk performs a binary read from disk into the given struct from // the absolute offset provided. -func readFromDisk(dev io.ReaderAt, abOff int64, v interface{}) error { - n := binary.Size(v) +func readFromDisk(dev io.ReaderAt, abOff int64, v marshal.Marshallable) error { + n := v.SizeBytes() buf := make([]byte, n) if read, _ := dev.ReadAt(buf, abOff); read < int(n) { return syserror.EIO } - binary.Unmarshal(buf, binary.LittleEndian, v) + v.UnmarshalBytes(buf) return nil } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 131145b85..8a447e29f 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -348,10 +348,10 @@ func (e *SCMConnectedEndpoint) Init() error { func (e *SCMConnectedEndpoint) Release(ctx context.Context) { e.DecRef(func() { e.mu.Lock() + fdnotifier.RemoveFD(int32(e.fd)) if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) } - fdnotifier.RemoveFD(int32(e.fd)) e.destroyLocked() e.mu.Unlock() }) diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD index 067c1657f..adb610213 100644 --- a/pkg/sentry/fsimpl/signalfd/BUILD +++ b/pkg/sentry/fsimpl/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/kernel", "//pkg/sentry/vfs", diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index bf11b425a..10f1452ef 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -16,7 +16,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -95,8 +94,7 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -105,9 +103,13 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index b75d70ae6..1a6749e53 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -104,7 +104,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro func (fd *kcovFD) Release(ctx context.Context) { // kcov instances have reference counts in Linux, but this seems sufficient // for our purposes. - fd.kcov.Reset() + fd.kcov.Clear() } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index bc8e38431..0ca750281 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") licenses(["notice"]) @@ -26,3 +26,22 @@ go_library( "//pkg/usermem", ], ) + +go_test( + name = "verity_test", + srcs = [ + "verity_test.go", + ], + library = ":verity", + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/contexttest", + "//pkg/sentry/vfs", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 26b117ca4..7779271a9 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -20,6 +20,7 @@ import ( "io" "strconv" "strings" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -251,11 +252,35 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Ctx: ctx, } + parentStat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + }, &vfs.StatOptions{}) + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + } + if err != nil { + return nil, err + } + // Since we are verifying against a directory Merkle tree, buf should // contain the root hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { + if _, err := merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + ReadOffset: int64(offset), + ReadSize: int64(merkletree.DigestSize()), + ExpectedRoot: parent.rootHash, + DataAndTreeInSameFile: true, + }); err != nil && err != io.EOF { return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } @@ -266,6 +291,84 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de return child, nil } +// verifyStat verifies the stat against the verified root hash. The mode/uid/gid +// of the file is cached after verified. +func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Statx) error { + vfsObj := fs.vfsfs.VirtualFilesystem() + + // Get the path to the child dentry. This is only used to provide path + // information in failure case. + childPath, err := vfsObj.PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD) + if err != nil { + return err + } + + verityMu.RLock() + defer verityMu.RUnlock() + + fd, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err == syserror.ENOENT { + return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + } + if err != nil { + return err + } + + merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + if err == syserror.ENODATA { + return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + } + if err != nil { + return err + } + + size, err := strconv.Atoi(merkleSize) + if err != nil { + return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + } + + fdReader := vfs.FileReadWriteSeeker{ + FD: fd, + Ctx: ctx, + } + + var buf bytes.Buffer + params := &merkletree.VerifyParams{ + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + ReadOffset: 0, + // Set read size to 0 so only the metadata is verified. + ReadSize: 0, + ExpectedRoot: d.rootHash, + DataAndTreeInSameFile: false, + } + if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR { + params.DataAndTreeInSameFile = true + } + + if _, err := merkletree.Verify(params); err != nil && err != io.EOF { + return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + } + d.mode = uint32(stat.Mode) + d.uid = stat.UID + d.gid = stat.GID + return nil +} + // Preconditions: fs.renameMu must be locked. d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { @@ -274,9 +377,27 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // runtime enable is allowed and the parent directory is // enabled, we should verify the child root hash here because // it may be cached before enabled. - if fs.allowRuntimeEnable && len(parent.rootHash) != 0 { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { - return nil, err + if fs.allowRuntimeEnable { + if isEnabled(parent) { + if _, err := fs.verifyChild(ctx, parent, child); err != nil { + return nil, err + } + } + if isEnabled(child) { + vfsObj := fs.vfsfs.VirtualFilesystem() + mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) + stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ + Root: child.lowerVD, + Start: child.lowerVD, + }, &vfs.StatOptions{ + Mask: mask, + }) + if err != nil { + return nil, err + } + if err := fs.verifyStat(ctx, child, stat); err != nil { + return nil, err + } } } return child, nil @@ -426,7 +547,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, child.parent = parent child.name = name - // TODO(b/162788573): Verify child metadata. child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID @@ -434,12 +554,18 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // Verify child root hash. This should always be performed unless in // allowRuntimeEnable mode and the parent directory hasn't been enabled // yet. - if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) { + if isEnabled(parent) { if _, err := fs.verifyChild(ctx, parent, child); err != nil { child.destroyLocked(ctx) return nil, err } } + if isEnabled(child) { + if err := fs.verifyStat(ctx, child, stat); err != nil { + child.destroyLocked(ctx) + return nil, err + } + } return child, nil } @@ -693,22 +819,24 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // be called if a verity FD is created successfully. defer merkleWriter.DecRef(ctx) - parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.parent.lowerMerkleVD, - Start: d.parent.lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_WRONLY | linux.O_APPEND, - }) - if err != nil { - if err == syserror.ENOENT { - parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + if d.parent != nil { + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err } - return nil, err + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) } - // parentMerkleWriter is cleaned up if any error occurs. IncRef - // will be called if a verity FD is created successfully. - defer parentMerkleWriter.DecRef(ctx) } fd := &fileDescription{ @@ -769,6 +897,8 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } // StatAt implements vfs.FilesystemImpl.StatAt. +// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should +// be verified. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() @@ -786,6 +916,11 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } + if isEnabled(d) { + if err := fs.verifyStat(ctx, d, stat); err != nil { + return linux.Statx{}, err + } + } return stat, nil } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 3129f290d..3eb972237 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -142,6 +142,14 @@ func (FilesystemType) Name() string { return Name } +// isEnabled checks whether the target is enabled with verity features. It +// should always be true if runtime enable is not allowed. In runtime enable +// mode, it returns true if the target has been enabled with +// ioctl(FS_IOC_ENABLE_VERITY). +func isEnabled(d *dentry) bool { + return !d.fs.allowRuntimeEnable || len(d.rootHash) != 0 +} + // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In // noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. @@ -245,12 +253,17 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } - // TODO(b/162788573): Verify Metadata. d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID - d.rootHash = make([]byte, len(iopts.RootHash)) + + if !fs.allowRuntimeEnable { + if err := fs.verifyStat(ctx, d, stat); err != nil { + return nil, nil, err + } + } + copy(d.rootHash, iopts.RootHash) d.vfsd.Init(d) @@ -488,6 +501,11 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu if err != nil { return linux.Statx{}, err } + if isEnabled(fd.d) { + if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil { + return linux.Statx{}, err + } + } return stat, nil } @@ -516,8 +534,10 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, FD: fd.merkleWriter, Ctx: ctx, } - var rootHash []byte - var dataSize uint64 + params := &merkletree.GenerateParams{ + TreeReader: &merkleReader, + TreeWriter: &merkleWriter, + } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { case linux.S_IFREG: @@ -528,12 +548,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, if err != nil { return nil, 0, err } - dataSize = stat.Size - rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */) - if err != nil { - return nil, 0, err - } + params.File = &fdReader + params.Size = int64(stat.Size) + params.Name = fd.d.name + params.Mode = uint32(stat.Mode) + params.UID = stat.UID + params.GID = stat.GID + params.DataAndTreeInSameFile = false case linux.S_IFDIR: // For a directory, generate a Merkle tree based on the root // hashes of its children that has already been written to the @@ -542,23 +564,32 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, if err != nil { return nil, 0, err } - dataSize = merkleStat.Size - rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */) + params.Size = int64(merkleStat.Size) + + stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{}) if err != nil { return nil, 0, err } + + params.File = &merkleReader + params.Name = fd.d.name + params.Mode = uint32(stat.Mode) + params.UID = stat.UID + params.GID = stat.GID + params.DataAndTreeInSameFile = true default: // TODO(b/167728857): Investigate whether and how we should // enable other types of file. return nil, 0, syserror.EINVAL } - return rootHash, dataSize, nil + rootHash, err := merkletree.Generate(params) + return rootHash, uint64(params.Size), err } // enableVerity enables verity features on fd by generating a Merkle tree file // and stores its root hash in its parent directory's Merkle tree. -func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (uintptr, error) { if !fd.d.fs.allowRuntimeEnable { return 0, syserror.EPERM } @@ -568,7 +599,11 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg verityMu.Lock() defer verityMu.Unlock() - if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil { + // In allowRuntimeEnable mode, the underlying fd and read/write fd for + // the Merkle tree file should have all been initialized. For any file + // or directory other than the root, the parent Merkle tree file should + // have also been initialized. + if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") } @@ -577,26 +612,28 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg return 0, err } - stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) - if err != nil { - return 0, err - } + if fd.parentMerkleWriter != nil { + stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return 0, err + } - // Write the root hash of fd to the parent directory's Merkle tree - // file, as it should be part of the parent Merkle tree data. - // parentMerkleWriter is open with O_APPEND, so it should write - // directly to the end of the file. - if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { - return 0, err - } + // Write the root hash of fd to the parent directory's Merkle + // tree file, as it should be part of the parent Merkle tree + // data. parentMerkleWriter is open with O_APPEND, so it + // should write directly to the end of the file. + if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { + return 0, err + } - // Record the offset of the root hash of fd in parent directory's - // Merkle tree file. - if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ - Name: merkleOffsetInParentXattr, - Value: strconv.Itoa(int(stat.Size)), - }); err != nil { - return 0, err + // Record the offset of the root hash of fd in parent directory's + // Merkle tree file. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleOffsetInParentXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return 0, err + } } // Record the size of the data being hashed for fd. @@ -610,7 +647,45 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, arg return 0, nil } -func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { +// measureVerity returns the root hash of fd, saved in args[2]. +func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + var metadata linux.DigestMetadata + + // If allowRuntimeEnable is true, an empty fd.d.rootHash indicates that + // verity is not enabled for the file. If allowRuntimeEnable is false, + // this is an integrity violation because all files should have verity + // enabled, in which case fd.d.rootHash should be set. + if len(fd.d.rootHash) == 0 { + if fd.d.fs.allowRuntimeEnable { + return 0, syserror.ENODATA + } + return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no root hash found") + } + + // The first part of VerityDigest is the metadata. + if _, err := metadata.CopyIn(t, verityDigest); err != nil { + return 0, err + } + if metadata.DigestSize < uint16(len(fd.d.rootHash)) { + return 0, syserror.EOVERFLOW + } + + // Populate the output digest size, since DigestSize is both input and + // output. + metadata.DigestSize = uint16(len(fd.d.rootHash)) + + // First copy the metadata. + if _, err := metadata.CopyOut(t, verityDigest); err != nil { + return 0, err + } + + // Now copy the root hash bytes to the memory after metadata. + _, err := t.CopyOutBytes(usermem.Addr(uintptr(verityDigest)+linux.SizeOfDigestMetadata), fd.d.rootHash) + return 0, err +} + +func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flags usermem.Addr) (uintptr, error) { f := int32(0) // All enabled files should store a root hash. This flag is not settable @@ -620,8 +695,7 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar } t := kernel.TaskFromContext(ctx) - addr := args[2].Pointer() - _, err := primitive.CopyInt32Out(t, addr, f) + _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -629,11 +703,15 @@ func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args ar func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FS_IOC_ENABLE_VERITY: - return fd.enableVerity(ctx, uio, args) + return fd.enableVerity(ctx, uio) + case linux.FS_IOC_MEASURE_VERITY: + return fd.measureVerity(ctx, uio, args[2].Pointer()) case linux.FS_IOC_GETFLAGS: - return fd.getFlags(ctx, uio, args) + return fd.verityFlags(ctx, uio, args[2].Pointer()) default: - return fd.lowerFD.Ioctl(ctx, uio, args) + // TODO(b/169682228): Investigate which ioctl commands should + // be allowed. + return 0, syserror.ENOSYS } } @@ -641,7 +719,7 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in // allowRuntimeEnable mode. - if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 { + if !isEnabled(fd.d) { return fd.lowerFD.PRead(ctx, dst, offset, opts) } @@ -678,9 +756,22 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of Ctx: ctx, } - n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */) + n, err := merkletree.Verify(&merkletree.VerifyParams{ + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + ReadOffset: offset, + ReadSize: dst.NumBytes(), + ExpectedRoot: fd.d.rootHash, + DataAndTreeInSameFile: false, + }) if err != nil { - return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) } return n, err } diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go new file mode 100644 index 000000000..8d0926bc4 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -0,0 +1,490 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "fmt" + "io" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// rootMerkleFilename is the name of the root Merkle tree file. +const rootMerkleFilename = "root.verity" + +// maxDataSize is the maximum data size written to the file for test. +const maxDataSize = 100000 + +// newVerityRoot creates a new verity mount, and returns the root. The +// underlying file system is tmpfs. If the error is not nil, then cleanup +// should be called when the root is no longer needed. +func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { + rand.Seed(time.Now().UnixNano()) + vfsObj := &vfs.VirtualFilesystem{} + if err := vfsObj.Init(ctx); err != nil { + return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + } + + vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + }) + + mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: InternalFilesystemOptions{ + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + AllowRuntimeEnable: true, + NoCrashOnVerificationFailure: true, + }, + }, + }) + if err != nil { + return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + } + root := mntns.Root() + t.Helper() + t.Cleanup(func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + }) + return vfsObj, root, nil +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + creds := auth.CredentialsFromContext(ctx) + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + + // Create the file in the underlying file system. + lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, + Mode: linux.ModeRegular | mode, + }) + if err != nil { + return nil, 0, err + } + + // Generate random data to be written to the file. + dataSize := rand.Intn(maxDataSize) + 1 + data := make([]byte, dataSize) + rand.Read(data) + + // Write directly to the underlying FD, since verity FD is read-only. + n, err := lowerFD.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{}) + if err != nil { + return nil, 0, err + } + + if n != int64(len(data)) { + return nil, 0, fmt.Errorf("lowerFD.Write got write length %d, want %d", n, len(data)) + } + + lowerFD.DecRef(ctx) + + // Now open the verity file descriptor. + fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filePath), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular | mode, + }) + return fd, dataSize, err +} + +// corruptRandomBit randomly flips a bit in the file represented by fd. +func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { + // Flip a random bit in the underlying file. + randomPos := int64(rand.Intn(size)) + byteToModify := make([]byte, 1) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { + return fmt.Errorf("lowerFD.PRead: %v", err) + } + byteToModify[0] ^= 1 + if _, err := fd.PWrite(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.WriteOptions{}); err != nil { + return fmt.Errorf("lowerFD.PWrite: %v", err) + } + return nil +} + +// TestOpen ensures that when a file is created, the corresponding Merkle tree +// file and the root Merkle tree file exist. +func TestOpen(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } +} + +// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file +// succeeds after enabling verity for it. +func TestReadUnmodifiedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + +// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file +// succeeds after enabling verity for it. +func TestReopenUnmodifiedFileSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } +} + +// TestModifiedFileFails ensures that read from a modified verity file fails. +func TestModifiedFileFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded with modified file") + } +} + +// TestModifiedMerkleFails ensures that read from a verity file fails if the +// corresponding Merkle tree file is modified. +func TestModifiedMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } +} + +// TestModifiedParentMerkleFails ensures that open a verity enabled file in a +// verity enabled directory fails if the hashes related to the target file in +// the parent Merkle tree file is modified. +func TestModifiedParentMerkleFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } +} + +// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file +// succeeds after enabling verity for it. +func TestUnmodifiedStatSucceeds(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms stat succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { + t.Errorf("fd.Stat: %v", err) + } +} + +// TestModifiedStatFails checks that getting stat for a file with modified stat +// should fail. +func TestModifiedStatFails(t *testing.T) { + ctx := contexttest.Context(t) + vfsObj, root, err := newVerityRoot(ctx, t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + lowerFD := fd.Impl().(*fileDescription).lowerFD + // Change the stat of the underlying file, and check that stat fails. + if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: uint32(linux.STATX_MODE), + Mode: 0777, + }, + }); err != nil { + t.Fatalf("lowerFD.SetStat: %v", err) + } + + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { + t.Errorf("fd.Stat succeeded when it should fail") + } +} diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 61c78569d..300b7ccce 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -7,11 +7,14 @@ go_library( srcs = [ "cgroup.go", "hostmm.go", + "membarrier.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/fd", "//pkg/log", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/hostmm/membarrier.go b/pkg/sentry/hostmm/membarrier.go new file mode 100644 index 000000000..4468d75f1 --- /dev/null +++ b/pkg/sentry/hostmm/membarrier.go @@ -0,0 +1,90 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostmm + +import ( + "syscall" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" +) + +var ( + haveMembarrierGlobal = false + haveMembarrierPrivateExpedited = false +) + +func init() { + supported, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_QUERY, 0 /* flags */, 0 /* unused */) + if e != 0 { + if e != syscall.ENOSYS { + log.Warningf("membarrier(MEMBARRIER_CMD_QUERY) failed: %s", e.Error()) + } + return + } + // We don't use MEMBARRIER_CMD_GLOBAL_EXPEDITED because this sends IPIs to + // all CPUs running tasks that have previously invoked + // MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, which presents a DOS risk. + // (MEMBARRIER_CMD_GLOBAL is synchronize_rcu(), i.e. it waits for an RCU + // grace period to elapse without bothering other CPUs. + // MEMBARRIER_CMD_PRIVATE_EXPEDITED sends IPIs only to CPUs running tasks + // sharing the caller's MM.) + if supported&linux.MEMBARRIER_CMD_GLOBAL != 0 { + haveMembarrierGlobal = true + } + if req := uintptr(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED); supported&req == req { + if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 { + log.Warningf("membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED) failed: %s", e.Error()) + } else { + haveMembarrierPrivateExpedited = true + } + } +} + +// HaveGlobalMemoryBarrier returns true if GlobalMemoryBarrier is supported. +func HaveGlobalMemoryBarrier() bool { + return haveMembarrierGlobal +} + +// GlobalMemoryBarrier blocks until "all running threads [in the host OS] have +// passed through a state where all memory accesses to user-space addresses +// match program order between entry to and return from [GlobalMemoryBarrier]", +// as for membarrier(2). +// +// Preconditions: HaveGlobalMemoryBarrier() == true. +func GlobalMemoryBarrier() error { + if _, _, e := syscall.Syscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_GLOBAL, 0 /* flags */, 0 /* unused */); e != 0 { + return e + } + return nil +} + +// HaveProcessMemoryBarrier returns true if ProcessMemoryBarrier is supported. +func HaveProcessMemoryBarrier() bool { + return haveMembarrierPrivateExpedited +} + +// ProcessMemoryBarrier is equivalent to GlobalMemoryBarrier, but only +// synchronizes with threads sharing a virtual address space (from the host OS' +// perspective) with the calling thread. +// +// Preconditions: HaveProcessMemoryBarrier() == true. +func ProcessMemoryBarrier() error { + if _, _, e := syscall.RawSyscall(unix.SYS_MEMBARRIER, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0 /* flags */, 0 /* unused */); e != 0 { + return e + } + return nil +} diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 083071b5e..5de70aecb 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -204,7 +204,6 @@ go_library( "//pkg/abi", "//pkg/abi/linux", "//pkg/amutex", - "//pkg/binary", "//pkg/bits", "//pkg/bpf", "//pkg/context", diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go index d3e76ca7b..060c056df 100644 --- a/pkg/sentry/kernel/kcov.go +++ b/pkg/sentry/kernel/kcov.go @@ -89,7 +89,7 @@ func (kcov *Kcov) TaskWork(t *Task) { kcov.mu.Lock() defer kcov.mu.Unlock() - if kcov.mode != linux.KCOV_TRACE_PC { + if kcov.mode != linux.KCOV_MODE_TRACE_PC { return } @@ -146,7 +146,7 @@ func (kcov *Kcov) InitTrace(size uint64) error { } // EnableTrace performs the KCOV_ENABLE_TRACE ioctl. -func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { +func (kcov *Kcov) EnableTrace(ctx context.Context, traceKind uint8) error { t := TaskFromContext(ctx) if t == nil { panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") @@ -160,9 +160,9 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { return syserror.EINVAL } - switch traceMode { + switch traceKind { case linux.KCOV_TRACE_PC: - kcov.mode = traceMode + kcov.mode = linux.KCOV_MODE_TRACE_PC case linux.KCOV_TRACE_CMP: // We do not support KCOV_MODE_TRACE_CMP. return syserror.ENOTSUP @@ -175,6 +175,7 @@ func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { } kcov.owningTask = t + t.SetKcov(kcov) t.RegisterWork(kcov) // Clear existing coverage data; the task expects to read only coverage data @@ -196,26 +197,35 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error { if t != kcov.owningTask { return syserror.EINVAL } - kcov.owningTask = nil kcov.mode = linux.KCOV_MODE_INIT - kcov.resetLocked() + kcov.owningTask = nil + kcov.mappable = nil return nil } -// Reset is called when the owning task exits. -func (kcov *Kcov) Reset() { +// Clear resets the mode and clears the owning task and memory mapping for kcov. +// It is called when the fd corresponding to kcov is closed. Note that the mode +// needs to be set so that the next call to kcov.TaskWork() will exit early. +func (kcov *Kcov) Clear() { kcov.mu.Lock() - kcov.resetLocked() + kcov.clearLocked() kcov.mu.Unlock() } -// The kcov instance is reset when the owning task exits or when tracing is -// disabled. -func (kcov *Kcov) resetLocked() { +func (kcov *Kcov) clearLocked() { + kcov.mode = linux.KCOV_MODE_INIT kcov.owningTask = nil - if kcov.mappable != nil { - kcov.mappable = nil - } + kcov.mappable = nil +} + +// OnTaskExit is called when the owning task exits. It is similar to +// kcov.Clear(), except the memory mapping is not cleared, so that the same +// mapping can be used in the future if kcov is enabled again by another task. +func (kcov *Kcov) OnTaskExit() { + kcov.mu.Lock() + kcov.mode = linux.KCOV_MODE_INIT + kcov.owningTask = nil + kcov.mu.Unlock() } // ConfigureMMap is called by the vfs.FileDescription for this kcov instance to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index d6c21adb7..16c427fc8 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -1738,3 +1738,18 @@ func (k *Kernel) ShmMount() *vfs.Mount { func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount } + +// Release releases resources owned by k. +// +// Precondition: This should only be called after the kernel is fully +// initialized, e.g. after k.Start() has been called. +func (k *Kernel) Release() { + if VFS2Enabled { + ctx := k.SupervisorContext() + k.hostMount.DecRef(ctx) + k.pipeMount.DecRef(ctx) + k.shmMount.DecRef(ctx) + k.socketMount.DecRef(ctx) + k.vfs.Release(ctx) + } +} diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go index c38c5a40c..387edfa91 100644 --- a/pkg/sentry/kernel/seccomp.go +++ b/pkg/sentry/kernel/seccomp.go @@ -18,7 +18,6 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" @@ -27,25 +26,18 @@ import ( const maxSyscallFilterInstructions = 1 << 15 -// seccompData is equivalent to struct seccomp_data, which contains the data -// passed to seccomp-bpf filters. -type seccompData struct { - // nr is the system call number. - nr int32 - - // arch is an AUDIT_ARCH_* value indicating the system call convention. - arch uint32 - - // instructionPointer is the value of the instruction pointer at the time - // of the system call. - instructionPointer uint64 - - // args contains the first 6 system call arguments. - args [6]uint64 -} - -func (d *seccompData) asBPFInput() bpf.Input { - return bpf.InputBytes{binary.Marshal(nil, usermem.ByteOrder, d), usermem.ByteOrder} +// dataAsBPFInput returns a serialized BPF program, only valid on the current task +// goroutine. +// +// Note: this is called for every syscall, which is a very hot path. +func dataAsBPFInput(t *Task, d *linux.SeccompData) bpf.Input { + buf := t.CopyScratchBuffer(d.SizeBytes()) + d.MarshalUnsafe(buf) + return bpf.InputBytes{ + Data: buf, + // Go-marshal always uses the native byte order. + Order: usermem.ByteOrder, + } } func seccompSiginfo(t *Task, errno, sysno int32, ip usermem.Addr) *arch.SignalInfo { @@ -112,20 +104,20 @@ func (t *Task) checkSeccompSyscall(sysno int32, args arch.SyscallArguments, ip u } func (t *Task) evaluateSyscallFilters(sysno int32, args arch.SyscallArguments, ip usermem.Addr) uint32 { - data := seccompData{ - nr: sysno, - arch: t.tc.st.AuditNumber, - instructionPointer: uint64(ip), + data := linux.SeccompData{ + Nr: sysno, + Arch: t.tc.st.AuditNumber, + InstructionPointer: uint64(ip), } // data.args is []uint64 and args is []arch.SyscallArgument (uintptr), so // we can't do any slicing tricks or even use copy/append here. for i, arg := range args { - if i >= len(data.args) { + if i >= len(data.Args) { break } - data.args[i] = arg.Uint64() + data.Args[i] = arg.Uint64() } - input := data.asBPFInput() + input := dataAsBPFInput(t, &data) ret := uint32(linux.SECCOMP_RET_ALLOW) f := t.syscallFilters.Load() diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 3eb78e91b..76d472292 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/fs/anon", diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index b07e1c1bd..78f718cfe 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -17,7 +17,6 @@ package signalfd import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/anon" @@ -103,8 +102,7 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } // Copy out the signal info using the specified format. - var buf [128]byte - binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{ + infoNative := linux.SignalfdSiginfo{ Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, @@ -113,9 +111,13 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), - }) - n, err := dst.CopyOut(ctx, buf[:]) - return int64(n), err + } + n, err := infoNative.WriteTo(dst.Writer(ctx)) + if err == usermem.ErrEndOfIOSequence { + // Partial copy-out ok. + err = nil + } + return n, err } // Readiness implements waiter.Waitable.Readiness. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index a436610c9..f796e0fa3 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -917,7 +917,7 @@ func (t *Task) SetKcov(k *Kcov) { // ResetKcov clears the kcov instance associated with t. func (t *Task) ResetKcov() { if t.kcov != nil { - t.kcov.Reset() + t.kcov.OnTaskExit() t.kcov = nil } } diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 5ae5906e8..fdadb52c0 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -265,6 +265,13 @@ func (ns *PIDNamespace) Tasks() []*Task { return tasks } +// NumTasks returns the number of tasks in ns. +func (ns *PIDNamespace) NumTasks() int { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return len(ns.tids) +} + // ThreadGroups returns a snapshot of the thread groups in ns. func (ns *PIDNamespace) ThreadGroups() []*ThreadGroup { return ns.ThreadGroupsAppend(nil) diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index e44a139b3..9bc452e67 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -17,7 +17,6 @@ package kernel import ( "fmt" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -28,6 +27,8 @@ import ( // // They are exposed to the VDSO via a parameter page managed by VDSOParamPage, // which also includes a sequence counter. +// +// +marshal type vdsoParams struct { monotonicReady uint64 monotonicBaseCycles int64 @@ -68,6 +69,13 @@ type VDSOParamPage struct { // checked in state_test_util tests, causing this field to change across // save / restore. seq uint64 + + // copyScratchBuffer is a temporary buffer used to marshal the params before + // copying it to the real parameter page. The parameter page is typically + // updated at a moderate frequency of ~O(seconds) throughout the lifetime of + // the sentry, so reusing this buffer is a good tradeoff between memory + // usage and the cost of allocation. + copyScratchBuffer []byte } // NewVDSOParamPage returns a VDSOParamPage. @@ -79,7 +87,11 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { - return &VDSOParamPage{mfp: mfp, fr: fr} + return &VDSOParamPage{ + mfp: mfp, + fr: fr, + copyScratchBuffer: make([]byte, (*vdsoParams)(nil).SizeBytes()), + } } // access returns a mapping of the param page. @@ -133,7 +145,8 @@ func (v *VDSOParamPage) Write(f func() vdsoParams) error { // Get the new params. p := f() - buf := binary.Marshal(nil, usermem.ByteOrder, p) + buf := v.copyScratchBuffer[:p.SizeBytes()] + p.MarshalUnsafe(buf) // Skip the sequence counter. if _, err := safemem.Copy(paramPage.DropFirst(8), safemem.BlockFromSafeSlice(buf)); err != nil { diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index a44fa2b95..7fd77925f 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -127,7 +127,7 @@ func (t Translation) FileRange() FileRange { // Preconditions: Same as Mappable.Translate. func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error { // Verify that the inputs to Mappable.Translate were valid. - if !required.WellFormed() || required.Length() <= 0 { + if !required.WellFormed() || required.Length() == 0 { panic(fmt.Sprintf("invalid required range: %v", required)) } if !usermem.Addr(required.Start).IsPageAligned() || !usermem.Addr(required.End).IsPageAligned() { @@ -145,7 +145,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp return fmt.Errorf("first Translation %+v does not cover start of required range %v", ts[0], required) } for i, t := range ts { - if !t.Source.WellFormed() || t.Source.Length() <= 0 { + if !t.Source.WellFormed() || t.Source.Length() == 0 { return fmt.Errorf("Translation %+v has invalid Source", t) } if !usermem.Addr(t.Source.Start).IsPageAligned() || !usermem.Addr(t.Source.End).IsPageAligned() { diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 8c9f11cce..92cc87d84 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -235,6 +235,20 @@ type MemoryManager struct { // vdsoSigReturnAddr is the address of 'vdso_sigreturn'. vdsoSigReturnAddr uint64 + + // membarrierPrivateEnabled is non-zero if EnableMembarrierPrivate has + // previously been called. Since, as of this writing, + // MEMBARRIER_CMD_PRIVATE_EXPEDITED is implemented as a global memory + // barrier, membarrierPrivateEnabled has no other effect. + // + // membarrierPrivateEnabled is accessed using atomic memory operations. + membarrierPrivateEnabled uint32 + + // membarrierRSeqEnabled is non-zero if EnableMembarrierRSeq has previously + // been called. + // + // membarrierRSeqEnabled is accessed using atomic memory operations. + membarrierRSeqEnabled uint32 } // vma represents a virtual memory area. diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 30facebf7..7e5f7de64 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -36,7 +36,7 @@ import ( // * ar.Length() != 0. func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -100,7 +100,7 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user // (i.e. permission checks must have been performed against vmas). func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !vseg.Ok() { @@ -193,7 +193,7 @@ func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrR // getVecPMAsLocked; other clients should call one of those instead. func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !vseg.Ok() { @@ -223,7 +223,7 @@ func (mm *MemoryManager) getPMAsInternalLocked(ctx context.Context, vseg vmaIter // Need a pma here. optAR := vseg.Range().Intersect(pgap.Range()) if checkInvariants { - if optAR.Length() <= 0 { + if optAR.Length() == 0 { panic(fmt.Sprintf("vseg %v and pgap %v do not overlap", vseg, pgap)) } } @@ -560,7 +560,7 @@ func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterat // Invalidate implements memmap.MappingSpace.Invalidate. func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.InvalidateOpts) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -583,7 +583,7 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate // * ar must be page-aligned. func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -629,7 +629,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat // * ar must be page-aligned. func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -715,10 +715,10 @@ func Unpin(prs []PinnedRange) { // * oldAR and newAR must be page-aligned. func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { if checkInvariants { - if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() { + if !oldAR.WellFormed() || oldAR.Length() == 0 || !oldAR.IsPageAligned() { panic(fmt.Sprintf("invalid oldAR: %v", oldAR)) } - if !newAR.WellFormed() || newAR.Length() <= 0 || !newAR.IsPageAligned() { + if !newAR.WellFormed() || newAR.Length() == 0 || !newAR.IsPageAligned() { panic(fmt.Sprintf("invalid newAR: %v", newAR)) } if oldAR.Length() > newAR.Length() { @@ -778,7 +778,7 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { // into mm.pmas. func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) (pmaGapIterator, error) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !pseg.Range().Contains(ar.Start) { @@ -831,7 +831,7 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe // * pseg.Range().Contains(ar.Start). func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !pseg.Range().Contains(ar.Start) { @@ -1050,7 +1050,7 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if !pseg.Ok() { panic("terminal pma iterator") } - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !pseg.Range().IsSupersetOf(ar) { diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index a2555ba1a..675efdc7c 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -17,6 +17,7 @@ package mm import ( "fmt" mrand "math/rand" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -1274,3 +1275,27 @@ func (mm *MemoryManager) VirtualDataSize() uint64 { defer mm.mappingMu.RUnlock() return mm.dataAS } + +// EnableMembarrierPrivate causes future calls to IsMembarrierPrivateEnabled to +// return true. +func (mm *MemoryManager) EnableMembarrierPrivate() { + atomic.StoreUint32(&mm.membarrierPrivateEnabled, 1) +} + +// IsMembarrierPrivateEnabled returns true if mm.EnableMembarrierPrivate() has +// previously been called. +func (mm *MemoryManager) IsMembarrierPrivateEnabled() bool { + return atomic.LoadUint32(&mm.membarrierPrivateEnabled) != 0 +} + +// EnableMembarrierRSeq causes future calls to IsMembarrierRSeqEnabled to +// return true. +func (mm *MemoryManager) EnableMembarrierRSeq() { + atomic.StoreUint32(&mm.membarrierRSeqEnabled, 1) +} + +// IsMembarrierRSeqEnabled returns true if mm.EnableMembarrierRSeq() has +// previously been called. +func (mm *MemoryManager) IsMembarrierRSeqEnabled() bool { + return atomic.LoadUint32(&mm.membarrierRSeqEnabled) != 0 +} diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index f769d8294..b8df72813 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -266,7 +266,7 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 { // * ar.Length() != 0. func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -350,7 +350,7 @@ const guardBytes = 256 * usermem.PageSize // * ar must be page-aligned. func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -371,7 +371,7 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) // * ar must be page-aligned. func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { - if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { + if !ar.WellFormed() || ar.Length() == 0 || !ar.IsPageAligned() { panic(fmt.Sprintf("invalid ar: %v", ar)) } } @@ -511,7 +511,7 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan if vseg.ValuePtr().mappable == nil { panic("MappableRange is meaningless for anonymous vma") } - if !ar.WellFormed() || ar.Length() <= 0 { + if !ar.WellFormed() || ar.Length() == 0 { panic(fmt.Sprintf("invalid ar: %v", ar)) } if !vseg.Range().IsSupersetOf(ar) { @@ -536,7 +536,7 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { if vseg.ValuePtr().mappable == nil { panic("MappableRange is meaningless for anonymous vma") } - if !mr.WellFormed() || mr.Length() <= 0 { + if !mr.WellFormed() || mr.Length() == 0 { panic(fmt.Sprintf("invalid mr: %v", mr)) } if !vseg.mappableRange().IsSupersetOf(mr) { diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 209b28053..db7d55ef2 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/context", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/hostmm", "//pkg/sentry/memmap", "//pkg/usermem", ], diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 3970dd81d..dd2bbeb12 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -9,12 +9,12 @@ go_library( "bluepill.go", "bluepill_allocator.go", "bluepill_amd64.go", - "bluepill_amd64.s", "bluepill_amd64_unsafe.go", "bluepill_arm64.go", "bluepill_arm64.s", "bluepill_arm64_unsafe.go", "bluepill_fault.go", + "bluepill_impl_amd64.s", "bluepill_unsafe.go", "context.go", "filters_amd64.go", @@ -56,6 +56,7 @@ go_library( "//pkg/sentry/time", "//pkg/sync", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) @@ -78,6 +79,15 @@ go_test( "//pkg/sentry/platform/kvm/testutil", "//pkg/sentry/platform/ring0", "//pkg/sentry/platform/ring0/pagetables", + "//pkg/sentry/time", "//pkg/usermem", ], ) + +genrule( + name = "bluepill_impl_amd64", + srcs = ["bluepill_amd64.s"], + outs = ["bluepill_impl_amd64.s"], + cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + tools = ["//pkg/sentry/platform/ring0/gen_offsets"], +) diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 2bc34a435..025ea93b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -19,11 +19,6 @@ // This is guaranteed to be zero. #define VCPU_CPU 0x0 -// CPU_SELF is the self reference in ring0's percpu. -// -// This is guaranteed to be zero. -#define CPU_SELF 0x0 - // Context offsets. // // Only limited use of the context is done in the assembly stub below, most is @@ -44,7 +39,7 @@ begin: LEAQ VCPU_CPU(AX), BX BYTE CLI; check_vcpu: - MOVQ CPU_SELF(GS), CX + MOVQ ENTRY_CPU_SELF(GS), CX CMPQ BX, CX JE right_vCPU wrong_vcpu: diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 03a98512e..0a54dd30d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -83,5 +83,34 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { - return c.runData.readyForInterruptInjection != 0 + if c.runData.readyForInterruptInjection == 0 { + return false + } + + if c.runData.ifFlag == 0 { + // This is impossible if readyForInterruptInjection is 1. + throw("interrupts are disabled") + } + + // Disable interrupts if we are in the kernel space. + // + // When the Sentry switches into the kernel mode, it disables + // interrupts. But when goruntime switches on a goroutine which has + // been saved in the host mode, it restores flags and this enables + // interrupts. See the comment of UserFlagsSet for more details. + uregs := userRegs{} + err := c.getUserRegisters(&uregs) + if err != 0 { + throw("failed to get user registers") + } + + if ring0.IsKernelFlags(uregs.RFLAGS) { + uregs.RFLAGS &^= ring0.KernelFlagsClear + err = c.setUserRegisters(&uregs) + if err != 0 { + throw("failed to set user registers") + } + return false + } + return true } diff --git a/pkg/sentry/platform/kvm/filters_amd64.go b/pkg/sentry/platform/kvm/filters_amd64.go index 7d949f1dd..d3d216aa5 100644 --- a/pkg/sentry/platform/kvm/filters_amd64.go +++ b/pkg/sentry/platform/kvm/filters_amd64.go @@ -17,14 +17,23 @@ package kvm import ( "syscall" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) // SyscallFilters returns syscalls made exclusively by the KVM platform. func (*KVM) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_ARCH_PRCTL: {}, - syscall.SYS_IOCTL: {}, + syscall.SYS_ARCH_PRCTL: {}, + syscall.SYS_IOCTL: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MMAP: {}, syscall.SYS_RT_SIGSUSPEND: {}, syscall.SYS_RT_SIGTIMEDWAIT: {}, diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go index 9245d07c2..21abc2a3d 100644 --- a/pkg/sentry/platform/kvm/filters_arm64.go +++ b/pkg/sentry/platform/kvm/filters_arm64.go @@ -17,13 +17,22 @@ package kvm import ( "syscall" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) // SyscallFilters returns syscalls made exclusively by the KVM platform. func (*KVM) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_IOCTL: {}, + syscall.SYS_IOCTL: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MMAP: {}, syscall.SYS_RT_SIGSUSPEND: {}, syscall.SYS_RT_SIGTIMEDWAIT: {}, diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index ae813e24e..dd45ad10b 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -63,6 +63,9 @@ type runData struct { type KVM struct { platform.NoCPUPreemptionDetection + // KVM never changes mm_structs. + platform.UseHostProcessMemoryBarrier + // machine is the backing VM. machine *machine } @@ -156,15 +159,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. pageTables := pagetables.New(newAllocator()) - applyPhysicalRegions(func(pr physicalRegion) bool { - // Map the kernel in the upper half. - pageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. - }) + k.machine.mapUpperHalf(pageTables) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 5c4b18899..6abaa21c4 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -26,12 +26,16 @@ const ( _KVM_RUN = 0xae80 _KVM_NMI = 0xae9a _KVM_CHECK_EXTENSION = 0xae03 + _KVM_GET_TSC_KHZ = 0xaea3 + _KVM_SET_TSC_KHZ = 0xaea2 _KVM_INTERRUPT = 0x4004ae86 _KVM_SET_MSRS = 0x4008ae89 _KVM_SET_USER_MEMORY_REGION = 0x4020ae46 _KVM_SET_REGS = 0x4090ae82 _KVM_SET_SREGS = 0x4138ae84 + _KVM_GET_MSRS = 0xc008ae88 _KVM_GET_REGS = 0x8090ae81 + _KVM_GET_SREGS = 0x8138ae83 _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 _KVM_SET_CPUID2 = 0x4008ae90 _KVM_SET_SIGNAL_MASK = 0x4004ae8b @@ -79,11 +83,14 @@ const ( ) // KVM hypercall list. +// // Canonical list of hypercalls supported. const ( // On amd64, it uses 'HLT' to leave the guest. + // // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest. - // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now. + // + // _KVM_HYPERCALL_VMEXIT is only used on arm64 for now. _KVM_HYPERCALL_VMEXIT int = iota _KVM_HYPERCALL_MAX ) diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 9a7be3655..84df0f878 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -101,13 +101,20 @@ const ( // Arm64: Memory Attribute Indirection Register EL1. const ( - _MT_DEVICE_nGnRnE = 0 - _MT_DEVICE_nGnRE = 1 - _MT_DEVICE_GRE = 2 - _MT_NORMAL_NC = 3 - _MT_NORMAL = 4 - _MT_NORMAL_WT = 5 - _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8) + _MT_DEVICE_nGnRnE = 0 + _MT_DEVICE_nGnRE = 1 + _MT_DEVICE_GRE = 2 + _MT_NORMAL_NC = 3 + _MT_NORMAL = 4 + _MT_NORMAL_WT = 5 + _MT_ATTR_DEVICE_nGnRnE = 0x00 + _MT_ATTR_DEVICE_nGnRE = 0x04 + _MT_ATTR_DEVICE_GRE = 0x0c + _MT_ATTR_NORMAL_NC = 0x44 + _MT_ATTR_NORMAL_WT = 0xbb + _MT_ATTR_NORMAL = 0xff + _MT_ATTR_MASK = 0xff + _MT_EL1_INIT = (_MT_ATTR_DEVICE_nGnRnE << (_MT_DEVICE_nGnRnE * 8)) | (_MT_ATTR_DEVICE_nGnRE << (_MT_DEVICE_nGnRE * 8)) | (_MT_ATTR_DEVICE_GRE << (_MT_DEVICE_GRE * 8)) | (_MT_ATTR_NORMAL_NC << (_MT_NORMAL_NC * 8)) | (_MT_ATTR_NORMAL << (_MT_NORMAL * 8)) | (_MT_ATTR_NORMAL_WT << (_MT_NORMAL_WT * 8)) ) const ( diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 45b3180f1..2e12470aa 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) @@ -442,6 +443,22 @@ func TestWrongVCPU(t *testing.T) { }) } +func TestRdtsc(t *testing.T) { + var i int // Iteration count. + kvmTest(t, nil, func(c *vCPU) bool { + start := ktime.Rdtsc() + bluepill(c) + guest := ktime.Rdtsc() + redpill() + end := ktime.Rdtsc() + if start > guest || guest > end { + t.Errorf("inconsistent time: start=%d, guest=%d, end=%d", start, guest, end) + } + i++ + return i < 100 + }) +} + func BenchmarkApplicationSyscall(b *testing.B) { var ( i int // Iteration includes machine.Get() / machine.Put(). diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index a74930423..455e2bd20 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -155,7 +155,7 @@ func (m *machine) newVCPU() *vCPU { fd: int(fd), machine: m, } - c.CPU.Init(&m.kernel, c) + c.CPU.Init(&m.kernel, c.id, c) m.vCPUsByID[c.id] = c // Ensure the signal mask is correct. @@ -183,9 +183,6 @@ func newMachine(vm int) (*machine, error) { // Create the machine. m := &machine{fd: vm} m.available.L = &m.mu - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }) // Pull the maximum vCPUs. maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) @@ -197,6 +194,9 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + m.kernel.Init(ring0.KernelOpts{ + PageTables: pagetables.New(newAllocator()), + }, m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -219,15 +219,9 @@ func newMachine(vm int) (*machine, error) { pagetables.MapOpts{AccessType: usermem.AnyAccess}, pr.physical) - // And keep everything in the upper half. - m.kernel.PageTables.Map( - usermem.Addr(ring0.KernelStartAddress|pr.virtual), - pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, - pr.physical) - return true // Keep iterating. }) + m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -365,6 +359,11 @@ func (m *machine) Destroy() { // Get gets an available vCPU. // // This will return with the OS thread locked. +// +// It is guaranteed that if any OS thread TID is in guest, m.vCPUs[TID] points +// to the vCPU in which the OS thread TID is running. So if Get() returns with +// the corrent context in guest, the vCPU of it must be the same as what +// Get() returns. func (m *machine) Get() *vCPU { m.mu.RLock() runtime.LockOSThread() diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index ccaf3a028..c67127d95 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -18,14 +18,17 @@ package kvm import ( "fmt" + "math/big" "reflect" "runtime/debug" "syscall" + "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/usermem" ) @@ -131,6 +134,7 @@ func (c *vCPU) initArchState() error { // Set the entrypoint for the kernel. kernelUserRegs.RIP = uint64(reflect.ValueOf(ring0.Start).Pointer()) kernelUserRegs.RAX = uint64(reflect.ValueOf(&c.CPU).Pointer()) + kernelUserRegs.RSP = c.StackTop() kernelUserRegs.RFLAGS = ring0.KernelFlagsSet // Set the system registers. @@ -139,8 +143,8 @@ func (c *vCPU) initArchState() error { } // Set the user registers. - if err := c.setUserRegisters(&kernelUserRegs); err != nil { - return err + if errno := c.setUserRegisters(&kernelUserRegs); errno != 0 { + return fmt.Errorf("error setting user registers: %v", errno) } // Allocate some floating point state save area for the local vCPU. @@ -153,6 +157,133 @@ func (c *vCPU) initArchState() error { return c.setSystemTime() } +// bitsForScaling returns the bits available for storing the fraction component +// of the TSC scaling ratio. This allows us to replicate the (bad) math done by +// the kernel below in scaledTSC, and ensure we can compute an exact zero +// offset in setSystemTime. +// +// These constants correspond to kvm_tsc_scaling_ratio_frac_bits. +var bitsForScaling = func() int64 { + fs := cpuid.HostFeatureSet() + if fs.Intel() { + return 48 // See vmx.c (kvm sources). + } else if fs.AMD() { + return 32 // See svm.c (svm sources). + } else { + return 63 // Unknown: theoretical maximum. + } +}() + +// scaledTSC returns the host TSC scaled by the given frequency. +// +// This assumes a current frequency of 1. We require only the unitless ratio of +// rawFreq to some current frequency. See setSystemTime for context. +// +// The kernel math guarantees that all bits of the multiplication and division +// will be correctly preserved and applied. However, it is not possible to +// actually store the ratio correctly. So we need to use the same schema in +// order to calculate the scaled frequency and get the same result. +// +// We can assume that the current frequency is (1), so we are calculating a +// strict inverse of this value. This simplifies this function considerably. +// +// Roughly, the returned value "scaledTSC" will have: +// scaledTSC/hostTSC == 1/rawFreq +// +//go:nosplit +func scaledTSC(rawFreq uintptr) int64 { + scale := int64(1 << bitsForScaling) + ratio := big.NewInt(scale / int64(rawFreq)) + ratio.Mul(ratio, big.NewInt(int64(ktime.Rdtsc()))) + ratio.Div(ratio, big.NewInt(scale)) + return ratio.Int64() +} + +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + // First, scale down the clock frequency to the lowest value allowed by + // the API itself. How low we can go depends on the underlying + // hardware, but it is typically ~1/2^48 for Intel, ~1/2^32 for AMD. + // Even the lower bound here will take a 4GHz frequency down to 1Hz, + // meaning that everything should be able to handle a Khz setting of 1 + // with bits to spare. + // + // Note that reducing the clock does not typically require special + // capabilities as it is emulated in KVM. We don't actually use this + // capability, but it means that this method should be robust to + // different hardware configurations. + rawFreq, err := c.getTSCFreq() + if err != nil { + return c.setSystemTimeLegacy() + } + if err := c.setTSCFreq(1); err != nil { + return c.setSystemTimeLegacy() + } + + // Always restore the original frequency. + defer func() { + if err := c.setTSCFreq(rawFreq); err != nil { + panic(err.Error()) + } + }() + + // Attempt to set the system time in this compressed world. The + // calculation for offset normally looks like: + // + // offset = target_tsc - kvm_scale_tsc(vcpu, rdtsc()); + // + // So as long as the kvm_scale_tsc component is constant before and + // after the call to set the TSC value (and it is passes as the + // target_tsc), we will compute an offset value of zero. + // + // This is effectively cheating to make our "setSystemTime" call so + // unbelievably, incredibly fast that we do it "instantly" and all the + // calculations result in an offset of zero. + lastTSC := scaledTSC(rawFreq) + for { + if err := c.setTSC(uint64(lastTSC)); err != nil { + return err + } + nextTSC := scaledTSC(rawFreq) + if lastTSC == nextTSC { + return nil + } + lastTSC = nextTSC // Try again. + } +} + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} + // nonCanonical generates a canonical address return. // //go:nosplit @@ -332,3 +463,41 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } + +var execRegions = func() (regions []region) { + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { + return + } + if vr.accessType.Execute { + regions = append(regions, vr.region) + } + }) + return +}() + +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + for _, r := range execRegions { + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossilbe translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) + } + for start, end := range m.kernel.EntryRegions() { + regionLen := end - start + physical, length, ok := translateToPhysical(start) + if !ok || length < regionLen { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|start), + regionLen, + pagetables.MapOpts{AccessType: usermem.ReadWrite}, + physical) + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 290f035dd..b430f92c6 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -23,7 +23,6 @@ import ( "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/time" ) // loadSegments copies the current segments. @@ -61,91 +60,63 @@ func (c *vCPU) setCPUID() error { return nil } -// setSystemTime sets the TSC for the vCPU. +// getTSCFreq gets the TSC frequency. // -// This has to make the call many times in order to minimize the intrinsic -// error in the offset. Unfortunately KVM does not expose a relative offset via -// the API, so this is an approximation. We do this via an iterative algorithm. -// This has the advantage that it can generally deal with highly variable -// system call times and should converge on the correct offset. -func (c *vCPU) setSystemTime() error { - const ( - _MSR_IA32_TSC = 0x00000010 - calibrateTries = 10 - ) - registers := modelControlRegisters{ - nmsrs: 1, - } - registers.entries[0] = modelControlRegister{ - index: _MSR_IA32_TSC, +// If mustSucceed is true, then this function panics on error. +func (c *vCPU) getTSCFreq() (uintptr, error) { + rawFreq, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_TSC_KHZ, + 0 /* ignored */) + if errno != 0 { + return 0, errno } - target := uint64(^uint32(0)) - for done := 0; done < calibrateTries; { - start := uint64(time.Rdtsc()) - registers.entries[0].data = start + target - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_MSRS, - uintptr(unsafe.Pointer(®isters))); errno != 0 { - return fmt.Errorf("error setting system time: %v", errno) - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. So we only count attempts - // within +/- 6.25% of our minimum as an attempt. - end := uint64(time.Rdtsc()) - if end < start { - continue // Totally bogus. - } - half := (end - start) / 2 - if half < target { - target = half - } - if (half - target) < target/8 { - done++ - } + return rawFreq, nil +} + +// setTSCFreq sets the TSC frequency. +func (c *vCPU) setTSCFreq(freq uintptr) error { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_TSC_KHZ, + freq /* khz */); errno != 0 { + return fmt.Errorf("error setting TSC frequency: %v", errno) } return nil } -// setSignalMask sets the vCPU signal mask. -// -// This must be called prior to running the vCPU. -func (c *vCPU) setSignalMask() error { - // The layout of this structure implies that it will not necessarily be - // the same layout chosen by the Go compiler. It gets fudged here. - var data struct { - length uint32 - mask1 uint32 - mask2 uint32 - _ uint32 +// setTSC sets the TSC value. +func (c *vCPU) setTSC(value uint64) error { + const _MSR_IA32_TSC = 0x00000010 + registers := modelControlRegisters{ + nmsrs: 1, } - data.length = 8 // Fixed sigset size. - data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) - data.mask2 = ^uint32(bounceSignalMask >> 32) + registers.entries[0].index = _MSR_IA32_TSC + registers.entries[0].data = value if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, uintptr(c.fd), - _KVM_SET_SIGNAL_MASK, - uintptr(unsafe.Pointer(&data))); errno != 0 { - return fmt.Errorf("error setting signal mask: %v", errno) + _KVM_SET_MSRS, + uintptr(unsafe.Pointer(®isters))); errno != 0 { + return fmt.Errorf("error setting tsc: %v", errno) } return nil } // setUserRegisters sets user registers in the vCPU. -func (c *vCPU) setUserRegisters(uregs *userRegs) error { +// +//go:nosplit +func (c *vCPU) setUserRegisters(uregs *userRegs) syscall.Errno { if _, _, errno := syscall.RawSyscall( syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_REGS, uintptr(unsafe.Pointer(uregs))); errno != 0 { - return fmt.Errorf("error setting user registers: %v", errno) + return errno } - return nil + return 0 } // getUserRegisters reloads user registers in the vCPU. @@ -175,3 +146,17 @@ func (c *vCPU) setSystemRegisters(sregs *systemRegs) error { } return nil } + +// getSystemRegisters sets system registers. +// +//go:nosplit +func (c *vCPU) getSystemRegisters(sregs *systemRegs) syscall.Errno { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_GET_SREGS, + uintptr(unsafe.Pointer(sregs))); errno != 0 { + return errno + } + return 0 +} diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 1d247f0dd..54837f20c 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -19,6 +19,7 @@ package kvm import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -48,6 +49,18 @@ const ( poolPCIDs = 8 ) +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + applyPhysicalRegions(func(pr physicalRegion) bool { + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|pr.virtual), + pr.length, + pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pr.physical) + + return true // Keep iterating. + }) +} + // Get all read-only physicalRegions. func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { var rdonlyRegions []region diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 537419657..a163f956d 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -191,42 +191,6 @@ func (c *vCPU) getOneRegister(reg *kvmOneReg) error { return nil } -// setCPUID sets the CPUID to be used by the guest. -func (c *vCPU) setCPUID() error { - return nil -} - -// setSystemTime sets the TSC for the vCPU. -func (c *vCPU) setSystemTime() error { - return nil -} - -// setSignalMask sets the vCPU signal mask. -// -// This must be called prior to running the vCPU. -func (c *vCPU) setSignalMask() error { - // The layout of this structure implies that it will not necessarily be - // the same layout chosen by the Go compiler. It gets fudged here. - var data struct { - length uint32 - mask1 uint32 - mask2 uint32 - _ uint32 - } - data.length = 8 // Fixed sigset size. - data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) - data.mask2 = ^uint32(bounceSignalMask >> 32) - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_SET_SIGNAL_MASK, - uintptr(unsafe.Pointer(&data))); errno != 0 { - return fmt.Errorf("error setting signal mask: %v", errno) - } - - return nil -} - // SwitchToUser unpacks architectural-details. func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) { // Check for canonical addresses. diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 607c82156..1d6ca245a 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -143,3 +143,29 @@ func (c *vCPU) waitUntilNot(state uint32) { panic("futex wait error") } } + +// setSignalMask sets the vCPU signal mask. +// +// This must be called prior to running the vCPU. +func (c *vCPU) setSignalMask() error { + // The layout of this structure implies that it will not necessarily be + // the same layout chosen by the Go compiler. It gets fudged here. + var data struct { + length uint32 + mask1 uint32 + mask2 uint32 + _ uint32 + } + data.length = 8 // Fixed sigset size. + data.mask1 = ^uint32(bounceSignalMask & 0xffffffff) + data.mask2 = ^uint32(bounceSignalMask >> 32) + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_SIGNAL_MASK, + uintptr(unsafe.Pointer(&data))); errno != 0 { + return fmt.Errorf("error setting signal mask: %v", errno) + } + + return nil +} diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 530e779b0..dcfe839a7 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/hostmm" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,6 +53,10 @@ type Platform interface { // can reliably return ErrContextCPUPreempted. DetectsCPUPreemption() bool + // HaveGlobalMemoryBarrier returns true if the GlobalMemoryBarrier method + // is supported. + HaveGlobalMemoryBarrier() bool + // MapUnit returns the alignment used for optional mappings into this // platform's AddressSpaces. Higher values indicate lower per-page costs // for AddressSpace.MapFile. As a special case, a MapUnit of 0 indicates @@ -97,6 +102,15 @@ type Platform interface { // called. PreemptAllCPUs() error + // GlobalMemoryBarrier blocks until all threads running application code + // (via Context.Switch) and all task goroutines "have passed through a + // state where all memory accesses to user-space addresses match program + // order between entry to and return from [GlobalMemoryBarrier]", as for + // membarrier(2). + // + // Preconditions: HaveGlobalMemoryBarrier() == true. + GlobalMemoryBarrier() error + // SyscallFilters returns syscalls made exclusively by this platform. SyscallFilters() seccomp.SyscallRules } @@ -115,6 +129,43 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error { panic("This platform does not support CPU preemption detection") } +// UseHostGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and +// Platform.GlobalMemoryBarrier by invoking equivalent functionality on the +// host. +type UseHostGlobalMemoryBarrier struct{} + +// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier. +func (UseHostGlobalMemoryBarrier) HaveGlobalMemoryBarrier() bool { + return hostmm.HaveGlobalMemoryBarrier() +} + +// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier. +func (UseHostGlobalMemoryBarrier) GlobalMemoryBarrier() error { + return hostmm.GlobalMemoryBarrier() +} + +// UseHostProcessMemoryBarrier implements Platform.HaveGlobalMemoryBarrier and +// Platform.GlobalMemoryBarrier by invoking a process-local memory barrier. +// This is faster than UseHostGlobalMemoryBarrier, but is only appropriate for +// platforms for which application code executes while using the sentry's +// mm_struct. +type UseHostProcessMemoryBarrier struct{} + +// HaveGlobalMemoryBarrier implements Platform.HaveGlobalMemoryBarrier. +func (UseHostProcessMemoryBarrier) HaveGlobalMemoryBarrier() bool { + // Fall back to a global memory barrier if a process-local one isn't + // available. + return hostmm.HaveProcessMemoryBarrier() || hostmm.HaveGlobalMemoryBarrier() +} + +// GlobalMemoryBarrier implements Platform.GlobalMemoryBarrier. +func (UseHostProcessMemoryBarrier) GlobalMemoryBarrier() error { + if hostmm.HaveProcessMemoryBarrier() { + return hostmm.ProcessMemoryBarrier() + } + return hostmm.GlobalMemoryBarrier() +} + // MemoryManager represents an abstraction above the platform address space // which manages memory mappings and their contents. type MemoryManager interface { diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index b52d0fbd8..f56aa3b79 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -192,6 +192,7 @@ func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {} type PTrace struct { platform.MMapMinAddr platform.NoCPUPreemptionDetection + platform.UseHostGlobalMemoryBarrier } // New returns a new ptrace-based implementation of the platform interface. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 9c6c2cf5c..00899273e 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -76,15 +76,41 @@ type KernelOpts struct { type KernelArchState struct { KernelOpts + // cpuEntries is array of kernelEntry for all cpus + cpuEntries []kernelEntry + // globalIDT is our set of interrupt gates. - globalIDT idt64 + globalIDT *idt64 } -// CPUArchState contains CPU-specific arch state. -type CPUArchState struct { +// kernelEntry contains minimal CPU-specific arch state +// that can be mapped at the upper of the address space. +// Malicious APP might steal info from it via CPU bugs. +type kernelEntry struct { // stack is the stack used for interrupts on this CPU. stack [256]byte + // scratch space for temporary usage. + scratch0 uint64 + + // stackTop is the top of the stack. + stackTop uint64 + + // cpuSelf is back reference to CPU. + cpuSelf *CPU + + // kernelCR3 is the cr3 used for sentry kernel. + kernelCR3 uintptr + + // gdt is the CPU's descriptor table. + gdt descriptorTable + + // tss is the CPU's task state. + tss TaskState64 +} + +// CPUArchState contains CPU-specific arch state. +type CPUArchState struct { // errorCode is the error code from the last exception. errorCode uintptr @@ -97,11 +123,7 @@ type CPUArchState struct { // exception. errorType uintptr - // gdt is the CPU's descriptor table. - gdt descriptorTable - - // tss is the CPU's task state. - tss TaskState64 + *kernelEntry } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go index 7fa43c2f5..d87b1fd00 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.go +++ b/pkg/sentry/platform/ring0/entry_amd64.go @@ -36,12 +36,15 @@ func sysenter() // This must be called prior to sysret/iret. func swapgs() +// jumpToKernel jumps to the kernel version of the current RIP. +func jumpToKernel() + // sysret returns to userspace from a system call. // // The return code is the vector that interrupted execution. // // See stubs.go for a note regarding the frame size of this function. -func sysret(*CPU, *arch.Registers) Vector +func sysret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // "iret is the cadillac of CPL switching." // @@ -50,7 +53,7 @@ func sysret(*CPU, *arch.Registers) Vector // iret is nearly identical to sysret, except an iret is used to fully restore // all user state. This must be called in cases where all registers need to be // restored. -func iret(*CPU, *arch.Registers) Vector +func iret(cpu *CPU, regs *arch.Registers, userCR3 uintptr) Vector // exception is the generic exception entry. // diff --git a/pkg/sentry/platform/ring0/entry_amd64.s b/pkg/sentry/platform/ring0/entry_amd64.s index 02df38331..f59747df3 100644 --- a/pkg/sentry/platform/ring0/entry_amd64.s +++ b/pkg/sentry/platform/ring0/entry_amd64.s @@ -63,6 +63,15 @@ MOVQ offset+PTRACE_RSI(reg), SI; \ MOVQ offset+PTRACE_RDI(reg), DI; +// WRITE_CR3() writes the given CR3 value. +// +// The code corresponds to: +// +// mov %rax, %cr3 +// +#define WRITE_CR3() \ + BYTE $0x0f; BYTE $0x22; BYTE $0xd8; + // SWAP_GS swaps the kernel GS (CPU). #define SWAP_GS() \ BYTE $0x0F; BYTE $0x01; BYTE $0xf8; @@ -75,15 +84,9 @@ #define SYSRET64() \ BYTE $0x48; BYTE $0x0f; BYTE $0x07; -// LOAD_KERNEL_ADDRESS loads a kernel address. -#define LOAD_KERNEL_ADDRESS(from, to) \ - MOVQ from, to; \ - ORQ ·KernelStartAddress(SB), to; - // LOAD_KERNEL_STACK loads the kernel stack. -#define LOAD_KERNEL_STACK(from) \ - LOAD_KERNEL_ADDRESS(CPU_SELF(from), SP); \ - LEAQ CPU_STACK_TOP(SP), SP; +#define LOAD_KERNEL_STACK(entry) \ + MOVQ ENTRY_STACK_TOP(entry), SP; // See kernel.go. TEXT ·Halt(SB),NOSPLIT,$0 @@ -95,58 +98,93 @@ TEXT ·swapgs(SB),NOSPLIT,$0 SWAP_GS() RET +// jumpToKernel changes execution to the kernel address space. +// +// This works by changing the return value to the kernel version. +TEXT ·jumpToKernel(SB),NOSPLIT,$0 + MOVQ 0(SP), AX + ORQ ·KernelStartAddress(SB), AX // Future return value. + MOVQ AX, 0(SP) + RET + // See entry_amd64.go. TEXT ·sysret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) + // save SP AX userCR3 on the kernel stack. + MOVQ CPU_ENTRY(BX), BX + LOAD_KERNEL_STACK(BX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_RAX(AX) + PUSHQ CX + // Restore user register state. REGISTERS_LOAD(AX, 0) MOVQ PTRACE_RIP(AX), CX // Needed for SYSRET. MOVQ PTRACE_FLAGS(AX), R11 // Needed for SYSRET. - MOVQ PTRACE_RSP(AX), SP // Restore the stack directly. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + + // restore userCR3, AX, SP. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. + POPQ SP // Restore SP. SYSRET64() // See entry_amd64.go. TEXT ·iret(SB),NOSPLIT,$0-24 - // Save original state. - LOAD_KERNEL_ADDRESS(cpu+0(FP), BX) - LOAD_KERNEL_ADDRESS(regs+8(FP), AX) + CALL ·jumpToKernel(SB) + // Save original state and stack. sysenter() or exception() + // from APP(gr3) will switch to this stack, set the return + // value (vector: 32(SP)) and then do RET, which will also + // automatically return to the lower half. + MOVQ cpu+0(FP), BX + MOVQ regs+8(FP), AX + MOVQ userCR3+16(FP), CX MOVQ SP, CPU_REGISTERS+PTRACE_RSP(BX) MOVQ BP, CPU_REGISTERS+PTRACE_RBP(BX) MOVQ AX, CPU_REGISTERS+PTRACE_RAX(BX) // Build an IRET frame & restore state. + MOVQ CPU_ENTRY(BX), BX LOAD_KERNEL_STACK(BX) - MOVQ PTRACE_SS(AX), BX; PUSHQ BX - MOVQ PTRACE_RSP(AX), CX; PUSHQ CX - MOVQ PTRACE_FLAGS(AX), DX; PUSHQ DX - MOVQ PTRACE_CS(AX), DI; PUSHQ DI - MOVQ PTRACE_RIP(AX), SI; PUSHQ SI - REGISTERS_LOAD(AX, 0) // Restore most registers. - MOVQ PTRACE_RAX(AX), AX // Restore AX (scratch). + PUSHQ PTRACE_SS(AX) + PUSHQ PTRACE_RSP(AX) + PUSHQ PTRACE_FLAGS(AX) + PUSHQ PTRACE_CS(AX) + PUSHQ PTRACE_RIP(AX) + PUSHQ PTRACE_RAX(AX) // Save AX on kernel stack. + PUSHQ CX // Save userCR3 on kernel stack. + REGISTERS_LOAD(AX, 0) // Restore most registers. + POPQ AX // Get userCR3. + WRITE_CR3() // Switch to userCR3. + POPQ AX // Restore AX. IRET() // See entry_amd64.go. TEXT ·resume(SB),NOSPLIT,$0 // See iret, above. - MOVQ CPU_REGISTERS+PTRACE_SS(GS), BX; PUSHQ BX - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), CX; PUSHQ CX - MOVQ CPU_REGISTERS+PTRACE_FLAGS(GS), DX; PUSHQ DX - MOVQ CPU_REGISTERS+PTRACE_CS(GS), DI; PUSHQ DI - MOVQ CPU_REGISTERS+PTRACE_RIP(GS), SI; PUSHQ SI - REGISTERS_LOAD(GS, CPU_REGISTERS) - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + PUSHQ CPU_REGISTERS+PTRACE_SS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RSP(AX) + PUSHQ CPU_REGISTERS+PTRACE_FLAGS(AX) + PUSHQ CPU_REGISTERS+PTRACE_CS(AX) + PUSHQ CPU_REGISTERS+PTRACE_RIP(AX) + REGISTERS_LOAD(AX, CPU_REGISTERS) + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX IRET() // See entry_amd64.go. TEXT ·Start(SB),NOSPLIT,$0 - LOAD_KERNEL_STACK(AX) // Set the stack. PUSHQ $0x0 // Previous frame pointer. MOVQ SP, BP // Set frame pointer. PUSHQ AX // First argument (CPU). @@ -155,53 +193,60 @@ TEXT ·Start(SB),NOSPLIT,$0 // See entry_amd64.go. TEXT ·sysenter(SB),NOSPLIT,$0 - // Interrupts are always disabled while we're executing in kernel mode - // and always enabled while executing in user mode. Therefore, we can - // reliably look at the flags in R11 to determine where this syscall - // was from. - TESTL $_RFLAGS_IF, R11 + // _RFLAGS_IOPL0 is always set in the user mode and it is never set in + // the kernel mode. See the comment of UserFlagsSet for more details. + TESTL $_RFLAGS_IOPL0, R11 JZ kernel - user: SWAP_GS() - XCHGQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Swap stacks. - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for AX (regs). + MOVQ AX, ENTRY_SCRATCH0(GS) // Save user AX on scratch. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Load saved AX value. - MOVQ BX, PTRACE_RAX(AX) // Save everything else. - MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ CX, PTRACE_RIP(AX) MOVQ R11, PTRACE_FLAGS(AX) - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), BX; MOVQ BX, PTRACE_RSP(AX) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. + MOVQ SP, PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), CX // Load saved user AX value. + MOVQ CX, PTRACE_RAX(AX) // Save everything else. + MOVQ CX, PTRACE_ORIGRAX(AX) + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Get stacks. + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. // Return to the kernel, where the frame is: // - // vector (sp+24) + // vector (sp+32) + // userCR3 (sp+24) // regs (sp+16) // cpu (sp+8) // vcpu.Switch (sp+0) // - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ $Syscall, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ $Syscall, 32(SP) // Output vector. RET kernel: // We can't restore the original stack, but we can access the registers // in the CPU state directly. No need for temporary juggling. - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ CX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ SP, CPU_REGISTERS+PTRACE_RSP(GS) - MOVQ $0, CPU_ERROR_CODE(GS) // Clear error code. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ AX, ENTRY_SCRATCH0(GS) + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + MOVQ CX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ R11, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ SP, CPU_REGISTERS+PTRACE_RSP(AX) + MOVQ ENTRY_SCRATCH0(GS), BX + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ $0, CPU_ERROR_CODE(AX) // Clear error code. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. // Call the syscall trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ AX // First argument (vCPU). CALL ·kernelSyscall(SB) // Call the trampoline. POPQ AX // Pop vCPU. @@ -230,16 +275,21 @@ TEXT ·exception(SB),NOSPLIT,$0 // ERROR_CODE (sp+8) // VECTOR (sp+0) // - TESTL $_RFLAGS_IF, 32(SP) + TESTL $_RFLAGS_IOPL0, 32(SP) JZ kernel user: SWAP_GS() ADDQ $-8, SP // Adjust for flags. MOVQ $_KERNEL_FLAGS, 0(SP); BYTE $0x9d; // Reset flags (POPFQ). - XCHGQ CPU_REGISTERS+PTRACE_RAX(GS), AX // Swap for user regs. + PUSHQ AX // Save user AX on stack. + MOVQ ENTRY_KERNEL_CR3(GS), AX // Get kernel cr3 on AX. + WRITE_CR3() // Switch to kernel cr3. + + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + MOVQ CPU_REGISTERS+PTRACE_RAX(AX), AX // Get user regs. REGISTERS_SAVE(AX, 0) // Save all except IP, FLAGS, SP, AX. - MOVQ CPU_REGISTERS+PTRACE_RAX(GS), BX // Restore original AX. + POPQ BX // Restore original AX. MOVQ BX, PTRACE_RAX(AX) // Save it. MOVQ BX, PTRACE_ORIGRAX(AX) MOVQ 16(SP), BX; MOVQ BX, PTRACE_RIP(AX) @@ -249,34 +299,36 @@ user: MOVQ 48(SP), SI; MOVQ SI, PTRACE_SS(AX) // Copy out and return. + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. MOVQ 0(SP), BX // Load vector. MOVQ 8(SP), CX // Load error code. - MOVQ CPU_REGISTERS+PTRACE_RSP(GS), SP // Original stack (kernel version). - MOVQ CPU_REGISTERS+PTRACE_RBP(GS), BP // Original base pointer. - MOVQ CX, CPU_ERROR_CODE(GS) // Set error code. - MOVQ $1, CPU_ERROR_TYPE(GS) // Set error type to user. - MOVQ BX, 24(SP) // Output vector. + MOVQ CPU_REGISTERS+PTRACE_RSP(AX), SP // Original stack (kernel version). + MOVQ CPU_REGISTERS+PTRACE_RBP(AX), BP // Original base pointer. + MOVQ CX, CPU_ERROR_CODE(AX) // Set error code. + MOVQ $1, CPU_ERROR_TYPE(AX) // Set error type to user. + MOVQ BX, 32(SP) // Output vector. RET kernel: // As per above, we can save directly. - MOVQ AX, CPU_REGISTERS+PTRACE_RAX(GS) - MOVQ AX, CPU_REGISTERS+PTRACE_ORIGRAX(GS) - REGISTERS_SAVE(GS, CPU_REGISTERS) - MOVQ 16(SP), AX; MOVQ AX, CPU_REGISTERS+PTRACE_RIP(GS) - MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(GS) - MOVQ 40(SP), CX; MOVQ CX, CPU_REGISTERS+PTRACE_RSP(GS) + PUSHQ AX + MOVQ ENTRY_CPU_SELF(GS), AX // Load vCPU. + REGISTERS_SAVE(AX, CPU_REGISTERS) + POPQ BX + MOVQ BX, CPU_REGISTERS+PTRACE_RAX(AX) + MOVQ BX, CPU_REGISTERS+PTRACE_ORIGRAX(AX) + MOVQ 16(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RIP(AX) + MOVQ 32(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_FLAGS(AX) + MOVQ 40(SP), BX; MOVQ BX, CPU_REGISTERS+PTRACE_RSP(AX) // Set the error code and adjust the stack. - MOVQ 8(SP), AX // Load the error code. - MOVQ AX, CPU_ERROR_CODE(GS) // Copy out to the CPU. - MOVQ $0, CPU_ERROR_TYPE(GS) // Set error type to kernel. + MOVQ 8(SP), BX // Load the error code. + MOVQ BX, CPU_ERROR_CODE(AX) // Copy out to the CPU. + MOVQ $0, CPU_ERROR_TYPE(AX) // Set error type to kernel. MOVQ 0(SP), BX // BX contains the vector. - ADDQ $48, SP // Drop the exception frame. // Call the exception trampoline. LOAD_KERNEL_STACK(GS) - MOVQ CPU_SELF(GS), AX // Load vCPU. PUSHQ BX // Second argument (vector). PUSHQ AX // First argument (vCPU). CALL ·kernelException(SB) // Call the trampoline. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 5f63cbd45..494baaa4d 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -47,8 +47,9 @@ #define SCTLR_C 1 << 2 #define SCTLR_I 1 << 12 #define SCTLR_UCT 1 << 15 +#define SCTLR_UCI 1 << 26 -#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT) +#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT | SCTLR_UCI) // cntkctl_el1: counter-timer kernel control register el1. #define CNTKCTL_EL0PCTEN 1 << 0 @@ -342,6 +343,8 @@ ADD $16, RSP, RSP; \ MOVD RSV_REG, PTRACE_R18(R20); \ MOVD RSV_REG_APP, PTRACE_R9(R20); \ + MRS TPIDR_EL0, R3; \ + MOVD R3, PTRACE_TLS(R20); \ WORD $0xd5384003; \ // MRS SPSR_EL1, R3 MOVD R3, PTRACE_PSTATE(R20); \ MRS ELR_EL1, R3; \ @@ -354,6 +357,8 @@ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context. MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \ + MRS TPIDR_EL0, R4; \ + MOVD R4, CPU_REGISTERS+PTRACE_TLS(RSV_REG); \ WORD $0xd5384004; \ // MRS SPSR_EL1, R4 MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \ MRS ELR_EL1, R4; \ @@ -435,6 +440,8 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MRS TPIDR_EL1, RSV_REG REGISTERS_SAVE(RSV_REG, CPU_REGISTERS) MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG) + MRS TPIDR_EL0, R3 + MOVD R3, CPU_REGISTERS+PTRACE_TLS(RSV_REG) WORD $0xd5384003 // MRS SPSR_EL1, R3 MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG) @@ -461,8 +468,18 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 MOVD PTRACE_PSTATE(RSV_REG_APP), R1 WORD $0xd5184001 //MSR R1, SPSR_EL1 + // need use kernel space address to excute below code, since + // after SWITCH_TO_APP_PAGETABLE the ASID is changed to app's + // ASID. + WORD $0x10000061 // ADR R1, do_exit_to_el0 + ORR $0xffff000000000000, R1, R1 + JMP (R1) + +do_exit_to_el0: // RSV_REG & RSV_REG_APP will be loaded at the end. REGISTERS_LOAD(RSV_REG_APP, 0) + MOVD PTRACE_TLS(RSV_REG_APP), RSV_REG + MSR RSV_REG, TPIDR_EL0 // switch to user pagetable. MOVD PTRACE_R18(RSV_REG_APP), RSV_REG diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 549f3d228..9742308d8 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,7 +24,10 @@ go_binary( "defs_impl_arm64.go", "main.go", ], - visibility = ["//pkg/sentry/platform/ring0:__pkg__"], + visibility = [ + "//pkg/sentry/platform/kvm:__pkg__", + "//pkg/sentry/platform/ring0:__pkg__", + ], deps = [ "//pkg/cpuid", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 021693791..264be23d3 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -19,8 +19,8 @@ package ring0 // N.B. that constraints on KernelOpts must be satisfied. // //go:nosplit -func (k *Kernel) Init(opts KernelOpts) { - k.init(opts) +func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { + k.init(opts, maxCPUs) } // Halt halts execution. @@ -49,6 +49,11 @@ func (defaultHooks) KernelException(Vector) { // kernelSyscall is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -58,6 +63,11 @@ func kernelSyscall(c *CPU) { // kernelException is a trampoline. // +// When in amd64, it is called with %rip on the upper half, so it can +// NOT access to any global data which is not mapped on upper and must +// call to function pointers or interfaces to switch to the lower half +// so that callee can access to global data. +// // +checkescape:hard,stack // //go:nosplit @@ -68,10 +78,10 @@ func kernelException(c *CPU, vector Vector) { // Init initializes a new CPU. // // Init allows embedding in other objects. -func (c *CPU) Init(k *Kernel, hooks Hooks) { - c.self = c // Set self reference. - c.kernel = k // Set kernel reference. - c.init() // Perform architectural init. +func (c *CPU) Init(k *Kernel, cpuID int, hooks Hooks) { + c.self = c // Set self reference. + c.kernel = k // Set kernel reference. + c.init(cpuID) // Perform architectural init. // Require hooks. if hooks != nil { diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index d37981dbf..3a9dff4cc 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -18,13 +18,42 @@ package ring0 import ( "encoding/binary" + "reflect" + + "gvisor.dev/gvisor/pkg/usermem" ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables + entrySize := reflect.TypeOf(kernelEntry{}).Size() + var ( + entries []kernelEntry + padding = 1 + ) + for { + entries = make([]kernelEntry, maxCPUs+padding-1) + totalSize := entrySize * uintptr(maxCPUs+padding-1) + addr := reflect.ValueOf(&entries[0]).Pointer() + if addr&(usermem.PageSize-1) == 0 && totalSize >= usermem.PageSize { + // The runtime forces power-of-2 alignment for allocations, and we are therefore + // safe once the first address is aligned and the chunk is at least a full page. + break + } + padding = padding << 1 + } + k.cpuEntries = entries + + k.globalIDT = &idt64{} + if reflect.TypeOf(idt64{}).Size() != usermem.PageSize { + panic("Size of globalIDT should be PageSize") + } + if reflect.ValueOf(k.globalIDT).Pointer()&(usermem.PageSize-1) != 0 { + panic("Allocated globalIDT should be page aligned") + } + // Setup the IDT, which is uniform. for v, handler := range handlers { // Allow Breakpoint and Overflow to be called from all @@ -39,8 +68,26 @@ func (k *Kernel) init(opts KernelOpts) { } } +func (k *Kernel) EntryRegions() map[uintptr]uintptr { + regions := make(map[uintptr]uintptr) + + addr := reflect.ValueOf(&k.cpuEntries[0]).Pointer() + size := reflect.TypeOf(kernelEntry{}).Size() * uintptr(len(k.cpuEntries)) + end, _ := usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + addr = reflect.ValueOf(k.globalIDT).Pointer() + size = reflect.TypeOf(idt64{}).Size() + end, _ = usermem.Addr(addr + size).RoundUp() + regions[uintptr(usermem.Addr(addr).RoundDown())] = uintptr(end) + + return regions +} + // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { + c.kernelEntry = &c.kernel.cpuEntries[cpuID] + c.cpuSelf = c // Null segment. c.gdt[0].setNull() @@ -65,6 +112,7 @@ func (c *CPU) init() { // Set the kernel stack pointer in the TSS (virtual address). stackAddr := c.StackTop() + c.stackTop = stackAddr c.tss.rsp0Lo = uint32(stackAddr) c.tss.rsp0Hi = uint32(stackAddr >> 32) c.tss.ist1Lo = uint32(stackAddr) @@ -183,7 +231,7 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID) - kernelCR3 := c.kernel.PageTables.CR3(true, switchOpts.KernelPCID) + c.kernelCR3 = uintptr(c.kernel.PageTables.CR3(true, switchOpts.KernelPCID)) // Sanitize registers. regs := switchOpts.Registers @@ -197,15 +245,11 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS. WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS. LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point. - jumpToKernel() // Switch to upper half. - writeCR3(uintptr(userCR3)) // Change to user address space. if switchOpts.FullRestore { - vector = iret(c, regs) + vector = iret(c, regs, uintptr(userCR3)) } else { - vector = sysret(c, regs) + vector = sysret(c, regs, uintptr(userCR3)) } - writeCR3(uintptr(kernelCR3)) // Return to kernel address space. - jumpToUser() // Return to lower half. SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. return @@ -219,7 +263,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { //go:nosplit func start(c *CPU) { // Save per-cpu & FS segment. - WriteGS(kernelAddr(c)) + WriteGS(kernelAddr(c.kernelEntry)) WriteFS(uintptr(c.registers.Fs_base)) // Initialize floating point. diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 14774c5db..b294ccc7c 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,13 +25,13 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts) { +func (k *Kernel) init(opts KernelOpts, maxCPUs int) { // Save the root page tables. k.PageTables = opts.PageTables } // init initializes architecture-specific state. -func (c *CPU) init() { +func (c *CPU) init(cpuID int) { // Set the kernel stack pointer(virtual address). c.registers.Sp = uint64(c.StackTop()) @@ -64,11 +64,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate |= UserFlagsSet LoadFloatingPoint(switchOpts.FloatingPointState) - SetTLS(regs.TPIDR_EL0) kernelExitToEl0() - regs.TPIDR_EL0 = GetTLS() SaveFloatingPoint(switchOpts.FloatingPointState) vector = c.vecCode diff --git a/pkg/sentry/platform/ring0/lib_amd64.go b/pkg/sentry/platform/ring0/lib_amd64.go index ca968a036..0ec5c3bc5 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.go +++ b/pkg/sentry/platform/ring0/lib_amd64.go @@ -61,21 +61,9 @@ func wrgsbase(addr uintptr) // wrgsmsr writes to the GS_BASE MSR. func wrgsmsr(addr uintptr) -// writeCR3 writes the CR3 value. -func writeCR3(phys uintptr) - -// readCR3 reads the current CR3 value. -func readCR3() uintptr - // readCR2 reads the current CR2 value. func readCR2() uintptr -// jumpToKernel jumps to the kernel version of the current RIP. -func jumpToKernel() - -// jumpToUser jumps to the user version of the current RIP. -func jumpToUser() - // fninit initializes the floating point unit. func fninit() diff --git a/pkg/sentry/platform/ring0/lib_amd64.s b/pkg/sentry/platform/ring0/lib_amd64.s index 75d742750..2fe83568a 100644 --- a/pkg/sentry/platform/ring0/lib_amd64.s +++ b/pkg/sentry/platform/ring0/lib_amd64.s @@ -127,53 +127,6 @@ TEXT ·wrgsmsr(SB),NOSPLIT,$0-8 BYTE $0x0f; BYTE $0x30; // WRMSR RET -// jumpToUser changes execution to the user address. -// -// This works by changing the return value to the user version. -TEXT ·jumpToUser(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - NOTQ BX - ANDQ BX, SP // Switch the stack. - ANDQ BX, BP // Switch the frame pointer. - ANDQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// jumpToKernel changes execution to the kernel address space. -// -// This works by changing the return value to the kernel version. -TEXT ·jumpToKernel(SB),NOSPLIT,$0 - MOVQ 0(SP), AX - MOVQ ·KernelStartAddress(SB), BX - ORQ BX, SP // Switch the stack. - ORQ BX, BP // Switch the frame pointer. - ORQ BX, AX // Future return value. - MOVQ AX, 0(SP) - RET - -// writeCR3 writes the given CR3 value. -// -// The code corresponds to: -// -// mov %rax, %cr3 -// -TEXT ·writeCR3(SB),NOSPLIT,$0-8 - MOVQ cr3+0(FP), AX - BYTE $0x0f; BYTE $0x22; BYTE $0xd8; - RET - -// readCR3 reads the current CR3 value. -// -// The code corresponds to: -// -// mov %cr3, %rax -// -TEXT ·readCR3(SB),NOSPLIT,$0-8 - BYTE $0x0f; BYTE $0x20; BYTE $0xd8; - MOVQ AX, ret+0(FP) - RET - // readCR2 reads the current CR2 value. // // The code corresponds to: diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go index b8ab120a0..ca4075b09 100644 --- a/pkg/sentry/platform/ring0/offsets_amd64.go +++ b/pkg/sentry/platform/ring0/offsets_amd64.go @@ -30,14 +30,21 @@ func Emit(w io.Writer) { c := &CPU{} fmt.Fprintf(w, "\n// CPU offsets.\n") - fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer()) - fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack))) fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_ENTRY 0x%02x\n", reflect.ValueOf(&c.kernelEntry).Pointer()-reflect.ValueOf(c).Pointer()) + + e := &kernelEntry{} + fmt.Fprintf(w, "\n// CPU entry offsets.\n") + fmt.Fprintf(w, "#define ENTRY_SCRATCH0 0x%02x\n", reflect.ValueOf(&e.scratch0).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_STACK_TOP 0x%02x\n", reflect.ValueOf(&e.stackTop).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_CPU_SELF 0x%02x\n", reflect.ValueOf(&e.cpuSelf).Pointer()-reflect.ValueOf(e).Pointer()) + fmt.Fprintf(w, "#define ENTRY_KERNEL_CR3 0x%02x\n", reflect.ValueOf(&e.kernelCR3).Pointer()-reflect.ValueOf(e).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _RFLAGS_IF 0x%02x\n", _RFLAGS_IF) + fmt.Fprintf(w, "#define _RFLAGS_IOPL0 0x%02x\n", _RFLAGS_IOPL0) fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 1d86b4bcf..45eba960d 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -125,4 +125,5 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer()) fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer()) fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer()) + fmt.Fprintf(w, "#define PTRACE_TLS 0x%02x\n", reflect.ValueOf(&p.TPIDR_EL0).Pointer()-reflect.ValueOf(p).Pointer()) } diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 6409d1d91..520161755 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -78,7 +78,7 @@ const ( const ( executeDisable = xn - optionMask = 0xfff | 0xfff<<48 + optionMask = 0xfff | 0xffff<<48 protDefault = accessed | shared ) @@ -188,7 +188,7 @@ func (p *PTE) Set(addr uintptr, opts MapOpts) { v |= mtNormal } else { v = v &^ user - v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress. + v |= mtNormal } atomic.StoreUintptr((*uintptr)(p), v) } diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go index 9da0ea685..34fbc1c35 100644 --- a/pkg/sentry/platform/ring0/x86.go +++ b/pkg/sentry/platform/ring0/x86.go @@ -39,7 +39,9 @@ const ( _RFLAGS_AC = 1 << 18 _RFLAGS_NT = 1 << 14 - _RFLAGS_IOPL = 3 << 12 + _RFLAGS_IOPL0 = 1 << 12 + _RFLAGS_IOPL1 = 1 << 13 + _RFLAGS_IOPL = _RFLAGS_IOPL0 | _RFLAGS_IOPL1 _RFLAGS_DF = 1 << 10 _RFLAGS_IF = 1 << 9 _RFLAGS_STEP = 1 << 8 @@ -67,15 +69,45 @@ const ( KernelFlagsSet = _RFLAGS_RESERVED // UserFlagsSet are always set in userspace. - UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF + // + // _RFLAGS_IOPL is a set of two bits and it shows the I/O privilege + // level. The Current Privilege Level (CPL) of the task must be less + // than or equal to the IOPL in order for the task or program to access + // I/O ports. + // + // Here, _RFLAGS_IOPL0 is used only to determine whether the task is + // running in the kernel or userspace mode. In the user mode, the CPL is + // always 3 and it doesn't matter what IOPL is set if it is bellow CPL. + // + // We need to have one bit which will be always different in user and + // kernel modes. And we have to remember that even though we have + // KernelFlagsClear, we still can see some of these flags in the kernel + // mode. This can happen when the goruntime switches on a goroutine + // which has been saved in the host mode. On restore, the popf + // instruction is used to restore flags and this means that all flags + // what the goroutine has in the host mode will be restored in the + // kernel mode. + // + // _RFLAGS_IOPL0 is never set in host and kernel modes and we always set + // it in the user mode. So if this flag is set, the task is running in + // the user mode and if it isn't set, the task is running in the kernel + // mode. + UserFlagsSet = _RFLAGS_RESERVED | _RFLAGS_IF | _RFLAGS_IOPL0 // KernelFlagsClear should always be clear in the kernel. KernelFlagsClear = _RFLAGS_STEP | _RFLAGS_IF | _RFLAGS_IOPL | _RFLAGS_AC | _RFLAGS_NT // UserFlagsClear are always cleared in userspace. - UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL + UserFlagsClear = _RFLAGS_NT | _RFLAGS_IOPL1 ) +// IsKernelFlags returns true if rflags coresponds to the kernel mode. +// +// go:nosplit +func IsKernelFlags(rflags uint64) bool { + return rflags&_RFLAGS_IOPL0 == 0 +} + // Vector is an exception vector. type Vector uintptr @@ -104,7 +136,7 @@ const ( VirtualizationException SecurityException = 0x1e SyscallInt80 = 0x80 - _NR_INTERRUPTS = SyscallInt80 + 1 + _NR_INTERRUPTS = 0x100 ) // System call vectors. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 94cb80437..904a12e38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -147,10 +147,6 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case stack.FilterTable: table = stack.EmptyFilterTable() case stack.NATTable: - if ipv6 { - nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)") - return syserr.ErrInvalidArgument - } table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 19b18b2d6..0e14447fe 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -47,6 +47,9 @@ func init() { registerTargetMaker(&redirectTargetMaker{ NetworkProtocol: header.IPv4ProtocolNumber, }) + registerTargetMaker(&nfNATTargetMaker{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) } type standardTargetMaker struct { @@ -250,6 +253,86 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return &target, nil } +type nfNATTarget struct { + Target linux.XTEntryTarget + Range linux.NFNATRange +} + +const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange + +type nfNATTargetMaker struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (rm *nfNATTargetMaker) id() stack.TargetID { + return stack.TargetID{ + Name: stack.RedirectTargetName, + NetworkProtocol: rm.NetworkProtocol, + } +} + +func (*nfNATTargetMaker) marshal(target stack.Target) []byte { + rt := target.(*stack.RedirectTarget) + nt := nfNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: nfNATMarhsalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } + copy(nt.Target.Name[:], stack.RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.Addr) + copy(nt.Range.MaxAddr[:], rt.Addr) + + nt.Range.MinProto = htons(rt.Port) + nt.Range.MaxProto = nt.Range.MinProto + + ret := make([]byte, 0, nfNATMarhsalledSize) + return binary.Marshal(ret, usermem.ByteOrder, nt) +} + +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { + if size := nfNATMarhsalledSize; len(buf) < size { + nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("nfNATTargetMaker: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] + binary.Unmarshal(buf, usermem.ByteOrder, &natRange) + + // We don't support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("nfNATTargetMaker: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("nfNATTargetMaker: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/3549): Check for other flags. + // For now, redirect target only supports destination change. + if natRange.Flags != linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("nfNATTargetMaker: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + } + + return &target, nil +} + // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { @@ -306,7 +389,7 @@ func (jt *JumpTarget) ID() stack.TargetID { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 335822c0e..87e30d742 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1512,8 +1512,17 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &vP, nil case linux.IP6T_ORIGINAL_DST: - // TODO(gvisor.dev/issue/170): ip6tables. - return nil, syserr.ErrInvalidArgument + if outLen < int(binary.Size(linux.SockAddrInet6{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: if outLen < linux.SizeOfIPTGetinfo { @@ -1555,6 +1564,26 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } return &entries, nil + case linux.IP6T_SO_GET_REVISION_TARGET: + if outLen < linux.SizeOfXTGetRevision { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + ret, err := netfilter.TargetRevision(t, outPtr, header.IPv6ProtocolNumber) + if err != nil { + return nil, err + } + return &ret, nil + default: emitUnimplementedEventIPv6(t, name) } diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 75752b2e6..a2e441448 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -21,6 +21,7 @@ go_library( "sys_identity.go", "sys_inotify.go", "sys_lseek.go", + "sys_membarrier.go", "sys_mempolicy.go", "sys_mmap.go", "sys_mount.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 5f26697d2..9c9def7cd 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -376,7 +376,7 @@ var AMD64 = &kernel.SyscallTable{ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 322: syscalls.Supported("execveat", Execveat), 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 324: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls implemented after 325 are "backports" from versions @@ -527,8 +527,8 @@ var ARM64 = &kernel.SyscallTable{ 96: syscalls.Supported("set_tid_address", SetTidAddress), 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil), - 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 99: syscalls.Supported("set_robust_list", SetRobustList), + 100: syscalls.Supported("get_robust_list", GetRobustList), 101: syscalls.Supported("nanosleep", Nanosleep), 102: syscalls.Supported("getitimer", Getitimer), 103: syscalls.Supported("setitimer", Setitimer), @@ -695,7 +695,7 @@ var ARM64 = &kernel.SyscallTable{ 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil), 281: syscalls.Supported("execveat", Execveat), 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345) - 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267) + 283: syscalls.PartiallySupported("membarrier", Membarrier, "Not supported on all platforms.", nil), 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil), // Syscalls after 284 are "backports" from versions of Linux after 4.4. diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 98331eb3c..519066a47 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -84,6 +84,7 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro } rel = f.Dirent if !fs.IsDir(rel.Inode.StableAttr) { + f.DecRef(t) return syserror.ENOTDIR } } diff --git a/pkg/sentry/syscalls/linux/sys_membarrier.go b/pkg/sentry/syscalls/linux/sys_membarrier.go new file mode 100644 index 000000000..63ee5d435 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_membarrier.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Membarrier implements syscall membarrier(2). +func Membarrier(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + cmd := args[0].Int() + flags := args[1].Uint() + + switch cmd { + case linux.MEMBARRIER_CMD_QUERY: + if flags != 0 { + return 0, nil, syserror.EINVAL + } + var supportedCommands uintptr + if t.Kernel().Platform.HaveGlobalMemoryBarrier() { + supportedCommands |= linux.MEMBARRIER_CMD_GLOBAL | + linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED | + linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED | + linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED | + linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED + } + if t.RSeqAvailable() { + supportedCommands |= linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ | + linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ + } + return supportedCommands, nil, nil + case linux.MEMBARRIER_CMD_GLOBAL, linux.MEMBARRIER_CMD_GLOBAL_EXPEDITED, linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED: + if flags != 0 { + return 0, nil, syserror.EINVAL + } + if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { + return 0, nil, syserror.EINVAL + } + if cmd == linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED && !t.MemoryManager().IsMembarrierPrivateEnabled() { + return 0, nil, syserror.EPERM + } + return 0, nil, t.Kernel().Platform.GlobalMemoryBarrier() + case linux.MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED: + if flags != 0 { + return 0, nil, syserror.EINVAL + } + if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { + return 0, nil, syserror.EINVAL + } + // no-op + return 0, nil, nil + case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED: + if flags != 0 { + return 0, nil, syserror.EINVAL + } + if !t.Kernel().Platform.HaveGlobalMemoryBarrier() { + return 0, nil, syserror.EINVAL + } + t.MemoryManager().EnableMembarrierPrivate() + return 0, nil, nil + case linux.MEMBARRIER_CMD_PRIVATE_EXPEDITED_RSEQ: + if flags&^linux.MEMBARRIER_CMD_FLAG_CPU != 0 { + return 0, nil, syserror.EINVAL + } + if !t.RSeqAvailable() { + return 0, nil, syserror.EINVAL + } + if !t.MemoryManager().IsMembarrierRSeqEnabled() { + return 0, nil, syserror.EPERM + } + // MEMBARRIER_CMD_FLAG_CPU and cpu_id are ignored since we don't have + // the ability to preempt specific CPUs. + return 0, nil, t.Kernel().Platform.PreemptAllCPUs() + case linux.MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_RSEQ: + if flags != 0 { + return 0, nil, syserror.EINVAL + } + if !t.RSeqAvailable() { + return 0, nil, syserror.EINVAL + } + t.MemoryManager().EnableMembarrierRSeq() + return 0, nil, nil + default: + // Probably a command we don't implement. + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, syserror.EINVAL + } +} diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 674d341b6..6320593f0 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -26,8 +26,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca addr := args[0].Pointer() mf := t.Kernel().MemoryFile() - mf.UpdateUsage() - _, totalUsage := usage.MemoryAccounting.Copy() + mfUsage, err := mf.TotalUsage() + if err != nil { + return 0, nil, err + } + memStats, _ := usage.MemoryAccounting.Copy() + totalUsage := mfUsage + memStats.Mapped totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage) memFree := totalSize - totalUsage if memFree > totalSize { @@ -37,12 +41,12 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // Only a subset of the fields in sysinfo_t make sense to return. si := linux.Sysinfo{ - Procs: uint16(len(t.PIDNamespace().Tasks())), + Procs: uint16(t.Kernel().TaskSet().Root.NumTasks()), Uptime: t.Kernel().MonotonicClock().Now().Seconds(), TotalRAM: totalSize, FreeRAM: memFree, Unit: 1, } - _, err := si.CopyOut(t, addr) + _, err = si.CopyOut(t, addr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index 0df3bd449..c50fd97eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -163,6 +163,7 @@ func Override() { // Override ARM64. s = linux.ARM64 + s.Table[2] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) s.Table[5] = syscalls.Supported("setxattr", SetXattr) s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) @@ -200,6 +201,7 @@ func Override() { s.Table[44] = syscalls.Supported("fstatfs", Fstatfs) s.Table[45] = syscalls.Supported("truncate", Truncate) s.Table[46] = syscalls.Supported("ftruncate", Ftruncate) + s.Table[47] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil) s.Table[48] = syscalls.Supported("faccessat", Faccessat) s.Table[49] = syscalls.Supported("chdir", Chdir) s.Table[50] = syscalls.Supported("fchdir", Fchdir) @@ -221,12 +223,14 @@ func Override() { s.Table[68] = syscalls.Supported("pwrite64", Pwrite64) s.Table[69] = syscalls.Supported("preadv", Preadv) s.Table[70] = syscalls.Supported("pwritev", Pwritev) + s.Table[71] = syscalls.Supported("sendfile", Sendfile) s.Table[72] = syscalls.Supported("pselect", Pselect) s.Table[73] = syscalls.Supported("ppoll", Ppoll) s.Table[74] = syscalls.Supported("signalfd4", Signalfd4) s.Table[76] = syscalls.Supported("splice", Splice) s.Table[77] = syscalls.Supported("tee", Tee) s.Table[78] = syscalls.Supported("readlinkat", Readlinkat) + s.Table[79] = syscalls.Supported("newfstatat", Newfstatat) s.Table[80] = syscalls.Supported("fstat", Fstat) s.Table[81] = syscalls.Supported("sync", Sync) s.Table[82] = syscalls.Supported("fsync", Fsync) @@ -251,8 +255,10 @@ func Override() { s.Table[210] = syscalls.Supported("shutdown", Shutdown) s.Table[211] = syscalls.Supported("sendmsg", SendMsg) s.Table[212] = syscalls.Supported("recvmsg", RecvMsg) + s.Table[213] = syscalls.Supported("readahead", Readahead) s.Table[221] = syscalls.Supported("execve", Execve) s.Table[222] = syscalls.Supported("mmap", Mmap) + s.Table[223] = syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil) s.Table[242] = syscalls.Supported("accept4", Accept4) s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg) s.Table[267] = syscalls.Supported("syncfs", Syncfs) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 8093ca55c..c855608db 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -92,7 +92,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index dfc3ae6c0..79a2d8c41 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -46,8 +46,9 @@ import ( // +stateify savable type Mount struct { // vfs, fs, root are immutable. References are held on fs and root. + // Note that for a disconnected mount, root may be nil. // - // Invariant: root belongs to fs. + // Invariant: if not nil, root belongs to fs. vfs *VirtualFilesystem fs *Filesystem root *Dentry @@ -498,7 +499,9 @@ func (mnt *Mount) DecRef(ctx context.Context) { mnt.vfs.mounts.seq.EndWrite() mnt.vfs.mountMu.Unlock() } - mnt.root.DecRef(ctx) + if mnt.root != nil { + mnt.root.DecRef(ctx) + } mnt.fs.DecRef(ctx) if vd.Ok() { vd.DecRef(ctx) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 5bd756ea5..31ea3139c 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -122,6 +122,13 @@ type VirtualFilesystem struct { filesystems map[*Filesystem]struct{} } +// Release drops references on filesystem objects held by vfs. +// +// Precondition: This must be called after VFS.Init() has succeeded. +func (vfs *VirtualFilesystem) Release(ctx context.Context) { + vfs.anonMount.DecRef(ctx) +} + // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. func (vfs *VirtualFilesystem) Init(ctx context.Context) error { if vfs.mountpoints != nil { diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index ea0c5413d..8db70a700 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -84,8 +84,8 @@ type VectorisedView struct { size int } -// NewVectorisedView creates a new vectorised view from an already-allocated slice -// of View and sets its size. +// NewVectorisedView creates a new vectorised view from an already-allocated +// slice of View and sets its size. func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } @@ -170,8 +170,9 @@ func (vv *VectorisedView) CapLength(length int) { } // Clone returns a clone of this VectorisedView. -// If the buffer argument is large enough to contain all the Views of this VectorisedView, -// the method will avoid allocations and use the buffer to store the Views of the clone. +// If the buffer argument is large enough to contain all the Views of this +// VectorisedView, the method will avoid allocations and use the buffer to +// store the Views of the clone. func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } @@ -209,7 +210,8 @@ func (vv *VectorisedView) PullUp(count int) (View, bool) { return newFirst, true } -// Size returns the size in bytes of the entire content stored in the vectorised view. +// Size returns the size in bytes of the entire content stored in the +// vectorised view. func (vv *VectorisedView) Size() int { return vv.size } @@ -222,6 +224,12 @@ func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } + return vv.ToOwnedView() +} + +// ToOwnedView returns a single view containing the content of the vectorised +// view that vv does not own. +func (vv *VectorisedView) ToOwnedView() View { u := make([]byte, 0, vv.size) for _, v := range vv.views { u = append(u, v...) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 19627fa9b..d4d785cca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -118,18 +118,82 @@ func TTL(ttl uint8) NetworkChecker { v = ip.HopLimit() } if v != ttl { - t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) + } + } +} + +// IPFullLength creates a checker for the full IP packet length. The +// expected size is checked against both the Total Length in the +// header and the number of bytes received. +func IPFullLength(packetLength uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + var v uint16 + var l uint16 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TotalLength() + l = uint16(len(ip)) + case header.IPv6: + v = ip.PayloadLength() + header.IPv6FixedHeaderSize + l = uint16(len(ip)) + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4 or header.IPv6", ip) + } + if l != packetLength { + t.Errorf("bad packet length, got = %d, want = %d", l, packetLength) + } + if v != packetLength { + t.Errorf("unexpected packet length in header, got = %d, want = %d", v, packetLength) + } + } +} + +// IPv4HeaderLength creates a checker that checks the IPv4 Header length. +func IPv4HeaderLength(headerLength int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + switch ip := h[0].(type) { + case header.IPv4: + if hl := ip.HeaderLength(); hl != uint8(headerLength) { + t.Errorf("Bad header length, got = %d, want = %d", hl, headerLength) + } + default: + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", ip) } } } // PayloadLen creates a checker that checks the payload length. -func PayloadLen(plen int) NetworkChecker { +func PayloadLen(payloadLength int) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - if l := len(h[0].Payload()); l != plen { - t.Errorf("Bad payload length, got %v, want %v", l, plen) + if l := len(h[0].Payload()); l != payloadLength { + t.Errorf("Bad payload length, got = %d, want = %d", l, payloadLength) + } + } +} + +// IPv4Options returns a checker that checks the options in an IPv4 packet. +func IPv4Options(want []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + options := ip.Options() + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(options) == 0 { + return + } + if diff := cmp.Diff(want, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) } } } @@ -139,11 +203,11 @@ func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.FragmentOffset(); v != offset { - t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, offset) } } } @@ -154,11 +218,11 @@ func FragmentFlags(flags uint8) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() - // We only do this of IPv4 for now. + // We only do this for IPv4 for now. switch ip := h[0].(type) { case header.IPv4: if v := ip.Flags(); v != flags { - t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + t.Errorf("Bad fragment offset, got = %d, want = %d", v, flags) } } } @@ -208,7 +272,7 @@ func TOS(tos uint8, label uint32) NetworkChecker { t.Helper() if v, l := h[0].TOS(); v != tos || l != label { - t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + t.Errorf("Bad TOS, got = (%d, %d), want = (%d,%d)", v, l, tos, label) } } } @@ -234,7 +298,7 @@ func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { t.Helper() if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } ipv6Frag := header.IPv6Fragment(h[0].Payload()) @@ -261,7 +325,7 @@ func TCP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.TCPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } // Verify the checksum. @@ -297,7 +361,7 @@ func UDP(checkers ...TransportChecker) NetworkChecker { last := h[len(h)-1] if p := last.TransportProtocol(); p != header.UDPProtocolNumber { - t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + t.Errorf("Bad protocol, got = %d, want = %d", p, header.UDPProtocolNumber) } udp := header.UDP(last.Payload()) @@ -316,7 +380,7 @@ func SrcPort(port uint16) TransportChecker { t.Helper() if p := h.SourcePort(); p != port { - t.Errorf("Bad source port, got %v, want %v", p, port) + t.Errorf("Bad source port, got = %d, want = %d", p, port) } } } @@ -327,7 +391,7 @@ func DstPort(port uint16) TransportChecker { t.Helper() if p := h.DestinationPort(); p != port { - t.Errorf("Bad destination port, got %v, want %v", p, port) + t.Errorf("Bad destination port, got = %d, want = %d", p, port) } } } @@ -359,7 +423,7 @@ func TCPSeqNum(seq uint32) TransportChecker { } if s := tcp.SequenceNumber(); s != seq { - t.Errorf("Bad sequence number, got %v, want %v", s, seq) + t.Errorf("Bad sequence number, got = %d, want = %d", s, seq) } } } @@ -375,7 +439,7 @@ func TCPAckNum(seq uint32) TransportChecker { } if s := tcp.AckNumber(); s != seq { - t.Errorf("Bad ack number, got %v, want %v", s, seq) + t.Errorf("Bad ack number, got = %d, want = %d", s, seq) } } } @@ -492,7 +556,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { case header.TCPOptionMSS: v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) if wantOpts.MSS != v { - t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + t.Errorf("Bad MSS, got = %d, want = %d", v, wantOpts.MSS) } foundMSS = true i += 4 @@ -502,7 +566,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { } v := int(opts[i+2]) if v != wantOpts.WS { - t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + t.Errorf("Bad WS, got = %d, want = %d", v, wantOpts.WS) } foundWS = true i += 3 @@ -551,7 +615,7 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { t.Error("TS option specified but the timestamp value is zero") } if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { - t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + t.Errorf("TS option specified but TSEcr is incorrect, got = %d, want = %d", tsEcr, wantOpts.TSEcr) } if wantOpts.SACKPermitted && !foundSACKPermitted { t.Errorf("SACKPermitted option not found. Options: %x", opts) @@ -589,7 +653,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) } if opts[i+1] != 10 { - t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + t.Errorf("TS option found, but bad length specified: got = %d, want = 10", opts[i+1]) } tsVal = binary.BigEndian.Uint32(opts[i+2:]) tsEcr = binary.BigEndian.Uint32(opts[i+6:]) @@ -609,19 +673,19 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } if wantTS != foundTS { - t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + t.Errorf("TS Option mismatch, got TS= %t, want TS= %t", foundTS, wantTS) } if wantTS && wantTSVal != 0 && wantTSVal != tsVal { - t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + t.Errorf("Timestamp value is incorrect, got = %d, want = %d", tsVal, wantTSVal) } if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { - t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + t.Errorf("Timestamp Echo Reply is incorrect, got = %d, want = %d", tsEcr, wantTSEcr) } } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not -// contain any SACK blocks in the TCP options. +// TCPNoSACKBlockChecker creates a checker that verifies that the segment does +// not contain any SACK blocks in the TCP options. func TCPNoSACKBlockChecker() TransportChecker { return TCPSACKBlockChecker(nil) } @@ -679,7 +743,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { - t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + t.Errorf("SACKBlocks are not equal, got = %v, want = %v", gotSACKBlocks, sackBlocks) } } } @@ -695,8 +759,8 @@ func Payload(want []byte) TransportChecker { } } -// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and -// potentially additional ICMPv4 header fields. +// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 +// and potentially additional ICMPv4 header fields. func ICMPv4(checkers ...TransportChecker) NetworkChecker { return func(t *testing.T, h []header.Network) { t.Helper() @@ -724,10 +788,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -739,10 +803,76 @@ func ICMPv4Code(want header.ICMPv4Code) TransportChecker { icmpv4, ok := h.(header.ICMPv4) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } if got := icmpv4.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Ident creates a checker that checks the ICMPv4 echo Ident. +func ICMPv4Ident(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Ident(); got != want { + t.Fatalf("unexpected ICMP ident, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Seq creates a checker that checks the ICMPv4 echo Sequence. +func ICMPv4Seq(want uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Sequence(); got != want { + t.Fatalf("unexpected ICMP sequence, got = %d, want = %d", got, want) + } + } +} + +// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. +// This assumes that the payload exactly makes up the rest of the slice. +func ICMPv4Checksum() TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + heldChecksum := icmpv4.Checksum() + icmpv4.SetChecksum(0) + newChecksum := ^header.Checksum(icmpv4, 0) + icmpv4.SetChecksum(heldChecksum) + if heldChecksum != newChecksum { + t.Errorf("unexpected ICMP checksum, got = %d, want = %d", heldChecksum, newChecksum) + } + } +} + +// ICMPv4Payload creates a checker that checks the payload in an ICMPv4 packet. +func ICMPv4Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + payload := icmpv4.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } @@ -782,10 +912,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Type(); got != want { - t.Fatalf("unexpected icmp type got: %d, want: %d", got, want) + t.Fatalf("unexpected icmp type, got = %d, want = %d", got, want) } } } @@ -797,10 +927,10 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { icmpv6, ok := h.(header.ICMPv6) if !ok { - t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h) + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } if got := icmpv6.Code(); got != want { - t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want) + t.Fatalf("unexpected ICMP code, got = %d, want = %d", got, want) } } } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index eaface8cb..95ade0e5c 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -117,25 +117,31 @@ func (b Ethernet) Encode(e *EthernetFields) { copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) } -// IsValidUnicastEthernetAddress returns true if addr is a valid unicast +// IsMulticastEthernetAddress returns true if the address is a multicast +// ethernet address. +func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool { + if len(addr) != EthernetAddressSize { + return false + } + + return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 +} + +// IsValidUnicastEthernetAddress returns true if the address is a unicast // ethernet address. func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { - // Must be of the right length. if len(addr) != EthernetAddressSize { return false } - // Must not be unspecified. if addr == unspecifiedEthernetAddress { return false } - // Must not be a multicast. if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 { return false } - // addr is a valid unicast ethernet address. return true } diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 14413f2ce..3bc8b2b21 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -67,6 +67,53 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { } } +func TestIsMulticastEthernetAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.LinkAddress + expected bool + }{ + { + "Nil", + tcpip.LinkAddress([]byte(nil)), + false, + }, + { + "Empty", + tcpip.LinkAddress(""), + false, + }, + { + "InvalidLength", + tcpip.LinkAddress("\x01\x02\x03"), + false, + }, + { + "Unspecified", + unspecifiedEthernetAddress, + false, + }, + { + "Multicast", + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + true, + }, + { + "Unicast", + tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06"), + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := IsMulticastEthernetAddress(test.addr); got != test.expected { + t.Fatalf("got IsMulticastEthernetAddress = %t, want = %t", got, test.expected) + } + }) + } +} + func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c00bcadfb..504408878 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -126,15 +126,6 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } -// SetPointer sets the pointer field in a Parameter error packet. -// This is the first byte of the type specific data field. -func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } - -// SetTypeSpecific sets the full 32 bit type specific data field. -func (b ICMPv4) SetTypeSpecific(val uint32) { - binary.BigEndian.PutUint32(b[icmpv4PointerOffset:], val) -} - // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4eb5abd79..6be31beeb 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -156,9 +156,14 @@ const ( // ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4. const ( + // ICMPv6ErroneousHeader indicates an erroneous header field was encountered. ICMPv6ErroneousHeader ICMPv6Code = 0 - ICMPv6UnknownHeader ICMPv6Code = 1 - ICMPv6UnknownOption ICMPv6Code = 2 + + // ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered. + ICMPv6UnknownHeader ICMPv6Code = 1 + + // ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered. + ICMPv6UnknownOption ICMPv6Code = 2 ) // ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use @@ -177,7 +182,12 @@ func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) } -// SetTypeSpecific sets the full 32 bit type specific data field. +// TypeSpecific returns the type specific data field. +func (b ICMPv6) TypeSpecific() uint32 { + return binary.BigEndian.Uint32(b[icmpv6PointerOffset:]) +} + +// SetTypeSpecific sets the type specific data field. func (b ICMPv6) SetTypeSpecific(val uint32) { binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val) } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index b07d9991d..4c6e4be64 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,10 +16,29 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" ) +// RFC 971 defines the fields of the IPv4 header on page 11 using the following +// diagram: ("Figure 4") +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification |Flags| Fragment Offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Time to Live | Protocol | Header Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Source Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Destination Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options | Padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// const ( versIHL = 0 tos = 1 @@ -33,6 +52,7 @@ const ( checksum = 10 srcAddr = 12 dstAddr = 16 + options = 20 ) // IPv4Fields contains the fields of an IPv4 packet. It is used to describe the @@ -76,7 +96,8 @@ type IPv4Fields struct { // IPv4 represents an ipv4 header stored in a byte array. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. -// Always call IsValid() to validate an instance of IPv4 before using other methods. +// Always call IsValid() to validate an instance of IPv4 before using other +// methods. type IPv4 []byte const ( @@ -151,13 +172,44 @@ func IPVersion(b []byte) int { if len(b) < versIHL+1 { return -1 } - return int(b[versIHL] >> 4) + return int(b[versIHL] >> ipVersionShift) } +// RFC 791 page 11 shows the header length (IHL) is in the lower 4 bits +// of the first byte, and is counted in multiples of 4 bytes. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| IHL |Type of Service| Total Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// (...) +// Version: 4 bits +// The Version field indicates the format of the internet header. This +// document describes version 4. +// +// IHL: 4 bits +// Internet Header Length is the length of the internet header in 32 +// bit words, and thus points to the beginning of the data. Note that +// the minimum value for a correct header is 5. +// +const ( + ipVersionShift = 4 + ipIHLMask = 0x0f + IPv4IHLStride = 4 +) + // HeaderLength returns the value of the "header length" field of the ipv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { - return (b[versIHL] & 0xf) * 4 + return (b[versIHL] & ipIHLMask) * IPv4IHLStride +} + +// SetHeaderLength sets the value of the "Internet Header Length" field. +func (b IPv4) SetHeaderLength(hdrLen uint8) { + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("got IPv4 Header size = %d, want <= %d", hdrLen, IPv4MaximumHeaderSize)) + } + b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } // ID returns the value of the identifier field of the ipv4 header. @@ -211,6 +263,12 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// Options returns a a buffer holding the options. +func (b IPv4) Options() []byte { + hdrLen := b.HeaderLength() + return b[options:hdrLen:hdrLen] +} + // TransportProtocol implements Network.TransportProtocol. func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { return tcpip.TransportProtocolNumber(b.Protocol()) @@ -236,6 +294,11 @@ func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } +// SetTTL sets the "Time to Live" field of the IPv4 header. +func (b IPv4) SetTTL(v byte) { + b[ttl] = v +} + // SetTotalLength sets the "total length" field of the ipv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) @@ -276,7 +339,7 @@ func (b IPv4) CalculateChecksum() uint16 { // Encode encodes all the fields of the ipv4 header. func (b IPv4) Encode(i *IPv4Fields) { - b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b.SetHeaderLength(i.IHL) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0761a1807..c5d8a3456 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -34,6 +34,9 @@ const ( hopLimit = 7 v6SrcAddr = 8 v6DstAddr = v6SrcAddr + IPv6AddressSize + + // IPv6FixedHeaderSize is the size of the fixed header. + IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize ) // IPv6Fields contains the fields of an IPv6 packet. It is used to describe the @@ -69,7 +72,7 @@ type IPv6 []byte const ( // IPv6MinimumSize is the minimum size of a valid IPv6 packet. - IPv6MinimumSize = 40 + IPv6MinimumSize = IPv6FixedHeaderSize // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 @@ -306,14 +309,21 @@ func IsV6UnicastAddress(addr tcpip.Address) bool { return addr[0] != 0xff } +const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + // SolicitedNodeAddr computes the solicited-node multicast address. This is // used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 // address. func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { - const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// IsSolicitedNodeAddr determines whether the address is a solicited-node +// multicast address. +func IsSolicitedNodeAddr(addr tcpip.Address) bool { + return solicitedNodeMulticastPrefix == addr[:len(addr)-3] +} + // EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 // from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. // diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 3499d8399..583c2c5d3 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -149,6 +149,19 @@ func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator { // obtained before modification is no longer used. type IPv6OptionsExtHdrOptionsIterator struct { reader bytes.Reader + + // optionOffset is the number of bytes from the first byte of the + // options field to the beginning of the current option. + optionOffset uint32 + + // nextOptionOffset is the offset of the next option. + nextOptionOffset uint32 +} + +// OptionOffset returns the number of bytes parsed while processing the +// option field of the current Extension Header. +func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 { + return i.optionOffset } // IPv6OptionUnknownAction is the action that must be taken if the processing @@ -226,6 +239,7 @@ func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {} // the options data, or an error occured. func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) { for { + i.optionOffset = i.nextOptionOffset temp, err := i.reader.ReadByte() if err != nil { // If we can't read the first byte of a new option, then we know the @@ -238,6 +252,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // know the option does not have Length and Data fields. End processing of // the Pad1 option and continue processing the buffer as a new option. if id == ipv6Pad1ExtHdrOptionIdentifier { + i.nextOptionOffset = i.optionOffset + 1 continue } @@ -254,41 +269,40 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF) } - // Special-case the variable length padding option to avoid a copy. - if id == ipv6PadNExtHdrOptionIdentifier { - // Do we have enough bytes in the reader for the PadN option? - if n := i.reader.Len(); n < int(length) { - // Reset the reader to effectively consume the remaining buffer. - i.reader.Reset(nil) - - // We return the same error as if we failed to read a non-padding option - // so consumers of this iterator don't need to differentiate between - // padding and non-padding options. - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) - } + // Do we have enough bytes in the reader for the next option? + if n := i.reader.Len(); n < int(length) { + // Reset the reader to effectively consume the remaining buffer. + i.reader.Reset(nil) + + // We return the same error as if we failed to read a non-padding option + // so consumers of this iterator don't need to differentiate between + // padding and non-padding options. + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF) + } + + i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */ + switch id { + case ipv6PadNExtHdrOptionIdentifier: + // Special-case the variable length padding option to avoid a copy. if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil { panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } - - // End processing of the PadN option and continue processing the buffer as - // a new option. continue - } - - bytes := make([]byte, length) - if n, err := io.ReadFull(&i.reader, bytes); err != nil { - // io.ReadFull may return io.EOF if i.reader has been exhausted. We use - // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the - // Length field found in the option. - if err == io.EOF { - err = io.ErrUnexpectedEOF + default: + bytes := make([]byte, length) + if n, err := io.ReadFull(&i.reader, bytes); err != nil { + // io.ReadFull may return io.EOF if i.reader has been exhausted. We use + // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the + // Length field found in the option. + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) } - - return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err) + return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } - - return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil } } @@ -382,6 +396,29 @@ type IPv6PayloadIterator struct { // Indicates to the iterator that it should return the remaining payload as a // raw payload on the next call to Next. forceRaw bool + + // headerOffset is the offset of the beginning of the current extension + // header starting from the beginning of the fixed header. + headerOffset uint32 + + // parseOffset is the byte offset into the current extension header of the + // field we are currently examining. It can be added to the header offset + // if the absolute offset within the packet is required. + parseOffset uint32 + + // nextOffset is the offset of the next header. + nextOffset uint32 +} + +// HeaderOffset returns the offset to the start of the extension +// header most recently processed. +func (i IPv6PayloadIterator) HeaderOffset() uint32 { + return i.headerOffset +} + +// ParseOffset returns the number of bytes successfully parsed. +func (i IPv6PayloadIterator) ParseOffset() uint32 { + return i.headerOffset + i.parseOffset } // MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing @@ -397,7 +434,8 @@ func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, pa nextHdrIdentifier: nextHdrIdentifier, payload: payload.Clone(nil), // We need a buffer of size 1 for calls to bufio.Reader.ReadByte. - reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1), + nextOffset: IPv6FixedHeaderSize, } } @@ -434,6 +472,8 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Next is unable to return anything because the iterator has reached the end of // the payload, or an error occured. func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { + i.headerOffset = i.nextOffset + i.parseOffset = 0 // We could be forced to return i as a raw header when the previous header was // a fragment extension header as the data following the fragment extension // header may not be complete. @@ -461,7 +501,7 @@ func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) { return IPv6RoutingExtHdr(bytes), false, nil case IPv6FragmentExtHdrIdentifier: var data [6]byte - // We ignore the returned bytes becauase we know the fragment extension + // We ignore the returned bytes because we know the fragment extension // header specific data will fit in data. nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:]) if err != nil { @@ -519,10 +559,12 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP if err != nil { return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err) } + i.parseOffset++ var length uint8 length, err = i.reader.ReadByte() i.payload.TrimFront(1) + if err != nil { if fragmentHdr { return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err) @@ -534,6 +576,17 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP length = 0 } + // Make parseOffset point to the first byte of the Extension Header + // specific data. + i.parseOffset++ + + // length is in 8 byte chunks but doesn't include the first one. + // See RFC 8200 for each header type, sections 4.3-4.6 and the requirement + // in section 4.8 for new extension headers at the top of page 24. + // [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet + // units, not including the first 8 octets. + i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit) + bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded if bytes == nil { bytes = make([]byte, bytesLen) diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go index b5540bf66..17a49d4fa 100644 --- a/pkg/tcpip/header/ipversion_test.go +++ b/pkg/tcpip/header/ipversion_test.go @@ -22,7 +22,7 @@ import ( func TestIPv4(t *testing.T) { b := header.IPv4(make([]byte, header.IPv4MinimumSize)) - b.Encode(&header.IPv4Fields{}) + b.Encode(&header.IPv4Fields{IHL: header.IPv4MinimumSize}) const want = header.IPv4Version if v := header.IPVersion(b); v != want { diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD new file mode 100644 index 000000000..9f31c1ffc --- /dev/null +++ b/pkg/tcpip/link/pipe/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pipe", + srcs = ["pipe.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go new file mode 100644 index 000000000..76f563811 --- /dev/null +++ b/pkg/tcpip/link/pipe/pipe.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pipe provides the implementation of pipe-like data-link layer +// endpoints. Such endpoints allow packets to be sent between two interfaces. +package pipe + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns both ends of a new pipe. +func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) { + ep1 := &Endpoint{ + linkAddr: linkAddr1, + capabilities: capabilities, + } + ep2 := &Endpoint{ + linkAddr: linkAddr2, + linked: ep1, + capabilities: capabilities, + } + ep1.linked = ep2 + return ep1, ep2 +} + +// Endpoint is one end of a pipe. +type Endpoint struct { + capabilities stack.LinkEndpointCapabilities + linkAddr tcpip.LinkAddress + dispatcher stack.NetworkDispatcher + linked *Endpoint + onWritePacket func(*stack.PacketBuffer) +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if !e.linked.IsAttached() { + return nil + } + + // The pipe endpoint will accept all multicast/broadcast link traffic and only + // unicast traffic destined to itself. + if len(e.linked.linkAddr) != 0 && + r.RemoteLinkAddress != e.linked.linkAddr && + r.RemoteLinkAddress != header.EthernetBroadcastAddress && + !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) { + return nil + } + + e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + + return nil +} + +// WritePackets implements stack.LinkEndpoint. +func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + panic("not implemented") +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// Wait implements stack.LinkEndpoint. +func (*Endpoint) Wait() {} + +// MTU implements stack.LinkEndpoint. +func (*Endpoint) MTU() uint32 { + return header.IPv6MinimumMTU +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.capabilities +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (*Endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index b6ddbe81e..f94491026 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,13 +76,29 @@ func (d *Device) Release(ctx context.Context) { } } +// NICID returns the NIC ID of the device. +// +// Must only be called after the device has been attached to an endpoint. +func (d *Device) NICID() tcpip.NICID { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.endpoint == nil { + panic("called NICID on a device that has not been attached") + } + + return d.endpoint.nicID +} + // SetIff services TUNSETIFF ioctl(2) request. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { +// +// Returns true if a new NIC was created; false if an existing one was attached. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return false, syserror.EINVAL } // Input validations. @@ -90,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return syserror.EINVAL + return false, syserror.EINVAL } prefix := "tun" @@ -103,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { linkCaps |= stack.CapabilityResolutionRequired } - endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return false, syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return nil + return created, nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { for { // 1. Try to attach to an existing NIC. if name != "" { - if nic, found := s.GetNICByName(name); found { - endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if linkEP := s.GetLinkEndpointByName(name); linkEP != nil { + endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, syserror.EOPNOTSUPP + return nil, false, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, nil + return endpoint, false, nil } } @@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, nil + return endpoint, true, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, false, syserror.EINVAL } } } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b47a7be51..7df77c66e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -49,7 +49,6 @@ type endpoint struct { enabled uint32 nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler } @@ -92,12 +91,12 @@ func (e *endpoint) DefaultTTL() uint8 { } func (e *endpoint) MTU() uint32 { - lmtu := e.linkEP.MTU() + lmtu := e.nic.MTU() return lmtu - uint32(e.MaxHeaderLength()) } func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.ARPSize + return e.nic.MaxHeaderLength() + header.ARPSize } func (e *endpoint) Close() { @@ -154,17 +153,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + origSender := h.HardwareAddressSender() + r.RemoteLinkAddress = tcpip.LinkAddress(origSender) + respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) - packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.HardwareAddressTarget(), origSender) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -207,7 +214,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e := &endpoint{ protocol: p, nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, } @@ -223,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ + NetProto: ProtocolNumber, RemoteLinkAddress: remoteLinkAddr, } if len(r.RemoteLinkAddress) == 0 { diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index e247f06a4..47fb63290 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -29,6 +29,8 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", ], ) @@ -44,5 +46,7 @@ go_test( deps = [ "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", + "//pkg/tcpip/network/testutil", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index e1909fab0..ed502a473 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -13,7 +13,7 @@ // limitations under the License. // Package fragmentation contains the implementation of IP fragmentation. -// It is based on RFC 791 and RFC 815. +// It is based on RFC 791, RFC 815 and RFC 8200. package fragmentation import ( @@ -25,12 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( - // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. - DefaultReassembleTimeout = 30 * time.Second - // HighFragThreshold is the threshold at which we start trimming old // fragmented packets. Linux uses a default value of 4 MB. See // net.ipv4.ipfrag_high_thresh for more information. @@ -243,3 +241,78 @@ func (f *Fragmentation) releaseReassemblersLocked() { f.release(r) } } + +// PacketFragmenter is the book-keeping struct for packet fragmentation. +type PacketFragmenter struct { + transportHeader buffer.View + data buffer.VectorisedView + reserve int + innerMTU int + fragmentCount int + currentFragment int + fragmentOffset int +} + +// MakePacketFragmenter prepares the struct needed for packet fragmentation. +// +// pkt is the packet to be fragmented. +// +// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// have. +// +// reserve is the number of bytes that should be reserved for the headers in +// each generated fragment. +func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { + // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be + // repeated in each fragment. However we do not currently support any header + // of that kind yet, so the following computation is valid for both IPv4 and + // IPv6. + // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are + // supported for outbound packets, the fragmentable data should not include + // these headers. + var fragmentableData buffer.VectorisedView + fragmentableData.AppendView(pkt.TransportHeader().View()) + fragmentableData.Append(pkt.Data) + fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + + return PacketFragmenter{ + data: fragmentableData, + reserve: reserve, + innerMTU: innerMTU, + fragmentCount: fragmentCount, + } +} + +// BuildNextFragment returns a packet with the payload of the next fragment, +// along with the fragment's offset, the number of bytes copied and a boolean +// indicating if there are more fragments left or not. If this function is +// called again after it indicated that no more fragments were left, it will +// panic. +// +// Note that the returned packet will not have its network and link headers +// populated, but space for them will be reserved. The transport header will be +// stored in the packet's data. +func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) { + if pf.currentFragment >= pf.fragmentCount { + panic("BuildNextFragment should not be called again after the last fragment was returned") + } + + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: pf.reserve, + }) + + // Copy data for the fragment. + copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + + offset := pf.fragmentOffset + pf.fragmentOffset += copied + pf.currentFragment++ + more := pf.currentFragment != pf.fragmentCount + + return fragPkt, offset, copied, more +} + +// RemainingFragmentCount returns the number of fragments left to be built. +func (pf *PacketFragmenter) RemainingFragmentCount() int { + return pf.fragmentCount - pf.currentFragment +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 189b223c5..d3c7d7f92 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -20,10 +20,16 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" ) +// reassembleTimeout is dummy timeout used for testing, where the clock never +// advances. +const reassembleTimeout = 1 + // vv is a helper to build VectorisedView from different strings. func vv(size int, pieces ...string) buffer.VectorisedView { views := make([]buffer.View, len(pieces)) @@ -96,7 +102,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) @@ -251,7 +257,7 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. @@ -275,7 +281,7 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. @@ -370,7 +376,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) @@ -381,3 +387,113 @@ func TestErrors(t *testing.T) { }) } } + +type fragmentInfo struct { + remaining int + copied int + offset int + more bool +} + +func TestPacketFragmenter(t *testing.T) { + const ( + reserve = 60 + proto = 0 + ) + + tests := []struct { + name string + innerMTU int + transportHeaderLen int + payloadSize int + wantFragments []fragmentInfo + }{ + { + name: "Packet exactly fits in MTU", + innerMTU: 1280, + transportHeaderLen: 0, + payloadSize: 1280, + wantFragments: []fragmentInfo{ + {remaining: 0, copied: 1280, offset: 0, more: false}, + }, + }, + { + name: "Packet exactly does not fit in MTU", + innerMTU: 1000, + transportHeaderLen: 0, + payloadSize: 1001, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 1000, offset: 0, more: true}, + {remaining: 0, copied: 1, offset: 1000, more: false}, + }, + }, + { + name: "Packet has a transport header", + innerMTU: 560, + transportHeaderLen: 40, + payloadSize: 560, + wantFragments: []fragmentInfo{ + {remaining: 1, copied: 560, offset: 0, more: true}, + {remaining: 0, copied: 40, offset: 560, more: false}, + }, + }, + { + name: "Packet has a huge transport header", + innerMTU: 500, + transportHeaderLen: 1300, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {remaining: 3, copied: 500, offset: 0, more: true}, + {remaining: 2, copied: 500, offset: 500, more: true}, + {remaining: 1, copied: 500, offset: 1000, more: true}, + {remaining: 0, copied: 300, offset: 1500, more: false}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto) + var originalPayload buffer.VectorisedView + originalPayload.AppendView(pkt.TransportHeader().View()) + originalPayload.Append(pkt.Data) + var reassembledPayload buffer.VectorisedView + pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + for i := 0; ; i++ { + fragPkt, offset, copied, more := pf.BuildNextFragment() + wantFragment := test.wantFragments[i] + if got := pf.RemainingFragmentCount(); got != wantFragment.remaining { + t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining) + } + if copied != wantFragment.copied { + t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied) + } + if offset != wantFragment.offset { + t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset) + } + if more != wantFragment.more { + t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) + } + if got := fragPkt.Size(); got > test.innerMTU { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + } + if got := fragPkt.AvailableHeaderBytes(); got != reserve { + t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) + } + if got := fragPkt.TransportHeader().View().Size(); got != 0 { + t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got) + } + reassembledPayload.Append(fragPkt.Data) + if !more { + if i != len(test.wantFragments)-1 { + t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1) + } + break + } + } + if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" { + t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 6861cfdaf..d436873b6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -270,7 +270,7 @@ func buildDummyStack(t *testing.T) *stack.Stack { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - tester testObject + testObject mu struct { sync.RWMutex @@ -302,10 +302,6 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (t *testInterface) LinkEndpoint() stack.LinkEndpoint { - return &t.tester -} - func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -517,7 +513,7 @@ func TestIPv4Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, @@ -538,10 +534,10 @@ func TestIPv4Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv4Addr - nic.tester.dstAddr = remoteIPv4Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv4Addr + nic.testObject.dstAddr = remoteIPv4Addr + nic.testObject.contents = payload r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -560,12 +556,12 @@ func TestIPv4Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -590,10 +586,10 @@ func TestIPv4Receive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[header.IPv4MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -606,8 +602,8 @@ func TestIPv4Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -640,11 +636,11 @@ func TestIPv4ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -691,16 +687,16 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } @@ -710,12 +706,12 @@ func TestIPv4FragmentationReceive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, v4: true, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -758,10 +754,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv4Addr - nic.tester.dstAddr = localIPv4Addr - nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) if err != nil { @@ -776,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls) + if nic.testObject.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } // Send second segment. @@ -788,8 +784,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -797,7 +793,7 @@ func TestIPv6Send(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } @@ -821,10 +817,10 @@ func TestIPv6Send(t *testing.T) { }) // Issue the write. - nic.tester.protocol = 123 - nic.tester.srcAddr = localIPv6Addr - nic.tester.dstAddr = remoteIPv6Addr - nic.tester.contents = payload + nic.testObject.protocol = 123 + nic.testObject.srcAddr = localIPv6Addr + nic.testObject.dstAddr = remoteIPv6Addr + nic.testObject.contents = payload r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -843,11 +839,11 @@ func TestIPv6Receive(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -871,10 +867,10 @@ func TestIPv6Receive(t *testing.T) { } // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[header.IPv6MinimumSize:totalLen] + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) if err != nil { @@ -888,8 +884,8 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } ep.HandlePacket(&r, pkt) - if nic.tester.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls) + if nic.testObject.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } } @@ -931,11 +927,11 @@ func TestIPv6ReceiveControl(t *testing.T) { s := buildDummyStack(t) proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) nic := testInterface{ - tester: testObject{ + testObject: testObject{ t: t, }, } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester) + ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) defer ep.Close() if err := ep.Enable(); err != nil { @@ -994,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. - nic.tester.protocol = 10 - nic.tester.srcAddr = remoteIPv6Addr - nic.tester.dstAddr = localIPv6Addr - nic.tester.contents = view[dataOffset:] - nic.tester.typ = c.expectedTyp - nic.tester.extra = c.expectedExtra + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[dataOffset:] + nic.testObject.typ = c.expectedTyp + nic.testObject.extra = c.expectedExtra // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) - if want := c.expectedCount; nic.tester.controlCalls != want { - t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want) + if want := c.expectedCount; nic.testObject.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } }) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 0a7e98ed1..7fc12e229 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -28,12 +28,15 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 5c4f715d7..3407755ed 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,8 @@ package ipv4 import ( + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() // Only send a reply if the checksum is valid. - wantChecksum := h.Checksum() - // Reset the checksum field to 0 to can calculate the proper - // checksum. We'll have to reset this before we hand the packet - // off. + headerChecksum := h.Checksum() h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - if gotChecksum != wantChecksum { - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) + calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(headerChecksum) + if calculatedChecksum != headerChecksum { + // It's possible that a raw socket still expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } - // Make a copy of data before pkt gets sent to raw socket. - // DeliverTransportPacket will take ownership of pkt. - replyData := pkt.Data.Clone(nil) - replyData.TrimFront(header.ICMPv4MinimumSize) + // DeliverTransportPacket will take ownership of pkt so don't use it beyond + // this point. Make a deep copy of the data before pkt gets sent as we will + // be modifying fields. + // + // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no + // waiting endpoints. Consider moving responsibility for doing the copy to + // DeliverTransportPacket so that is is only done when needed. + replyData := pkt.Data.ToOwnedView() + replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) - // It's possible that a raw socket expects to receive this. - h.SetChecksum(wantChecksum) e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). @@ -117,32 +117,49 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } defer r.Release() - // Use the remote link address from the incoming packet. - r.ResolveWith(remoteLinkAddr) - - // Prepare a reply packet. - icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) - copy(icmpHdr, h) - icmpHdr.SetType(header.ICMPv4EchoReply) - icmpHdr.SetChecksum(0) - icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) - dataVV := buffer.View(icmpHdr).ToVectorisedView() - dataVV.Append(replyData) + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the + // header information, we may have to change this code to handle the + // ICMP header no longer being in the data buffer. + + // Because IP and ICMP are so closely intertwined, we need to handcraft our + // IP header to be able to follow RFC 792. The wording on page 13 is as + // follows: + // IP Fields: + // Addresses + // The address of the source in an echo message will be the + // destination of the echo reply message. To form an echo reply + // message, the source and destination addresses are simply reversed, + // the type code changed to 0, and the checksum recomputed. + // + // This was interpreted by early implementors to mean that all options must + // be copied from the echo request IP header to the echo reply IP header + // and this behaviour is still relied upon by some applications. + // + // Create a copy of the IP header we received, options and all, and change + // The fields we need to alter. + // + // We need to produce the entire packet in the data segment in order to + // use WriteHeaderIncludedPacket(). + replyIPHdr.SetSourceAddress(r.LocalAddress) + replyIPHdr.SetDestinationAddress(r.RemoteAddress) + replyIPHdr.SetTTL(r.DefaultTTL()) + + replyICMPHdr := header.ICMPv4(replyData) + replyICMPHdr.SetType(header.ICMPv4EchoReply) + replyICMPHdr.SetChecksum(0) + replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0)) + + replyVV := buffer.View(replyIPHdr).ToVectorisedView() + replyVV.AppendView(replyData) replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: dataVV, + Data: replyVV, }) - // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header - // information we will have to change this code to handle the ICMP header - // no longer being in the data buffer. replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // Send out the reply packet. + + // The checksum will be calculated so we don't need to do it here. sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ - Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), - TOS: stack.DefaultTOS, - }, replyPkt); err != nil { + if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return } @@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonProtoUnreachable is an error where the transport protocol is +// not supported. +type icmpReasonProtoUnreachable struct{} + +func (*icmpReasonProtoUnreachable) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - sent := r.Stats().ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc return nil } + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + sent := p.stack.Stats().ICMP.V4PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() + return nil + } + networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() @@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Assume any type we don't know about may be an error type. return nil } - } else if transportHeader.IsEmpty() { - return nil } // Now work out how much of the triggering packet we should return. @@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // least 8 bytes of the payload must be included. Today linux and other // systems implement the RFC 1812 definition and not the original // requirement. We treat 8 bytes as the minimum but will try send more. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv4MinimumProcessableDatagramSize { mtu = header.IPv4MinimumProcessableDatagramSize } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize { @@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc ReserveHeaderBytes: headerLen, Data: payload, }) + icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + switch reason.(type) { + case *icmpReasonPortUnreachable: + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + case *icmpReasonProtoUnreachable: + icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter := sent.DstUnreachable icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + counter := sent.DstUnreachable - if err := r.WritePacket( + if err := route.WritePacket( nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, - TTL: r.DefaultTTL(), + TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, }, icmpPkt, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ad7a767a4..c5ac7b8b5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -18,6 +18,7 @@ package ipv4 import ( "fmt" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -30,6 +31,15 @@ import ( ) const ( + // As per RFC 791 section 3.2: + // The current recommendation for the initial timer setting is 15 seconds. + // This may be changed as experience with this protocol accumulates. + // + // Considering that it is an old recommendation, we use the same reassembly + // timeout that linux defines, which is 30 seconds: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 + reassembleTimeout = 30 * time.Second + // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -56,7 +66,6 @@ var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -77,7 +86,6 @@ type endpoint struct { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), dispatcher: dispatcher, protocol: p, } @@ -168,21 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -190,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is already present in pkt.NetworkHeader. -// pkt.TransportHeader may be set. mtu includes the IP header and options. This -// does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { - // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.NetworkHeader().View()) - flags := ip.Flags() - - // Update mtu to take into account the header, which will exist in all - // fragments anyway. - innerMTU := mtu - int(ip.HeaderLength()) - - // Round the MTU down to align to 8 bytes. Then calculate the number of - // fragments. Calculate fragment sizes as in RFC791. - innerMTU &^= 7 - n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU - - outerMTU := innerMTU + int(ip.HeaderLength()) - offset := ip.FragmentOffset() - - // Keep the length reserved for link-layer, we need to create fragments with - // the same reserved length. - reservedForLink := pkt.AvailableHeaderBytes() - - // Destroy the packet, pull all payloads out for fragmentation. - transHeader, data := pkt.TransportHeader().View(), pkt.Data - - // Where possible, the first fragment that is sent has the same - // number of bytes reserved for header as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - transFitsFirst := len(transHeader) <= innerMTU - - for i := 0; i < n; i++ { - reserve := reservedForLink + int(ip.HeaderLength()) - if i == 0 && transFitsFirst { - // Reserve for transport header if it's going to be put in the first - // fragment. - reserve += len(transHeader) - } - fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: reserve, - }) - fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - - // Copy data for the fragment. - avail := innerMTU - - if n := len(transHeader); n > 0 { - if n > avail { - n = avail - } - if i == 0 && transFitsFirst { - copy(fragPkt.TransportHeader().Push(n), transHeader) - } else { - fragPkt.Data.AppendView(transHeader[:n:n]) - } - transHeader = transHeader[n:] - avail -= n - } - - if avail > 0 { - n := data.Size() - if n > avail { - n = avail - } - data.ReadToVV(&fragPkt.Data, n) - avail -= n - } - - copied := uint16(innerMTU - avail) - - // Set lengths in header and calculate checksum. - h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) - copy(h, ip) - if i != n-1 { - h.SetTotalLength(uint16(outerMTU)) - h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) - } else { - h.SetTotalLength(uint16(h.HeaderLength()) + copied) - h.SetFlagsFragmentOffset(flags, offset) - } - h.SetChecksum(0) - h.SetChecksum(^h.CalculateChecksum()) - offset += copied +// writePacketFragments fragments pkt and writes the results on the link +// endpoint. The IP header must already present in the original packet. The mtu +// is the maximum size of the packets. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - // Send out the fragment. - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) return err } r.Stats().IP.PacketsSent.Increment() + if !more { + break + } } + return nil } @@ -303,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. @@ -329,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) @@ -345,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) + if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.Stats().IP.PacketsSent.Increment() @@ -377,8 +306,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -392,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -401,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -461,9 +394,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return nil } + if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } r.Stats().IP.PacketsSent.Increment() - - return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -566,7 +502,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC: 1122 Section 3.2.2.1 + // A host SHOULD generate Destination Unreachable messages with code: + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported + _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -794,14 +736,36 @@ func calculateMTU(mtu uint32) uint32 { return mtu - header.IPv4MinimumSize } +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + if mtu > MaxTotalSize { + mtu = MaxTotalSize + } + mtu -= uint32(pkt.NetworkHeader().View().Size()) + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + return mtu +} + +// addressToUint32 translates an IPv4 address into its little endian uint32 +// representation. +// +// This function does the same thing as binary.LittleEndian.Uint32 but operates +// on a tcpip.Address (a string) without the need to convert it to a byte slice, +// which would cause an allocation. +func addressToUint32(addr tcpip.Address) uint32 { + _ = addr[3] // bounds check hint to compiler + return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 +} + // hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number, and a random initial -// value (generated once on initialization) to generate the hash. +// destination address, the transport protocol number and a 32-bit number to +// generate the hash. func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - t := r.LocalAddress - a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 - t = r.RemoteAddress - b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + a := addressToUint32(r.LocalAddress) + b := addressToUint32(r.RemoteAddress) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } @@ -821,6 +785,29 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + } +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeaderLength := len(originalIPHeader) + nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + + if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) } + + flags := originalIPHeader.Flags() + if more { + flags |= header.IPv4FlagMoreFragments + } + nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset)) + nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied)) + nextFragIPHeader.SetChecksum(0) + nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum()) + + return fragPkt, more } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 277560e35..9916d783f 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,19 +16,24 @@ package ipv4_test import ( "bytes" + "context" "encoding/hex" "math" + "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) { }) } +// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and +// checks the response. +func TestIPv4Sanity(t *testing.T) { + const ( + defaultMTU = header.IPv6MinimumMTU + ttl = 255 + nicID = 1 + randomSequence = 123 + randomIdent = 42 + ) + var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + ) + + tests := []struct { + name string + headerLength uint8 // value of 0 means "use correct size" + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + shouldFail bool + expectICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + options []byte + }{ + { + name: "valid", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + }, + // The TTL tests check that we are not rejecting an incoming packet + // with a zero or one TTL, which has been a point of confusion in the + // past as RFC 791 says: "If this field contains the value zero, then the + // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies + // for the case of the destination host, stating as follows. + // + // A host MUST NOT send a datagram with a Time-to-Live (TTL) + // value of zero. + // + // A host MUST NOT discard a datagram just because it was + // received with TTL less than 2. + { + name: "zero TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 0, + shouldFail: false, + }, + { + name: "one TTL", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + shouldFail: false, + }, + { + name: "End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{0, 0, 0, 0}, + }, + { + name: "NOP options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 1, 1}, + }, + { + name: "NOP and End options", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{1, 1, 0, 0}, + }, + { + name: "bad header length", + headerLength: header.IPv4MinimumSize - 1, + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (0)", + maxTotalLength: 0, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad total length (ip + icmp - 1)", + maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1), + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + shouldFail: true, + expectICMP: false, + }, + { + name: "bad protocol", + maxTotalLength: defaultMTU, + transportProtocol: 99, + TTL: ttl, + shouldFail: true, + expectICMP: true, + ICMPType: header.ICMPv4DstUnreachable, + ICMPCode: header.ICMPv4ProtoUnreachable, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) + } + + // Default routes for IPv4 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // Round up the header size to the next multiple of 4 as RFC 791, page 11 + // says: "Internet Header Length is the length of the internet header + // in 32 bit words..." and on page 23: "The internet header padding is + // used to ensure that the internet header ends on a 32 bit boundary." + ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1) + + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + + // Specify ident/seq to make sure we get the same in the response. + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + if test.maxTotalLength < totalLen { + totalLen = test.maxTotalLength + } + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHeaderLength), + TotalLength: totalLen, + Protocol: test.transportProtocol, + TTL: test.TTL, + SrcAddr: remoteIPv4Addr, + DstAddr: ipv4Addr.Address, + }) + if n := copy(ip.Options(), test.options); n != len(test.options) { + t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options)) + } + // Override the correct value if the test case specified one. + if test.headerLength != 0 { + ip.SetHeaderLength(test.headerLength) + } + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e.Read() + if !ok { + if test.shouldFail { + if test.expectICMP { + t.Fatal("expected ICMP error response missing") + } + return // Expected silent failure. + } + t.Fatal("expected ICMP echo reply missing") + } + + // Check the route that brought the packet to us. + if reply.Route.LocalAddress != ipv4Addr.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) + } + if reply.Route.RemoteAddress != remoteIPv4Addr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) + } + + // Make sure it's all in one buffer. + vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) + replyIPHeader := header.IPv4(vv.ToView()) + + // At this stage we only know it's an IP header so verify that much. + checker.IPv4(t, replyIPHeader, + checker.SrcAddr(ipv4Addr.Address), + checker.DstAddr(remoteIPv4Addr), + ) + + // All expected responses are ICMP packets. + if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { + t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + } + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) + + // Sanity check the response. + switch replyICMPHeader.Type() { + case header.ICMPv4DstUnreachable: + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + if !test.shouldFail || !test.expectICMP { + t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + return + case header.ICMPv4EchoReply: + checker.IPv4(t, replyIPHeader, + checker.IPv4HeaderLength(ipHeaderLength), + checker.IPv4Options(test.options), + checker.IPFullLength(uint16(requestPkt.Size())), + checker.ICMPv4( + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Seq(randomSequence), + checker.ICMPv4Ident(randomIdent), + checker.ICMPv4Checksum(), + ), + ) + if test.shouldFail { + t.Fatalf("unexpected Echo Reply packet\n") + } + default: + t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + } + }) + } +} + // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { @@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if i == 0 { - got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() - // sourcePacketInfo does not have NetworkHeader added, simulate one. - want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() - // Check that it kept the transport header in packet.TransportHeader if - // it fits in the first fragment. - if want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) - } - } if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } @@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } func TestFragmentation(t *testing.T) { + const ttl = 42 + var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 @@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) { payloadViewsSizes []int expectedFrags int }{ - {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, + {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, + {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, + {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, + {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, + {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, + {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, } for _, ft := range fragTests { @@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) { source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Errorf("got err = %s, want = nil", err) + t.Fatalf("r.WritePacket(_, _, _) = %s", err) } if got := len(ep.WrittenPackets); got != ft.expectedFrags { @@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) { if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } compareFragments(t, ep.WrittenPackets, source, ft.mtu) }) } @@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) { // TestFragmentationErrors checks that errors are returned from write packet // correctly. func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + expectedError := tcpip.ErrAborted fragTests := []struct { description string mtu uint32 transportHeaderLength int - payloadViewsSizes []int - err *tcpip.Error + payloadSize int allowPackets int + fragmentCount int }{ - {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0}, - {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1}, - {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0}, + { + description: "No frag", + mtu: 2000, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 1, + }, + { + description: "Error on first frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 0, + fragmentCount: 3, + }, + { + description: "Error on second frag", + mtu: 500, + transportHeaderLength: 0, + payloadSize: 1000, + allowPackets: 1, + fragmentCount: 3, + }, + { + description: "Error on first frag MTU smaller than header", + mtu: 500, + transportHeaderLength: 1000, + payloadSize: 500, + allowPackets: 0, + fragmentCount: 4, + }, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets) + ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, - TTL: 42, + TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != ft.err { - t.Errorf("got WritePacket() = %s, want = %s", err, ft.err) + if err != expectedError { + t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) } if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) } + if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + } }) } } @@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 + tests := []struct { name string setup func(*testing.T, *stack.Stack) @@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\x10\x00\x00\x02" ) if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + mask := tcpip.AddressMask(header.IPv4Broadcast) + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err) } return rt } @@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, lm.limit-- return false, false } + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize, + TTL: ipv4.DefaultTTL, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ipv4.DefaultTTL, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != header.IPv4ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + checker.ICMPv4Code(header.ICMPv4UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a ARP request since link address resolution should be + // performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != arp.ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) + } + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(p.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address) + } + } + + // Send an ARP reply to complete link address resolution. + { + hdr := buffer.View(make([]byte, header.ARPSize)) + packet := header.ARP(hdr) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), host2NICLinkAddr) + copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address) + copy(packet.HardwareAddressTarget(), host1NICLinkAddr) + copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address) + e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 97adbcbd4..a30437f02 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -18,6 +18,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", ], ) @@ -41,6 +42,7 @@ go_test( "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 4b4b483cc..7be35c78b 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -286,6 +286,17 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } + // As per RFC 4861 section 7.1.1: + // A node MUST silently discard any received Neighbor Solicitation + // messages that do not satisfy all of the following validity checks: + // ... + // - If the IP source address is the unspecified address, the IP + // destination address is a solicited-node multicast address. + if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + received.Invalid.Increment() + return + } + // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the // route we end up with here has as its LocalAddress such a @@ -429,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - remoteLinkAddr := r.RemoteLinkAddress - // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. localAddr := r.LocalAddress @@ -445,9 +454,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } defer r.Release() - // Use the link address from the source of the original packet. - r.ResolveWith(remoteLinkAddr) - replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data, @@ -635,18 +641,21 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - snaddr := header.SolicitedNodeAddr(addr) - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the // route here. Note, we would need the nicID to do this properly so the right // NIC (associated to linkEP) is used to send the NDP NS message. - r := &stack.Route{ + r := stack.Route{ LocalAddress: localAddr, - RemoteAddress: snaddr, + RemoteAddress: addr, RemoteLinkAddress: remoteLinkAddr, } + + // If a remote address is not already known, then send a multicast + // solicitation since multicast addresses have a static mapping to link + // addresses. if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) + r.RemoteAddress = header.SolicitedNodeAddr(addr) + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -672,7 +681,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -690,6 +699,36 @@ type icmpReason interface { isICMPReason() } +// icmpReasonParameterProblem is an error during processing of extension headers +// or the fixed header defined in RFC 4443 section 3.4. +type icmpReasonParameterProblem struct { + code header.ICMPv6Code + + // respondToMulticast indicates that we are sending a packet that falls under + // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2: + // + // (e.3) A packet destined to an IPv6 multicast address. (There are + // two exceptions to this rule: (1) the Packet Too Big Message + // (Section 3.2) to allow Path MTU discovery to work for IPv6 + // multicast, and (2) the Parameter Problem Message, Code 2 + // (Section 3.4) reporting an unrecognized IPv6 option (see + // Section 4.2 of [IPv6]) that has the Option Type highest- + // order two bits set to 10). + respondToMulticast bool + + // pointer is defined in the RFC 4443 setion 3.4 which reads: + // + // Pointer Identifies the octet offset within the invoking packet + // where the error was detected. + // + // The pointer will point beyond the end of the ICMPv6 + // packet if the field in error is beyond what can fit + // in the maximum size of an ICMPv6 error message. + pointer uint32 +} + +func (*icmpReasonParameterProblem) isICMPReason() {} + // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. type icmpReasonPortUnreachable struct{} @@ -698,18 +737,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { - stats := r.Stats().ICMP - sent := stats.V6PacketsSent - if !r.Stack().AllowICMPMessage() { - sent.RateLimited.Increment() - return nil - } - +func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // - // TODO(b/164522993) There are exceptions to this rule. + // There are exceptions to this rule. // See: point e.3) RFC 4443 section-2.4 // // (e) An ICMPv6 error message MUST NOT be originated as a result of @@ -727,7 +759,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // Section 4.2 of [IPv6]) that has the Option Type highest- // order two bits set to 10). // - if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any { + var allowResponseToMulticast bool + if reason, ok := reason.(*icmpReasonParameterProblem); ok { + allowResponseToMulticast = reason.respondToMulticast + } + + if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + return nil + } + + // Even if we were able to receive a packet from some remote, we may not have + // a route to it - the remote may be blocked via routing rules. We must always + // consult our routing table and find a route to the remote before sending any + // packet. + route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + // From this point on, the incoming route should no longer be used; route + // must be used to send the ICMP error. + r = nil + + stats := p.stack.Stats().ICMP + sent := stats.V6PacketsSent + if !p.stack.AllowICMPMessage() { + sent.RateLimited.Increment() return nil } @@ -757,11 +814,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc // packet that caused the error) as possible without making // the error message packet exceed the minimum IPv6 MTU // [IPv6]. - mtu := int(r.MTU()) + mtu := int(route.MTU()) if mtu > header.IPv6MinimumMTU { mtu = header.IPv6MinimumMTU } - headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize + headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize available := int(mtu) - headerLen if available < header.IPv6MinimumSize { return nil @@ -780,12 +837,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data)) - counter := sent.DstUnreachable - err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt) - if err != nil { + var counter *tcpip.StatCounter + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + icmpHdr.SetType(header.ICMPv6ParamProblem) + icmpHdr.SetCode(reason.code) + icmpHdr.SetTypeSpecific(reason.pointer) + counter = sent.ParamProblem + case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + counter = sent.DstUnreachable + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data)) + if err := route.WritePacket( + nil, /* gso */ + stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, + newPkt, + ); err != nil { sent.Dropped.Increment() return err } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 31370c1d4..3affcc4e4 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,21 @@ package ipv6 import ( "context" + "net" "reflect" "strings" "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -39,6 +43,9 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 30 * time.Second ) var ( @@ -50,6 +57,10 @@ type stubLinkEndpoint struct { stack.LinkEndpoint } +func (*stubLinkEndpoint) MTU() uint32 { + return defaultMTU +} + func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { // Indicate that resolution for link layer addresses is required to send // packets over this link. This is needed so the NIC knows to allocate a @@ -105,7 +116,9 @@ func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { var _ stack.NetworkInterface = (*testInterface)(nil) -type testInterface struct{} +type testInterface struct { + stack.NetworkLinkEndpoint +} func (*testInterface) ID() tcpip.NICID { return 0 @@ -123,10 +136,6 @@ func (*testInterface) Enabled() bool { return true } -func (*testInterface) LinkEndpoint() stack.LinkEndpoint { - return nil -} - func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -1219,19 +1228,22 @@ func TestLinkAddressRequest(t *testing.T) { mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + name string + remoteLinkAddr tcpip.LinkAddress + expectedLinkAddr tcpip.LinkAddress + expectedAddr tcpip.Address }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectLinkAddr: linkAddr1, + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectedLinkAddr: linkAddr1, + expectedAddr: lladdr0, }, { - name: "Multicast", - remoteLinkAddr: "", - expectLinkAddr: mcaddr, + name: "Multicast", + remoteLinkAddr: "", + expectedLinkAddr: mcaddr, + expectedAddr: snaddr, }, } @@ -1254,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } + if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + } + if pkt.Route.RemoteAddress != test.expectedAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + } + if pkt.Route.LocalAddress != lladdr1 { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) + } + checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), + checker.SrcAddr(lladdr1), + checker.DstAddr(test.expectedAddr), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(lladdr0), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}), + )) + } +} + +func TestPacketQueing(t *testing.T) { + const nicID = 1 + + var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, } + ) + + tests := []struct { + name string + rxPkt func(*channel.Endpoint) + checkResp func(*testing.T, *channel.Endpoint) + }{ + { + name: "ICMP Error", + rxPkt: func(e *channel.Endpoint) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + }, + }, + + { + name: "Ping", + rxPkt: func(e *channel.Endpoint) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + }, + checkResp: func(t *testing.T, e *channel.Endpoint) { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + checker.ICMPv6Code(header.ICMPv6UnusedCode))) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + + // Receive a packet to trigger link resolution before a response is sent. + test.rxPkt(e) + + // Wait for a neighbor solicitation since link address resolution should + // be performed. + { + p, ok := e.ReadContext(context.Background()) + if !ok { + t.Fatalf("timed out waiting for packet") + } + if p.Proto != ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) + } + snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address), + checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}), + )) + } + + // Send a neighbor advertisement to complete link address resolution. + { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + // Expect the response now that the link address has resolved. + test.checkResp(t, e) + + // Since link resolution was already performed, it shouldn't be performed + // again. + test.rxPkt(e) + test.checkResp(t, e) + }) } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index aff4e1425..2bd8f4ece 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ package ipv6 import ( + "encoding/binary" "fmt" + "hash/fnv" "sort" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,10 +29,20 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( + // As per RFC 8200 section 4.5: + // If insufficient fragments are received to complete reassembly of a packet + // within 60 seconds of the reception of the first-arriving fragment of that + // packet, reassembly of that packet must be abandoned. + // + // Linux also uses 60 seconds for reassembly timeout: + // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 + reassembleTimeout = 60 * time.Second + // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -40,6 +53,9 @@ const ( // DefaultTTL is the default hop limit for IPv6 Packets egressed by // Netstack. DefaultTTL = 64 + + // buckets for fragment identifiers + buckets = 2048 ) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) @@ -50,7 +66,6 @@ var _ NDPEndpoint = (*endpoint)(nil) type endpoint struct { nic stack.NetworkInterface - linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache nud stack.NUDHandler dispatcher stack.TransportDispatcher @@ -348,21 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.linkEP.MTU()) + return calculateMTU(e.nic.MTU()) } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { @@ -376,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + pkt.NetworkProtocolNumber = ProtocolNumber +} + +func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { + return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The mtu is the maximum size of the packets. The transport +// header protocol number is required to avoid parsing the IPv6 extension +// headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + if fragMTU < pkt.TransportHeader().View().Size() { + // As per RFC 8200 Section 4.5, the Transport Header is expected to be small + // enough to fit in the first fragment. + return 0, 1, tcpip.ErrMessageTooLong + } + + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + break + } + } + + return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. @@ -402,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) ep.HandlePacket(&route, pkt) return nil @@ -423,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if e.packetMustBeFragmented(pkt, gso) { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.Stats().IP.PacketsSent.Increment() return nil } -// WritePackets implements stack.LinkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") @@ -441,6 +500,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) + if e.packetMustBeFragmented(pb, gso) { + current := pb + _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(current, fragPkt) + current = current.Next() + return nil + }) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + // The fragmented packet can be released. The rest of the packets can be + // processed. + pkts.Remove(pb) + pb = current + } } // iptables filtering. All packets that reach here are locally @@ -451,8 +527,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + } return n, err } r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) @@ -466,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) @@ -475,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -536,7 +616,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } - for firstHeader := true; ; firstHeader = false { + for { + // Keep track of the start of the previous header so we can report the + // special case of a Hop by Hop at a location other than at the start. + previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -550,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6HopByHopOptionsExtHdr: // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 - // (unrecognized next header) error in response to an extension header's - // Next Header field with the Hop By Hop extension header identifier. - if !firstHeader { + if previousHeaderStart != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: previousHeaderStart, + }, pkt) return } @@ -576,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) @@ -593,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.4, if a node encounters a routing header with // an unrecognized routing type value, with a non-zero Segments Left // value, the node must discard the packet and send an ICMP Parameter - // Problem, Code 0. If the Segments Left is 0, the node must ignore the - // Routing extension header and process the next header in the packet. + // Problem, Code 0 to the packet's Source Address, pointing to the + // unrecognized Routing Type. + // + // If the Segments Left is 0, the node must ignore the Routing extension + // header and process the next header in the packet. // // Note, the stack does not yet handle any type of routing extension // header, so we just make sure Segments Left is zero before processing // the next extension header. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for - // unrecognized routing types with a non-zero Segments Left value. if extHdr.SegmentsLeft() != 0 { + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: it.ParseOffset(), + }, pkt) return } @@ -737,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionSkip: case header.IPv6OptionUnknownActionDiscard: return - case header.IPv6OptionUnknownActionDiscardSendICMP: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. - return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for - // unrecognized IPv6 extension header options. + if header.IsV6MulticastAddress(r.LocalAddress) { + return + } + fallthrough + case header.IPv6OptionUnknownActionDiscardSendICMP: + // This case satisfies a requirement of RFC 8200 section 4.2 + // which states that an unknown option starting with bits [10] should: + // + // discard the packet and, regardless of whether or not the + // packet's Destination Address was a multicast address, send an + // ICMP Parameter Problem, Code 2, message to the packet's + // Source Address, pointing to the unrecognized Option Type. + // + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownOption, + pointer: it.ParseOffset() + optsIt.OptionOffset(), + respondToMulticast: true, + }, pkt) return default: panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) @@ -767,8 +878,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { pkt.TransportProtocolNumber = p e.handleICMP(r, pkt, hasFragmentHeader) } else { - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + r.Stats().IP.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -777,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + case stack.TransportPacketProtocolUnreachable: + // As per RFC 8200 section 4. (page 7): + // Extension headers are numbered from IANA IP Protocol Numbers + // [IANA-PN], the same values used for IPv4 and IPv6. When + // processing a sequence of Next Header values in a packet, the + // first one that is not an extension header [IANA-EH] indicates + // that the next item in the packet is the corresponding upper-layer + // header. + // With more related information on page 8: + // If, as a result of processing a header, the destination node is + // required to proceed to the next header but the Next Header value + // in the current header is unrecognized by the node, it should + // discard the packet and send an ICMP Parameter Problem message to + // the source of the packet, with an ICMP Code value of 1 + // ("unrecognized Next Header type encountered") and the ICMP + // Pointer field containing the offset of the unrecognized value + // within the original packet. + // + // Which when taken together indicate that an unknown protocol should + // be treated as an unrecognized next header value. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } } default: - // If we receive a packet for an extension header we do not yet handle, - // drop the packet for now. - // - // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error - // in response to unrecognized next header values. + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6UnknownHeader, + pointer: it.ParseOffset(), + }, pkt) r.Stats().UnknownProtocolRcvdPackets.Increment() return } @@ -1097,6 +1230,9 @@ type protocol struct { eps map[*endpoint]struct{} } + ids []uint32 + hashIV uint32 + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -1157,7 +1293,6 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ nic: nic, - linkEP: nic.LinkEndpoint(), linkAddrCache: linkAddrCache, nud: nud, dispatcher: dispatcher, @@ -1318,10 +1453,15 @@ type Options struct { func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { opts.NDPConfigs.validate() + ids := hash.RandN32(buckets) + hashIV := hash.RandN32(1)[0] + return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + ids: ids, + hashIV: hashIV, ndpDisp: opts.NDPDisp, ndpConfigs: opts.NDPConfigs, @@ -1339,3 +1479,73 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } + +// calculateFragmentInnerMTU calculates the maximum number of bytes of +// fragmentable data a fragment can have, based on the link layer mtu and pkt's +// network header size. +func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // MTU because they should only be transmitted once. + mtu -= uint32(pkt.NetworkHeader().View().Size()) + mtu -= header.IPv6FragmentHeaderSize + // Round the MTU down to align to 8 bytes. + mtu &^= 7 + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func calculateFragmentReserve(pkt *stack.PacketBuffer) int { + return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address and 32-bit number to generate the hash. +func hashRoute(r *stack.Route, hashIV uint32) uint32 { + // The FNV-1a was chosen because it is a fast hashing algorithm, and + // cryptographic properties are not needed here. + h := fnv.New32a() + if _, err := h.Write([]byte(r.LocalAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + if _, err := h.Write([]byte(r.RemoteAddress)); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err)) + } + + s := make([]byte, 4) + binary.LittleEndian.PutUint32(s, hashIV) + if _, err := h.Write(s); err != nil { + panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err)) + } + + return h.Sum32() +} + +func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) { + fragPkt, offset, copied, more := pf.BuildNextFragment() + fragPkt.NetworkProtocolNumber = ProtocolNumber + + originalIPHeadersLength := len(originalIPHeaders) + fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + + // Copy the IPv6 header and any extension headers already populated. + if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { + panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) + } + fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) + + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) + fragmentHeader.Encode(&header.IPv6FragmentFields{ + M: more, + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + Identification: id, + NextHeader: uint8(transportProto), + }) + + return fragPkt, more +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 94344057e..e792ca9e2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,17 +15,21 @@ package ipv6 import ( + "encoding/hex" + "fmt" "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst } } +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // sourcePacket does not have its IP Header populated. Let's copy the one + // from the first fragment. + source := header.IPv6(packets[0].NetworkHeader().View()) + sourceIPHeadersLen := len(source) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) + source = append(source, vv.ToView()...) + + var reassembledPayload buffer.VectorisedView + for i, fragment := range packets { + // Confirm that the packet is valid. + allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views()) + fragmentIPHeaders := header.IPv6(allBytes.ToView()) + if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders)) + } + + fragmentIPHeadersLength := fragment.NetworkHeader().View().Size() + if fragmentIPHeadersLength != sourceIPHeadersLen { + return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen) + } + + if got := len(fragmentIPHeaders); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu) + } + + sourceIPHeader := source[:header.IPv6MinimumSize] + fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize] + + if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize { + return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize) + } + + // We expect the IPv6 Header to be similar across each fragment, besides the + // payload length. + sourceIPHeader.SetPayloadLength(0) + fragmentIPHeader.SetPayloadLength(0) + if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) + } + + if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) + } + + if len(packets) > 1 { + // If the source packet was big enough that it needed fragmentation, let's + // inspect the fragment header. Because no other extension headers are + // supported, it will always be the last extension header. + fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength]) + + if got := fragmentHeader.More(); got != wantFragments[i].more { + return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more) + } + if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset { + return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset) + } + if got := fragmentHeader.NextHeader(); got != uint8(proto) { + return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto)) + } + } + + // Store the reassembled payload as we parse each fragment. The payload + // includes the Transport header and everything after. + reassembledPayload.AppendView(fragment.TransportHeader().View()) + reassembledPayload.Append(fragment.Data) + } + + result := reassembledPayload.ToView() + if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + + return nil +} + // TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and // UDP packets destined to the IPv6 link-local all-nodes multicast address. func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { @@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { - const nicID = 1 - tests := []struct { name string protocolFactory stack.TransportProtocolFactory @@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, @@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) { } func TestReceiveIPv6ExtHdrs(t *testing.T) { - const nicID = 1 - tests := []struct { name string extHdr func(nextHdr uint8) ([]byte, uint8) shouldAccept bool + // Should we expect an ICMP response and if so, with what contents? + expectICMP bool + ICMPType header.ICMPv6Type + ICMPCode header.ICMPv6Code + pointer uint32 + multicast bool }{ { name: "None", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, shouldAccept: true, + expectICMP: false, }, { name: "hopbyhop with unknown option skippable action", @@ -336,9 +419,30 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action (unicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action", + name: "hopbyhop with unknown option discard and send icmp action (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -348,12 +452,18 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + multicast: true, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -364,39 +474,97 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. }, hopByHopExtHdrID }, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ Unknown option. + }, hopByHopExtHdrID + }, + multicast: true, shouldAccept: false, + expectICMP: false, }, { - name: "routing with zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: true, }, { - name: "routing with non-zero segments left", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 1, 2, 3, 4, 5, + }, routingExtHdrID + }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6ErroneousHeader, + pointer: header.IPv6FixedHeaderSize + 2, }, { - name: "atomic fragment with zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 0, 0, 0, 0, + }, fragmentExtHdrID + }, shouldAccept: true, }, { - name: "atomic fragment with non-zero ID", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 0, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: true, + expectICMP: false, }, { - name: "fragment", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, + 1, 0, 1, 2, 3, 4, + }, fragmentExtHdrID + }, shouldAccept: false, + expectICMP: false, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{}, + noNextHdrID + }, shouldAccept: false, + expectICMP: false, }, { name: "destination with unknown option skippable action", @@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: true, + expectICMP: false, }, { name: "destination with unknown option discard action", @@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, destinationExtHdrID }, shouldAccept: false, + expectICMP: false, }, { - name: "destination with unknown option discard and send icmp action", + name: "destination with unknown option discard and send icmp action (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP if option is unknown. 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "destination with unknown option discard and send icmp action unless multicast dest", + name: "destination with unknown option discard and send icmp action (muilticast)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + //^ 191 is an unknown option. + }, destinationExtHdrID + }, + multicast: true, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, + }, + { + name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ nextHdr, 1, @@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Discard & send ICMP unless packet is for multicast destination if // option is unknown. 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. }, destinationExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownOption, + pointer: header.IPv6FixedHeaderSize + 8, }, { - name: "routing - atomic fragment", + name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ - // Routing extension header. - fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + nextHdr, 1, - // Fragment extension header. - nextHdr, 0, 0, 0, 1, 2, 3, 4, - }, routingExtHdrID + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + //^ 255 is unknown. + }, destinationExtHdrID }, - shouldAccept: true, + shouldAccept: false, + expectICMP: false, + multicast: true, }, { name: "atomic fragment - routing", @@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { return []byte{ // Routing extension header. hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. // Hop By Hop extension header with skippable unknown option. nextHdr, 0, 62, 4, 1, 2, 3, 4, }, routingExtHdrID }, shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { + name: "routing - hop by hop (with send icmp unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + // ^^^ The HopByHop extension header may not appear after the first + // extension header. + + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, routingExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, }, { name: "No next header", @@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", @@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, hopByHopExtHdrID }, shouldAccept: false, + expectICMP: false, }, } @@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(1, 1280, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + // Add a default route so that a return packet knows where to go. + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: ipv6NextHdr, HopLimit: 255, SrcAddr: addr1, - DstAddr: addr2, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 0", got) } + if !test.expectICMP { + if p, ok := e.Read(); ok { + t.Fatalf("unexpected packet received: %#v", p) + } + return + } + + // ICMP required. + p, ok := e.Read() + if !ok { + t.Fatalf("expected packet wasn't written out") + } + + // Pack the output packet into a single buffer.View as the checkers + // assume that. + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() + if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want { + t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want) + } + + ipHdr := header.IPv6(pkt) + checker.IPv6(t, ipHdr, checker.ICMPv6( + checker.ICMPv6Type(test.ICMPType), + checker.ICMPv6Code(test.ICMPCode))) + + // We know we are looking at no extension headers in the error ICMP + // packets. + icmpPkt := header.ICMPv6(ipHdr.Payload()) + // We know we sent small packets that won't be truncated when reflected + // back to us. + originalPacket := icmpPkt.Payload() + if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want { + t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want) + } + if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" { + t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff) + } return } @@ -683,7 +972,6 @@ type fragmentData struct { func TestReceiveIPv6Fragments(t *testing.T) { const ( - nicID = 1 udpPayload1Length = 256 udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of @@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) { t.Run(test.name, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) - var pkts stack.PacketBufferList for i := 0; i < nPackets; i++ { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) if err := s.AddAddress(1, ProtocolNumber, src); err != nil { - t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err) + t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err) } { - subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff") + subnet, err := tcpip.NewSubnet(dst, mask) if err != nil { - t.Fatalf("NewSubnet(_, _) failed: %v", err) + t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err) } s.SetRouteTable([]tcpip.Route{{ Destination: subnet, @@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { } rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err) + t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err) } return rt } @@ -1922,3 +2210,293 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { } } } + +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} + +type fragmentationTestCase struct { + description string + mtu uint32 + gso *stack.GSO + transHdrLen int + extraHdrLen int + payloadSize int + wantFragments []fragmentInfo + expectedFrags int +} + +var fragmentationTests = []fragmentationTestCase{ + { + description: "No Fragmentation", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso nil", + mtu: 1280, + gso: nil, + transHdrLen: 0, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 176, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 100, + extraHdrLen: header.IPv6MinimumSize, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 76, more: false}, + }, + }, + { + description: "Fragmented with big header and prependable bytes", + mtu: 1280, + gso: &stack.GSO{}, + transHdrLen: 20, + extraHdrLen: header.IPv6MinimumSize + 66, + payloadSize: 1500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 296, more: false}, + }, + }, +} + +func TestFragmentation(t *testing.T) { + const ( + ttl = 42 + tos = stack.DefaultTOS + transportProto = tcp.ProtocolNumber + ) + + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt.Clone() + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != nil { + t.Fatalf("WritePacket(_, _, _): = %s", err) + } + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + if len(ep.WrittenPackets) > 0 { + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + tests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + source := pkt + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if n != wantTotalPackets || err != nil { + t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } + }) + } +} + +// TestFragmentationErrors checks that errors are returned from WritePacket +// correctly. +func TestFragmentationErrors(t *testing.T) { + const ttl = 42 + + tests := []struct { + description string + mtu uint32 + transHdrLen int + payloadSize int + allowPackets int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error + }{ + { + description: "No frag", + mtu: 2000, + payloadSize: 1000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on first frag", + mtu: 1300, + payloadSize: 3000, + transHdrLen: 0, + allowPackets: 0, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on second frag", + mtu: 1500, + payloadSize: 4000, + transHdrLen: 0, + allowPackets: 1, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error on packet with MTU smaller than transport header", + mtu: 1280, + transHdrLen: 1500, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrMessageTooLong, + }, + } + + for _, ft := range tests { + t.Run(ft.description, func(t *testing.T) { + pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }, pkt) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 48a4c65e3..40da011f8 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -1289,7 +1289,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt // // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.ep.linkEP.LinkAddress() + linkAddr := ndp.ep.nic.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { return false } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 25464a03a..9033a9ed5 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { naDst tcpip.Address }{ { - name: "Unspecified source to multicast destination", + name: "Unspecified source to solicited-node multicast destination", nsOpts: nil, nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, @@ -437,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrcLinkAddr: remoteLinkAddr0, nsSrc: header.IPv6Any, nsDst: nicAddr, - nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, - naSolicited: false, - naSrc: nicAddr, - naDst: header.IPv6AllNodesMulticastAddress, + nsInvalid: true, }, { name: "Unspecified source with source ll option to unicast destination", diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index c9e57dc0d..d0ffc299a 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -8,6 +8,7 @@ go_library( "testutil.go", ], visibility = [ + "//pkg/tcpip/network/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", ], diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2eaeab779..eba97334e 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,7 +56,6 @@ go_library( srcs = [ "addressable_endpoint_state.go", "conntrack.go", - "forwarder.go", "headertype_string.go", "icmp_rate_limit.go", "iptables.go", @@ -73,6 +72,7 @@ go_library( "nud.go", "packet_buffer.go", "packet_buffer_list.go", + "pending_packets.go", "rand.go", "registration.go", "route.go", @@ -123,7 +123,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", - "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", @@ -139,7 +138,7 @@ go_test( name = "stack_test", size = "small", srcs = [ - "forwarder_test.go", + "forwarding_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index db8ac1c2b..4d3acab96 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -679,11 +679,6 @@ type addressState struct { } } -// NetworkEndpoint implements AddressEndpoint. -func (a *addressState) NetworkEndpoint() NetworkEndpoint { - return a.addressableEndpointState.networkEndpoint -} - // AddressWithPrefix implements AddressEndpoint. func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 91ae7dafc..0cd1da11f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -196,13 +196,14 @@ type bucket struct { // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. +// +// Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := pkt.Network() + if netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol @@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { dstAddr: netHeader.DestinationAddress(), dstPort: tcpHeader.DestinationPort(), transProto: netHeader.TransportProtocol(), - netProto: header.IPv4ProtocolNumber, + netProto: pkt.NetworkProtocolNumber, }, nil } @@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction @@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { // support cases when they are validated, e.g. when we can't offload // receive checksumming. - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacketOutput manipulates ports for packets in Output hook. @@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction @@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } } // handlePacket will manipulate the port and address of the packet if the @@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { + if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return } @@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1 dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, transProto: header.TCPProtocolNumber, - netProto: header.IPv4ProtocolNumber, + netProto: netProto, } conn, _ := ct.connForTID(tid) if conn == nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go index 4e4b00a92..cf042309e 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -48,10 +48,9 @@ const ( type fwdTestNetworkEndpoint struct { AddressableEndpointState - nicID tcpip.NICID + nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher - ep LinkEndpoint } var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) @@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool { func (*fwdTestNetworkEndpoint) Disable() {} func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) + return f.nic.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { @@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen + return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. @@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { e := &fwdTestNetworkEndpoint{ - nicID: nic.ID(), + nic: nic, proto: f, dispatcher: dispatcher, - ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index faa503b00..8d6d9a7f1 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, tcpip.ErrNotConnected } - return it.connections.originalDst(epID) + return it.connections.originalDst(epID, netProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 08063f6ff..538c4625d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -34,7 +34,7 @@ func (at *AcceptTarget) ID() TargetID { } // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -52,7 +52,7 @@ func (dt *DropTarget) ID() TargetID { } // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -76,7 +76,7 @@ func (et *ErrorTarget) ID() TargetID { } // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -99,7 +99,7 @@ func (uc *UserChainTarget) ID() TargetID { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -118,7 +118,7 @@ func (rt *ReturnTarget) ID() TargetID { } // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -153,7 +153,7 @@ func (rt *RedirectTarget) ID() TargetID { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,11 +164,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1) in Output and - // to primary address of the incoming interface in Prerouting. + // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // the primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + } else { + rt.Addr = header.IPv6Loopback + } case Prerouting: rt.Addr = address default: @@ -177,8 +181,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader().View()) - switch protocol := netHeader.TransportProtocol(); protocol { + switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.Port) @@ -186,10 +189,10 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := r.PseudoHeaderChecksum(protocol, length) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) @@ -198,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - // Change destination address. - netHeader.SetDestinationAddress(rt.Addr) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) + + pkt.Network().SetDestinationAddress(rt.Addr) + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 27e1feec0..4df288798 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA defer entry.mu.Unlock() switch s := entry.neigh.State; s { - case Reachable, Static: + case Stale: + entry.handlePacketQueuedLocked() + fallthrough + case Reachable, Static, Delay, Probe: + // As per RFC 4861 section 7.3.3: + // "Neighbor Unreachability Detection operates in parallel with the sending + // of packets to a neighbor. While reasserting a neighbor's reachability, + // a node continues sending packets to that neighbor using the cached + // link-layer address." return entry.neigh, nil, nil - - case Unknown, Incomplete, Stale, Delay, Probe: + case Unknown, Incomplete: entry.addWakerLocked(w) if entry.done == nil { @@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.handlePacketQueuedLocked() return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress - default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index a0b7da5cd..fcd54ed83 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) { } // Verify the entry exists - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) - } - if t.Failed() { - t.FailNow() - } - want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + { + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } } // Notify of a link address change @@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) { IsRouter: false, }) - // Requesting the entry again should start address resolution + // Requesting the entry again should start neighbor reachability confirmation. + // + // Verify the entry's new link address and the new state. { - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } - clock.Advance(config.DelayFirstProbeTime + typicalLatency) - select { - case <-doneCh: - default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Delay, } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + clock.Advance(config.DelayFirstProbeTime + typicalLatency) } - // Verify the entry's new link address + // Verify that the neighbor is now reachable. { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } - want = NeighborEntry{ + want := NeighborEntry{ Addr: entry.Addr, LocalAddr: entry.LocalAddr, LinkAddr: updatedLinkAddr, diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 9a72bec79..4d69a4de1 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { // There is no need to log the error here; the NUD implementation may // assume a working link. A valid link should be the responsibility of // the NIC/stack.LinkEndpoint. @@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index a265fff0a..e79abebca 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ - id: entryTestNICID, - linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + + id: entryTestNICID, stack: &Stack{ clock: clock, nudDisp: &disp, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 23022292c..8828cc5fe 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -32,14 +32,18 @@ var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { + LinkEndpoint + stack *Stack id tcpip.NICID name string - linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + + // The network endpoints themselves may be modified by calling the interface's + // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -88,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ + LinkEndpoint: ep, + stack: stack, id: id, name: name, - linkEP: ep, context: ctx, stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), @@ -127,11 +132,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } - nic.linkEP.Attach(nic) + nic.LinkEndpoint.Attach(nic) return nic } +func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { + return n.networkEndpoints[proto] +} + // Enabled implements NetworkInterface. func (n *NIC) Enabled() bool { return atomic.LoadUint32(&n.enabled) == 1 @@ -211,10 +220,9 @@ func (n *NIC) remove() *tcpip.Error { for _, ep := range n.networkEndpoints { ep.Close() } - n.networkEndpoints = nil // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.LinkEndpoint.Attach(nil) return nil } @@ -234,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool { // IsLoopback implements NetworkInterface. func (n *NIC) IsLoopback() bool { - return n.linkEP.Capabilities()&CapabilityLoopback != 0 + return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0 +} + +// WritePacket implements NetworkLinkEndpoint. +func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // As per relevant RFCs, we should queue packets while we wait for link + // resolution to complete. + // + // RFC 1122 section 2.3.2.2 (for IPv4): + // The link layer SHOULD save (rather than discard) at least + // one (the latest) packet of each set of packets destined to + // the same unresolved IP address, and transmit the saved + // packet when the address has been resolved. + // + // RFC 4861 section 5.2 (for IPv6): + // Once the IP address of the next-hop node is known, the sender + // examines the Neighbor Cache for link-layer information about that + // neighbor. If no entry exists, the sender creates one, sets its state + // to INCOMPLETE, initiates Address Resolution, and then queues the data + // packet pending completion of address resolution. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + r := r.Clone() + n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + return nil + } + return err + } + + return n.writePacket(r, gso, protocol, pkt) +} + +func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Size() + + if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + return err + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + return nil +} + +// WritePackets implements NetworkLinkEndpoint. +func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution + // is being peformed like WritePacket. + writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Size() + } + + n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + return writtenPackets, err } // setSpoofing enables or disables address spoofing. @@ -483,9 +548,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + defer r.Release() r.RemoteLinkAddress = remotelinkAddr - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) - addressEndpoint.DecRef() + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -519,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // If no local link layer address is provided, assume it was sent // directly to this NIC. if local == "" { - local = n.linkEP.LinkAddress() + local = n.LinkEndpoint.LinkAddress() } // Are any packet type sockets listening for this network protocol? @@ -599,11 +664,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n := r.nic if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.linkEP.LinkAddress() + r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) addressEndpoint.DecRef() r.Release() return @@ -614,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(b/128629022): move this logic to route.WritePacket. // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. - if ch, err := r.Resolve(nil); err != nil { - if err == tcpip.ErrWouldBlock { - n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) - // forwarder will release route. - return - } + + // pkt may have set its header and may not have enough headroom for + // link-layer header for the other link to prepend. Here we create a new + // packet to forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - r.Release() - return } - // The link-address resolution finished immediately. - n.forwardPacket(&r, protocol, pkt) r.Release() return } @@ -652,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. - n.linkEP.AddHeader(local, remote, protocol, p) + n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - // pkt may have set its header and may not have enough headroom for link-layer - // header for the other link to prepend. Here we create a new packet to - // forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - }) - - // WritePacket takes ownership of fwdPkt, calculate numBytes first. - numBytes := fwdPkt.Size() - - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return - } - - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - // TODO(gvisor.dev/issue/4365): Let the caller know that the transport - // protocol is unrecognized. n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return TransportPacketHandled + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -792,11 +832,6 @@ func (n *NIC) Name() string { return n.name } -// LinkEndpoint implements NetworkInterface. -func (n *NIC) LinkEndpoint() LinkEndpoint { - return n.linkEP -} - // nudConfigs gets the NUD configurations for n. func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index fdd49b77f..97a96af62 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil) type testIPv6Endpoint struct { AddressableEndpointState - nicID tcpip.NICID - linkEP LinkEndpoint + nic NetworkInterface protocol *testIPv6Protocol invalidatedRtr tcpip.Address @@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 { // MTU implements NetworkEndpoint.MTU. func (e *testIPv6Endpoint) MTU() uint32 { - return e.linkEP.MTU() - header.IPv6MinimumSize + return e.nic.MTU() - header.IPv6MinimumSize } // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { - return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize + return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } // WritePacket implements NetworkEndpoint.WritePacket. @@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint implements NetworkProtocol.NewEndpoint. func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { e := &testIPv6Endpoint{ - nicID: nic.ID(), - linkEP: nic.LinkEndpoint(), + nic: nic, protocol: p, } e.AddressableEndpointState.Init(e) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index a7d9d59fa..105583c49 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) type headerType int @@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { return newPk } +// Network returns the network header as a header.Network. +// +// Network should only be called when NetworkHeader has been set. +func (pk *PacketBuffer) Network() header.Network { + switch netProto := pk.NetworkProtocolNumber; netProto { + case header.IPv4ProtocolNumber: + return header.IPv4(pk.NetworkHeader().View()) + case header.IPv6ProtocolNumber: + return header.IPv6(pk.NetworkHeader().View()) + default: + panic(fmt.Sprintf("unknown network protocol number %d", netProto)) + } +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go index 3eff141e6..f838eda8d 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -29,60 +29,60 @@ const ( ) type pendingPacket struct { - nic *NIC route *Route proto tcpip.NetworkProtocolNumber pkt *PacketBuffer } -type forwardQueue struct { +// packetsPendingLinkResolution is a queue of packets pending link resolution. +// +// Once link resolution completes successfully, the packets will be written. +type packetsPendingLinkResolution struct { sync.Mutex // The packets to send once the resolver completes. - packets map[<-chan struct{}][]*pendingPacket + packets map[<-chan struct{}][]pendingPacket // FIFO of channels used to cancel the oldest goroutine waiting for // link-address resolution. cancelChans []chan struct{} } -func newForwardQueue() *forwardQueue { - return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +func (f *packetsPendingLinkResolution) init() { + f.Lock() + defer f.Unlock() + f.packets = make(map[<-chan struct{}][]pendingPacket) } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - shouldWait := false - +func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { f.Lock() + defer f.Unlock() + packets, ok := f.packets[ch] - if !ok { - shouldWait = true - } - for len(packets) == maxPendingPacketsPerResolution { + if len(packets) == maxPendingPacketsPerResolution { p := packets[0] + packets[0] = pendingPacket{} packets = packets[1:] - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() p.route.Release() } + if l := len(packets); l >= maxPendingPacketsPerResolution { panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) } - f.packets[ch] = append(packets, &pendingPacket{ - nic: n, + + f.packets[ch] = append(packets, pendingPacket{ route: r, - proto: protocol, + proto: proto, pkt: pkt, }) - f.Unlock() - if !shouldWait { + if ok { return } // Wait for the link-address resolution to complete. - // Start a goroutine with a forwarding-cancel channel so that we can - // limit the maximum number of goroutines running concurrently. - cancel := f.newCancelChannel() + cancel := f.newCancelChannelLocked() go func() { cancelled := false select { @@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc } f.Lock() - packets := f.packets[ch] + packets, ok := f.packets[ch] delete(f.packets, ch) f.Unlock() + if !ok { + panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets")) + } + for _, p := range packets { if cancelled { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() } else if _, err := p.route.Resolve(nil); err != nil { - p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.nic.forwardPacket(p.route, p.proto, p.pkt) + p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } @@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc // newCancelChannel creates a channel that can cancel a pending forwarding // activity. The oldest channel is closed if the number of open channels would // exceed maxPendingResolutions. -func (f *forwardQueue) newCancelChannel() chan struct{} { - f.Lock() - defer f.Unlock() - +func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} { if len(f.cancelChans) == maxPendingResolutions { ch := f.cancelChans[0] + f.cancelChans[0] = nil f.cancelChans = f.cancelChans[1:] close(ch) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b6f823b54..defb9129b 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -208,6 +208,10 @@ const ( // transport layer and callers need not take any further action. TransportPacketHandled TransportPacketDisposition = iota + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + // TransportPacketDestinationPortUnreachable indicates that there weren't any // listeners interested in the packet and the transport protocol has no means // to notify the sender. @@ -322,10 +326,6 @@ const ( // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { - // NetworkEndpoint returns the NetworkEndpoint the receiver is associated - // with. - NetworkEndpoint() NetworkEndpoint - // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix @@ -475,6 +475,8 @@ type NDPEndpoint interface { // NetworkInterface is a network interface. type NetworkInterface interface { + NetworkLinkEndpoint + // ID returns the interface's ID. ID() tcpip.NICID @@ -488,9 +490,6 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool - - // LinkEndpoint returns the link endpoint backing the interface. - LinkEndpoint() LinkEndpoint } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -663,22 +662,15 @@ const ( CapabilitySoftwareGSO ) -// LinkEndpoint is the interface implemented by data link layer protocols (e.g., -// ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. When a link header exists, -// it sets each PacketBuffer's LinkHeader field before passing it up the -// stack. -type LinkEndpoint interface { +// NetworkLinkEndpoint is a data-link layer that supports sending network +// layer packets. +type NetworkLinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a // physical network doesn't exist, the limit is generally 64k, which // includes the maximum size of an IP packet. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the data link (and // lower level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -686,7 +678,7 @@ type LinkEndpoint interface { MaxHeaderLength() uint16 // LinkAddress returns the link address (typically a MAC) of the - // link endpoint. + // endpoint. LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the @@ -706,6 +698,19 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, it may // require to change syscall filters. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) +} + +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. When a link header exists, +// it sets each PacketBuffer's LinkHeader field before passing it up the +// stack. +type LinkEndpoint interface { + NetworkLinkEndpoint + + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. It takes ownership of vv. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 5ade3c832..25f80c1f8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - linkEP := nic.LinkEndpoint() r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: linkEP.LinkAddress(), + LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), RemoteAddress: remoteAddr, addressEndpoint: addressEndpoint, nic: nic, Loop: loop, } - if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = nic.stack + r.linkCache = r.nic.stack } } @@ -100,7 +99,7 @@ func (r *Route) NICID() tcpip.NICID { // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength() + return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. @@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint().Capabilities() + return r.nic.LinkEndpoint.Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok { + if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 } -// ResolveWith immediately resolves a route with the specified remote link -// address. -func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr -} - // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -208,17 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf return tcpip.ErrInvalidEndpointState } - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Size() - - err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - } - return err + return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns @@ -228,22 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - // WritePackets takes ownership of pkt, calculate length first. - numPkts := pkts.Len() - - n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params) - if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) - } - r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - - writtenBytes := 0 - for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Size() - } - - r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) - return n, err + return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -253,32 +221,17 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Data.Size() - - if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.nic.stats.Tx.Packets.Increment() - r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) - return nil + return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.addressEndpoint.NetworkEndpoint().DefaultTTL() + return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.addressEndpoint.NetworkEndpoint().MTU() -} - -// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying -// network endpoint. -func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber() + return r.nic.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57d8e79e0..3a07577c8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -436,9 +436,9 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // forwarder holds the packets that wait for their link-address resolutions - // to complete, and forwards them when each resolution is done. - forwarder *forwardQueue + // linkResQueue holds packets that are waiting for link resolution to + // complete. + linkResQueue packetsPendingLinkResolution // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. @@ -550,8 +550,8 @@ type TransportEndpointInfo struct { // incompatible with the receiver. // // Preconditon: the parent endpoint mu must be held while calling this method. -func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto +func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := t.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: netProto = header.IPv4ProtocolNumber @@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } } - switch len(e.ID.LocalAddress) { + switch len(t.ID.LocalAddress) { case header.IPv4AddressSize: if len(addr.Addr) == header.IPv6AddressSize { return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState @@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl } switch { - case netProto == e.NetProto: - case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + case netProto == t.NetProto: + case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber: if v6only { return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute } @@ -640,7 +640,6 @@ func New(opts Options) *Stack { useNeighborCache: opts.UseNeighborCache, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, @@ -653,6 +652,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.linkResQueue.init() // Add specified network protocols. for _, netProtoFactory := range opts.NetworkProtocols { @@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } -// GetNICByName gets the NIC specified by name. -func (s *Stack) GetNICByName(name string) (*NIC, bool) { +// GetLinkEndpointByName gets the link endpoint specified by name. +func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint { s.mu.RLock() defer s.mu.RUnlock() for _, nic := range s.nics { if nic.Name() == name { - return nic, true + return nic.LinkEndpoint } } - return nil, false + return nil } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { } nics[id] = NICInfo{ Name: nic.name, - LinkAddress: nic.linkEP.LinkAddress(), + LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, - MTU: nic.linkEP.MTU(), + MTU: nic.LinkEndpoint.MTU(), Stats: nic.stats, Context: nic.context, - ARPHardwareType: nic.linkEP.ARPHardwareType(), + ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), } } return nics @@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) } // Neighbors returns all IP to MAC address associations. @@ -1539,7 +1539,7 @@ func (s *Stack) Wait() { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nics { - n.linkEP.Wait() + n.LinkEndpoint.Wait() } } @@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t // Add our own fake ethernet header. ethFields := header.EthernetFields{ - SrcAddr: nic.linkEP.LinkAddress(), + SrcAddr: nic.LinkEndpoint.LinkAddress(), DstAddr: dst, Type: netProto, } @@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t vv := buffer.View(fakeHeader).ToVectorisedView() vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(vv); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { return err } @@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) return tcpip.ErrUnknownDevice } - if err := nic.linkEP.WriteRawPacket(payload); err != nil { + if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { return err } @@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco return nil, tcpip.ErrUnknownNICID } - return nic.networkEndpoints[proto], nil + return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres if addressEndpoint == nil { continue } - - ep := addressEndpoint.NetworkEndpoint() addressEndpoint.DecRef() - return ep, nil + return nic.getNetworkEndpoint(netProto), nil } return nil, tcpip.ErrBadAddress } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index aa20f750b..38994cca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,7 +21,6 @@ import ( "bytes" "fmt" "math" - "net" "sort" "testing" "time" @@ -35,7 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -77,10 +75,9 @@ type fakeNetworkEndpoint struct { enabled bool } - nicID tcpip.NICID + nic stack.NetworkInterface proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { @@ -103,7 +100,7 @@ func (f *fakeNetworkEndpoint) Disable() { } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) + return f.nic.MTU() - uint32(f.MaxHeaderLength()) } func (*fakeNetworkEndpoint) DefaultTTL() uint8 { @@ -135,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.ep.MaxHeaderLength() + fakeNetHeaderLen + return f.nic.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -164,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -216,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { e := &fakeNetworkEndpoint{ - nicID: nic.ID(), + nic: nic, proto: f, dispatcher: dispatcher, - ep: nic.LinkEndpoint(), } e.AddressableEndpointState.Init(e) return e @@ -2106,7 +2102,7 @@ func TestNICStats(t *testing.T) { t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { + if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -3502,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } -func TestResolveWith(t *testing.T) { - const ( - unspecifiedNICID = 0 - nicID = 1 - ) - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - }) - ep := channel.New(0, defaultMTU, "") - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } - addr := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), - PrefixLen: 24, - }, - } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) - } - - s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) - - remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) - r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) - } - defer r.Release() - - // Should initially require resolution. - if !r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = false, want = true") - } - - // Manually resolving the route should no longer require resolution. - r.ResolveWith("\x01") - if r.IsResolutionRequired() { - t.Fatal("got r.IsResolutionRequired() = true, want = false") - } -} - // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 06c7a3cd3..a4f141253 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -6,6 +6,8 @@ go_test( name = "integration_test", size = "small", srcs = [ + "forward_test.go", + "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", ], @@ -15,6 +17,8 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/pipe", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go new file mode 100644 index 000000000..ffd38ee1a --- /dev/null +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -0,0 +1,378 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +func TestForwarding(t *testing.T) { + const ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1NICID = 1 + routerNICID1 = 2 + routerNICID2 = 3 + host2NICID = 4 + + listenPort = 8080 + ) + + host1IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + routerNIC1IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + routerNIC2IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + } + host2IPv4Addr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()), + PrefixLen: 8, + }, + } + host1IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } + routerNIC1IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + routerNIC2IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("b::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("b::2").To16()), + PrefixLen: 64, + }, + } + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + return ep, ch + } + + tests := []struct { + name string + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 host2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired) + routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + + if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) + } + if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil { + t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) + } + if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err) + } + if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) + } + + if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err) + } + if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil { + t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, + NIC: host1NICID, + }, + }) + routerStack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + tcpip.Route{ + Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID1, + }, + tcpip.Route{ + Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + tcpip.Route{ + Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: routerNICID2, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, + NIC: host2NICID, + }, + }) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack) + defer epsAndAddrs.serverEP.Close() + defer epsAndAddrs.clientEP.Close() + + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + write := func(ep tcpip.Endpoint, data []byte, to *tcpip.FullAddress) { + t.Helper() + + dataPayload := tcpip.SlicePayload(data) + wOpts := tcpip.WriteOptions{To: to} + n, ch, err := ep.Write(dataPayload, wOpts) + if err == tcpip.ErrNoLinkAddress { + // Wait for link resolution to complete. + <-ch + + n, _, err = ep.Write(dataPayload, wOpts) + } else if err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } + + if err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + } + } + + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data, &serverAddr) + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.Address) tcpip.FullAddress { + t.Helper() + + // Wait for the endpoint to be readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + + if diff := cmp.Diff(v, buffer.View(data)); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != expectedFrom { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom) + } + + if t.Failed() { + t.FailNow() + } + + return addr + } + + addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) + // Unspecify the NIC since NIC IDs are meaningless across stacks. + addr.NIC = 0 + + data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + write(epsAndAddrs.serverEP, data, &addr) + addr = read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.serverAddr) + if addr.Port != listenPort { + t.Errorf("got addr.Port = %d, want = %d", addr.Port, listenPort) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go new file mode 100644 index 000000000..bf3a6f6ee --- /dev/null +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -0,0 +1,219 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "net" + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/pipe" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +var ( + host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") + + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 8, + }, + } + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("a::2").To16()), + PrefixLen: 64, + }, + } +) + +// TestPing tests that two hosts can ping eachother when link resolution is +// enabled. +func TestPing(t *testing.T) { + const ( + host1NICID = 1 + host2NICID = 4 + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + }{ + { + name: "IPv4 Ping", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + icmpBuf: func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + }, + }, + { + name: "IPv6 Ping", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + icmpBuf: func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + } + + host1Stack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + + host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + + if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) + } + if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) + } + + if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) + } + + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + } + if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + } + if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + } + + host1Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + tcpip.Route{ + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host1NICID, + }, + }) + host2Stack.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + tcpip.Route{ + Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + NIC: host2NICID, + }, + }) + + var wq waiter.Queue + we, waiterCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := host1Stack.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + // The first write should trigger link resolution. + icmpBuf := test.icmpBuf(t) + wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} + if _, ch, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got ep.Write(_, _) = %s, want = %s", err, tcpip.ErrNoLinkAddress) + } else { + // Wait for link resolution to complete. + <-ch + } + if n, _, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } else if want := int64(len(icmpBuf)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, _, _), want = (%d, _, _)", n, want) + } + + // Wait for the endpoint to be readable. + <-waiterCH + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.remoteAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 72d86b5ab..4f2ca7f54 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -203,7 +203,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(pkt.Pkt.NetworkHeader().View()) + src, dst := s.NetworkProtocolInstance(protoNum).ParseAddresses(stack.PayloadSince(pkt.Pkt.NetworkHeader())) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6891fd245..0aaef495d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + pkt.NetworkProtocolNumber = r.NetProto data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) @@ -1219,12 +1219,6 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { return true, nil } - // Increase counter if after processing the segment we would potentially - // advertise a zero window. - if crossed, above := e.windowCrossedACKThresholdLocked(-s.segMemSize()); crossed && !above { - e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - } - // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7ad894840..3bcd3923a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -248,6 +248,11 @@ type ReceiveErrors struct { // ZeroRcvWindowState is the number of times we advertised // a zero receive window when rcvList is full. ZeroRcvWindowState tcpip.StatCounter + + // WantZeroWindow is the number of times we wanted to advertise a + // zero receive window but couldn't because it would have caused + // the receive window's right edge to shrink. + WantZeroRcvWindow tcpip.StatCounter } // SendErrors collect segment send errors within the transport layer. @@ -1162,7 +1167,7 @@ func (e *endpoint) cleanupLocked() { // wndFromSpace returns the window that we can advertise based on the available // receive buffer space. func wndFromSpace(space int) int { - return space / (1 << rcvAdvWndScale) + return space >> rcvAdvWndScale } // initialReceiveWindow returns the initial receive window to advertise in the @@ -1518,6 +1523,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } +// selectWindowLocked returns the new window without checking for shrinking or scaling +// applied. +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) + maxWindow := wndFromSpace(e.rcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvBufUsed + + // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in + // cases where we receive a lot of small segments the segment overhead is a + // lot higher and we can run out socket buffer space before we can fill the + // previous window we advertised. In cases where we receive MSS sized or close + // MSS sized segments we will probably run out of window space before we + // exhaust receive buffer. + newWnd := wndFromAvailable + if newWnd > wndFromUsedBytes { + newWnd = wndFromUsedBytes + } + if newWnd < 0 { + newWnd = 0 + } + return seqnum.Size(newWnd) +} + +// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +func (e *endpoint) selectWindow() (wnd seqnum.Size) { + e.rcvListMu.Lock() + wnd = e.selectWindowLocked() + e.rcvListMu.Unlock() + return wnd +} + // windowCrossedACKThresholdLocked checks if the receive window to be announced // would be under aMSS or under the window derived from half receive buffer, // whichever smaller. This is useful as a receive side silly window syndrome @@ -1534,7 +1571,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // Precondition: e.mu and e.rcvListMu must be held. func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := wndFromSpace(e.receiveBufferAvailableLocked()) + newAvail := int(e.selectWindowLocked()) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -2099,7 +2136,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID) + addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -3013,6 +3050,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { EndSequence: rc.endSequence, FACK: rc.fack, RTT: rc.rtt, + Reord: rc.reorderSeen, } return s } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d969ca23a..d312b1b8b 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -29,26 +29,36 @@ import ( // // +stateify savable type rackControl struct { - // xmitTime is the latest transmission timestamp of rackControl.seg. - xmitTime time.Time `state:".(unixTime)"` - // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value + // dsack indicates if the connection has seen a DSACK. + dsack bool + // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value + // minRTT is the estimated minimum RTT of the connection. + minRTT time.Duration + // rtt is the RTT of the most recently delivered packet on the // connection (either cumulatively acknowledged or selectively // acknowledged) that was not marked invalid as a possible spurious // retransmission. rtt time.Duration + + // reorderSeen indicates if reordering has been detected on this + // connection. + reorderSeen bool + + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` } -// Update will update the RACK related fields when an ACK has been received. +// update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 -func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { +func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) { rtt := time.Now().Sub(seg.xmitTime) // If the ACK is for a retransmitted packet, do not update if it is a @@ -65,12 +75,21 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, return } } - if rtt < srtt { + if rtt < rc.minRTT { return } } rc.rtt = rtt + + // The sender can either track a simple global minimum of all RTT + // measurements from the connection, or a windowed min-filtered value + // of recent RTT measurements. This implementation keeps track of the + // simple global minimum of all RTTs for the connection. + if rtt < rc.minRTT || rc.minRTT == 0 { + rc.minRTT = rtt + } + // Update rc.xmitTime and rc.endSequence to the transmit time and // ending sequence number of the packet which has been acknowledged // most recently. @@ -80,3 +99,26 @@ func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, rc.endSequence = endSeq } } + +// detectReorder detects if packet reordering has been observed. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// * Step 3: Detect data segment reordering. +// To detect reordering, the sender looks for original data segments being +// delivered out of order. To detect such cases, the sender tracks the +// highest sequence selectively or cumulatively acknowledged in the RACK.fack +// variable. The name "fack" stands for the most "Forward ACK" (this term is +// adopted from [FACK]). If a never retransmitted segment that's below +// RACK.fack is (selectively or cumulatively) acknowledged, it has been +// delivered out of order. The sender sets RACK.reord to TRUE if such segment +// is identified. +func (rc *rackControl) detectReorder(seg *segment) { + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.fack.LessThan(endSeq) { + rc.fack = endSeq + return + } + + if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { + rc.reorderSeen = true + } +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 48bf196d8..8e0b7c843 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -43,6 +43,9 @@ type receiver struct { // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size + // rcvWUP is the rcvNxt value at the last window update sent. + rcvWUP seqnum.Value + rcvWndScale uint8 closed bool @@ -64,6 +67,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale rcvNxt: irs + 1, rcvAcc: irs.Add(rcvWnd + 1), rcvWnd: rcvWnd, + rcvWUP: irs + 1, rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } @@ -84,34 +88,54 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) } +// currentWindow returns the available space in the window that was advertised +// last to our peer. +func (r *receiver) currentWindow() (curWnd seqnum.Size) { + endOfWnd := r.rcvWUP.Add(r.rcvWnd) + if endOfWnd.LessThan(r.rcvNxt) { + // return 0 if r.rcvNxt is past the end of the previously advertised window. + // This can happen because we accept a large segment completely even if + // accepting it causes it to partially exceed the advertised window. + return 0 + } + return r.rcvNxt.Size(endOfWnd) +} + // getSendParams returns the parameters needed by the sender when building // segments to send. func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { - avail := wndFromSpace(r.ep.receiveBufferAvailable()) - if avail == 0 { - // We have no space available to accept any data, move to zero window - // state. - r.rcvWnd = 0 - return r.rcvNxt, 0 - } - - acc := r.rcvNxt.Add(seqnum.Size(avail)) - newWnd := r.rcvNxt.Size(acc) - curWnd := r.rcvNxt.Size(r.rcvAcc) - + newWnd := r.ep.selectWindow() + curWnd := r.currentWindow() // Update rcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. - if newWnd > curWnd { - r.rcvAcc = r.rcvNxt.Add(newWnd) + // ==================================================== sequence space. + // ^ ^ ^ ^ + // rcvWUP rcvNxt rcvAcc new rcvAcc + // <=====curWnd ===> + // <========= newWnd > curWnd ========= > + if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) { + // If the new window moves the right edge, then update rcvAcc. + r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) } else { + if newWnd == 0 { + // newWnd is zero but we can't advertise a zero as it would cause window + // to shrink so just increment a metric to record this event. + r.ep.stats.ReceiveErrors.WantZeroRcvWindow.Increment() + } newWnd = curWnd } // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - return r.rcvNxt, r.rcvWnd >> r.rcvWndScale + r.rcvWUP = r.rcvNxt + scaledWnd := r.rcvWnd >> r.rcvWndScale + if scaledWnd == 0 { + // Increment a metric if we are advertising an actual zero window. + r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + return r.rcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 13acaf753..1f9c5cf50 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -71,6 +71,9 @@ type segment struct { // xmitTime is the last transmit time of this segment. xmitTime time.Time `state:".(unixTime)"` xmitCount uint32 + + // acked indicates if the segment has already been SACKed. + acked bool } func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index c55589c45..6fa8d63cd 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "sort" "sync/atomic" "time" @@ -263,6 +264,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint highRxt: iss, rescueRxt: iss, }, + rc: rackControl{ + fack: iss, + }, gso: ep.gso != nil, } @@ -1274,6 +1278,39 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { return true } +// Iterate the writeList and update RACK for each segment which is newly acked +// either cumulatively or selectively. Loop through the segments which are +// sacked, and update the RACK related variables and check for reordering. +// +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +// steps 2 and 3. +func (s *sender) walkSACK(rcvdSeg *segment) { + // Sort the SACK blocks. The first block is the most recent unacked + // block. The following blocks can be in arbitrary order. + sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) + copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks) + sort.Slice(sackBlocks, func(i, j int) bool { + return sackBlocks[j].Start.LessThan(sackBlocks[i].Start) + }) + + seg := s.writeList.Front() + for _, sb := range sackBlocks { + // This check excludes DSACK blocks. + if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) { + continue + } + + for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { + if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { + s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.detectReorder(seg) + seg.acked = true + } + seg = seg.Next() + } + } +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1308,6 +1345,21 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { rcvdSeg.hasNewSACKInfo = true } } + + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08 + // section-7.2 + // * Step 2: Update RACK stats. + // If the ACK is not ignored as invalid, update the RACK.rtt + // to be the RTT sample calculated using this ACK, and + // continue. If this ACK or SACK was for the most recently + // sent packet, then record the RACK.xmit_ts timestamp and + // RACK.end_seq sequence implied by this ACK. + // * Step 3: Detect packet reordering. + // If the ACK selectively or cumulatively acknowledges an + // unacknowledged and also never retransmitted sequence below + // RACK.fack, then the corresponding packet has been + // reordered and RACK.reord is set to TRUE. + s.walkSACK(rcvdSeg) s.SetPipe() } @@ -1365,9 +1417,6 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ackLeft := acked originalOutstanding := s.outstanding - s.rtt.Lock() - srtt := s.rtt.srtt - s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1388,13 +1437,14 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted { - s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + if s.ep.sackPermitted && !seg.acked { + s.rc.update(seg, rcvdSeg, s.ep.tsOffset) + s.rc.detectReorder(seg) } s.writeList.Remove(seg) - // if SACK is enabled then Only reduce outstanding if + // If SACK is enabled then Only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index e03f101e8..d3f92b48c 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -21,17 +21,20 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" ) +const ( + maxPayload = 10 + tsOptionSize = 12 + maxTCPOptionSize = 40 +) + // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - const maxPayload = 10 - const tsOptionSize = 12 - const maxTCPOptionSize = 40 - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) defer c.Cleanup() @@ -49,7 +52,7 @@ func TestRACKUpdate(t *testing.T) { } if state.Sender.RACKState.RTT == 0 { - t.Fatalf("RACK RTT failed to update when an ACK is received") + t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") } }) setStackSACKPermitted(t, c, true) @@ -69,6 +72,66 @@ func TestRACKUpdate(t *testing.T) { bytesRead := 0 c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload - c.SendAck(790, bytesRead) + c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) time.Sleep(200 * time.Millisecond) } + +// TestRACKDetectReorder tests that RACK detects packet reordering. +func TestRACKDetectReorder(t *testing.T) { + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + const ackNum = 2 + + var n int + ch := make(chan struct{}) + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + gotSeq := state.Sender.RACKState.FACK + wantSeq := state.Sender.SndNxt + // FACK should be updated to the highest ending sequence number of the + // segment acknowledged most recently. + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK FACK failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + n++ + if n < ackNum { + if state.Sender.RACKState.Reord { + t.Fatalf("RACK reorder detected when there is no reordering") + } + return + } + + if state.Sender.RACKState.Reord == false { + t.Fatalf("RACK reorder detection failed") + } + close(ch) + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + data := buffer.NewView(ackNum * maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + for i := 0; i < ackNum; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + start := c.IRS.Add(maxPayload + 1) + end := start.Add(maxPayload) + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAckWithSACK(seq, 0, []header.SACKBlock{{start, end}}) + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-ch +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 5b504d0d1..a7149efd0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6264,14 +6264,27 @@ func TestReceiveBufferAutoTuning(t *testing.T) { rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ + if i == 0 { // In the first iteration the receiver based RTT is not // yet known as a result the moderation code should not // increase the advertised window. rawEP.VerifyACKRcvWnd(scaleRcvWnd(curRcvWnd)) } else { - pkt := c.GetPacket() - curRcvWnd = int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale + // Read loop above could generate an ACK if the window had dropped to + // zero and then read had opened it up. + lastACK := c.GetPacket() + // Discard any intermediate ACKs and only check the last ACK we get in a + // short time period of few ms. + for { + time.Sleep(1 * time.Millisecond) + pkt := c.GetPacketNonBlocking() + if pkt == nil { + break + } + lastACK = pkt + } + curRcvWnd = int(header.TCP(header.IPv4(lastACK).Payload()).WindowSize()) << c.WindowScale // If thew new current window is close maxReceiveBufferSize then terminate // the loop. This can happen before all iterations are done due to timing // differences when running the test. @@ -7328,7 +7341,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Write chunks of ~30000 bytes. It's important that two // payloads make it equal or longer than MSS. - remain := rcvBuf * 2 + remain := rcvBuf sent := 0 data := make([]byte, defaultMTU/2) @@ -7343,7 +7356,6 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { }) sent += len(data) remain -= len(data) - checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index faf51ef95..4d7847142 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -68,9 +68,9 @@ const ( // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - // testInitialSequenceNumber is the initial sequence number sent in packets that + // TestInitialSequenceNumber is the initial sequence number sent in packets that // are sent in response to a SYN or in the initial SYN sent to the stack. - testInitialSequenceNumber = 789 + TestInitialSequenceNumber = 789 ) // StackAddrWithPrefix is StackAddr with its associated prefix length. @@ -505,7 +505,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.TCP( checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -532,7 +532,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.TCP( checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), - checker.TCPAckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) @@ -912,7 +912,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Build SYN-ACK. c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) - iss := seqnum.Value(testInitialSequenceNumber) + iss := seqnum.Value(TestInitialSequenceNumber) c.SendPacket(nil, &Headers{ SrcPort: tcpSeg.DestinationPort(), DstPort: tcpSeg.SourcePort(), @@ -1084,7 +1084,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions offset += paddingToAdd // Send a SYN request. - iss := seqnum.Value(testInitialSequenceNumber) + iss := seqnum.Value(TestInitialSequenceNumber) c.SendPacket(nil, &Headers{ SrcPort: TestPort, DstPort: StackPort, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e14a5fc5..b4604ba35 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1432,7 +1432,9 @@ func TestNoChecksum(t *testing.T) { var _ stack.NetworkInterface = (*testInterface)(nil) -type testInterface struct{} +type testInterface struct { + stack.NetworkLinkEndpoint +} func (*testInterface) ID() tcpip.NICID { return 0 @@ -1450,10 +1452,6 @@ func (*testInterface) Enabled() bool { return true } -func (*testInterface) LinkEndpoint() stack.LinkEndpoint { - return nil -} - func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1780,16 +1778,26 @@ func TestV4UnknownDestination(t *testing.T) { checker.ICMPv4Type(header.ICMPv4DstUnreachable), checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + // We need to compare the included data part of the UDP packet that is in + // the ICMP packet with the matching original data. icmpPkt := header.ICMPv4(hdr.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload()) + incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize wantLen := len(payload) if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + // To work out the data size we need to simulate what the sender would + // have done. The wanted size is the total available minus the sum of + // the headers in the UDP AND ICMP packets, given that we know the test + // had only a minimal IP header but the ICMP sender will have allowed + // for a maximally sized packet header. + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength + } - // In case of large payloads the IP packet may be truncated. Update + // In the case of large payloads the IP packet may be truncated. Update // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + // Add back the two headers within the payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength)) origDgram := header.UDP(payloadIPHeader.Payload()) if got, want := len(origDgram.Payload()), wantLen; got != want { @@ -2015,7 +2023,8 @@ func TestPayloadModifiedV4(t *testing.T) { payload := newPayload() h := unicastV4.header4Tuple(incoming) buf := c.buildV4Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), @@ -2045,7 +2054,8 @@ func TestPayloadModifiedV6(t *testing.T) { payload := newPayload() h := unicastV6.header4Tuple(incoming) buf := c.buildV6Packet(payload, &h) - // Modify the payload so that the checksum value in the UDP header will be incorrect. + // Modify the payload so that the checksum value in the UDP header will be + // incorrect. buf[len(buf)-1]++ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 06fb823f6..49ab87c58 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -270,7 +270,7 @@ func RandomID(prefix string) string { // same name, sometimes between test runs the socket does not get cleaned up // quickly enough, causing container creation to fail. func RandomContainerID() string { - return RandomID("test-container-") + return RandomID("test-container") } // Copy copies file from src to dst. diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 6ac19668f..a7c4ebb0c 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -162,6 +162,12 @@ var allowedSyscalls = seccomp.SyscallRules{ }, syscall.SYS_LSEEK: {}, syscall.SYS_MADVISE: {}, + unix.SYS_MEMBARRIER: []seccomp.Rule{ + { + seccomp.EqualTo(linux.MEMBARRIER_CMD_GLOBAL), + seccomp.EqualTo(0), + }, + }, syscall.SYS_MINCORE: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 2e652ddad..9a08ebc60 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -282,6 +282,7 @@ func New(args Args) (*Loader, error) { args.NumCPU = runtime.NumCPU() } log.Infof("CPUs: %d", args.NumCPU) + runtime.GOMAXPROCS(args.NumCPU) if args.TotalMem > 0 { // Adjust the total memory returned by the Sentry so that applications that @@ -471,9 +472,13 @@ func (l *Loader) Destroy() { } l.watchdog.Stop() + // Release all kernel resources. This is only safe after we can no longer + // save/restore. + l.k.Release() + // In the success case, stdioFDs and goferFDs will only contain // released/closed FDs that ownership has been passed over to host FDs and - // gofer sessions. Close them here in case on failure. + // gofer sessions. Close them here in case of failure. for _, fd := range l.root.stdioFDs { _ = fd.Close() } @@ -797,7 +802,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn } // startGoferMonitor runs a goroutine to monitor gofer's health. It polls on -// the gofer FDs looking for disconnects, and destroys the container if a +// the gofer FDs looking for disconnects, and kills the container processes if a // disconnect occurs in any of the gofer FDs. func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { go func() { @@ -818,18 +823,15 @@ func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { panic(fmt.Sprintf("Error monitoring gofer FDs: %v", err)) } - // Check if the gofer has stopped as part of normal container destruction. - // This is done just to avoid sending an annoying error message to the log. - // Note that there is a small race window in between mu.Unlock() and the - // lock being reacquired in destroyContainer(), but it's harmless to call - // destroyContainer() multiple times. l.mu.Lock() - _, ok := l.processes[execID{cid: cid}] - l.mu.Unlock() - if ok { - log.Infof("Gofer socket disconnected, destroying container %q", cid) - if err := l.destroyContainer(cid); err != nil { - log.Warningf("Error destroying container %q after gofer stopped: %v", cid, err) + defer l.mu.Unlock() + + // The gofer could have been stopped due to a normal container shutdown. + // Check if the container has not stopped yet. + if tg, _ := l.tryThreadGroupFromIDLocked(execID{cid: cid}); tg != nil { + log.Infof("Gofer socket disconnected, killing container %q", cid) + if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil { + log.Warningf("Error killing container %q after gofer stopped: %v", cid, err) } } }() @@ -898,17 +900,24 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) { return 0, fmt.Errorf("container %q not started", args.ContainerID) } - // Get the container MountNamespace from the Task. + // Get the container MountNamespace from the Task. Try to acquire ref may fail + // in case it raced with task exit. if kernel.VFS2Enabled { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2() - args.MountNamespaceVFS2.IncRef() + if !args.MountNamespaceVFS2.TryIncRef() { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } else { + var reffed bool tg.Leader().WithMuLocked(func(t *kernel.Task) { // task.MountNamespace() does not take a ref, so we must do so ourselves. args.MountNamespace = t.MountNamespace() - args.MountNamespace.IncRef() + reffed = args.MountNamespace.TryIncRef() }) + if !reffed { + return 0, fmt.Errorf("container %q has stopped", args.ContainerID) + } } // Add the HOME environment variable if it is not already set. diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index bf9ec5d38..1f49431a2 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -264,7 +264,7 @@ type CreateMountTestcase struct { expectedPaths []string } -func createMountTestcases(vfs2 bool) []*CreateMountTestcase { +func createMountTestcases() []*CreateMountTestcase { testCases := []*CreateMountTestcase{ &CreateMountTestcase{ // Only proc. @@ -409,32 +409,26 @@ func createMountTestcases(vfs2 bool) []*CreateMountTestcase { Destination: "/proc", Type: "tmpfs", }, - // TODO (gvisor.dev/issue/1487): Re-add this case when sysfs supports - // MkDirAt in VFS2 (and remove the reduntant append). - // { - // Destination: "/sys/bar", - // Type: "tmpfs", - // }, - // + { + Destination: "/sys/bar", + Type: "tmpfs", + }, + { Destination: "/tmp/baz", Type: "tmpfs", }, }, }, - expectedPaths: []string{"/proc", "/sys" /* "/sys/bar" ,*/, "/tmp", "/tmp/baz"}, + expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"}, } - if !vfs2 { - vfsCase.spec.Mounts = append(vfsCase.spec.Mounts, specs.Mount{Destination: "/sys/bar", Type: "tmpfs"}) - vfsCase.expectedPaths = append(vfsCase.expectedPaths, "/sys/bar") - } return append(testCases, vfsCase) } // Test that MountNamespace can be created with various specs. func TestCreateMountNamespace(t *testing.T) { - for _, tc := range createMountTestcases(false /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { conf := testConfig() ctx := contexttest.Context(t) @@ -471,7 +465,7 @@ func TestCreateMountNamespace(t *testing.T) { // Test that MountNamespace can be created with various specs. func TestCreateMountNamespaceVFS2(t *testing.T) { - for _, tc := range createMountTestcases(true /* vfs2 */) { + for _, tc := range createMountTestcases() { t.Run(tc.name, func(t *testing.T) { spec := testSpec() spec.Mounts = tc.spec.Mounts diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 548c68087..1f8e277cc 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -316,6 +316,7 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config { return cs } +// TODO(gvisor.dev/issue/1624): Merge with configs when VFS2 is the default. func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config { all := configs(t, opts...) for key, value := range configs(t, opts...) { @@ -894,13 +895,15 @@ func TestKillPid(t *testing.T) { } } -// TestCheckpointRestore creates a container that continuously writes successive integers -// to a file. To test checkpoint and restore functionality, the container is -// checkpointed and the last number printed to the file is recorded. Then, it is restored in two -// new containers and the first number printed from these containers is checked. Both should -// be the next consecutive number after the last number from the checkpointed container. +// TestCheckpointRestore creates a container that continuously writes successive +// integers to a file. To test checkpoint and restore functionality, the +// container is checkpointed and the last number printed to the file is +// recorded. Then, it is restored in two new containers and the first number +// printed from these containers is checked. Both should be the next consecutive +// number after the last number from the checkpointed container. func TestCheckpointRestore(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test") @@ -1062,6 +1065,7 @@ func TestCheckpointRestore(t *testing.T) { // with filesystem Unix Domain Socket use. func TestUnixDomainSockets(t *testing.T) { // Skip overlay because test requires writing to host file. + // TODO(gvisor.dev/issue/1663): Add VFS when S/R support is added. for name, conf := range configs(t, noOverlay...) { t.Run(name, func(t *testing.T) { // UDS path is limited to 108 chars for compatibility with older systems. @@ -1199,7 +1203,7 @@ func TestUnixDomainSockets(t *testing.T) { // recreated. Then it resumes the container, verify that the file gets created // again. func TestPauseResume(t *testing.T) { - for name, conf := range configs(t, noOverlay...) { + for name, conf := range configsWithVFS2(t, noOverlay...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock") if err != nil { diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 952215ec1..850e80290 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -480,7 +480,7 @@ func TestMultiContainerMount(t *testing.T) { // TestMultiContainerSignal checks that it is possible to signal individual // containers without killing the entire sandbox. func TestMultiContainerSignal(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1691,12 +1691,11 @@ func TestMultiContainerRunNonRoot(t *testing.T) { } // TestMultiContainerHomeEnvDir tests that the HOME environment variable is set -// for root containers, sub-containers, and execed processes. +// for root containers, sub-containers, and exec'ed processes. func TestMultiContainerHomeEnvDir(t *testing.T) { - // TODO(gvisor.dev/issue/1487): VFSv2 configs failing. // NOTE: Don't use overlay since we need changes to persist to the temp dir // outside the sandbox. - for testName, conf := range configs(t, noOverlay...) { + for testName, conf := range configsWithVFS2(t, noOverlay...) { t.Run(testName, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() @@ -1718,9 +1717,9 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { // We will sleep in the root container in order to ensure that the root //container doesn't terminate before sub containers can be created. - rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())} - subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())} - execCmd := fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name()) + rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s; sleep 1000`, homeDirs["root"].Name())} + subCmd := []string{"/bin/sh", "-c", fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["sub"].Name())} + execCmd := fmt.Sprintf(`printf "$HOME" > %s`, homeDirs["exec"].Name()) // Setup the containers, a root container and sub container. specConfig, ids := createSpecs(rootCmd, subCmd) diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go index ba1ff833f..775325c06 100644 --- a/runsc/flag/flag.go +++ b/runsc/flag/flag.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package flag wraps flag primitives. package flag import ( diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 0b9f39466..8f66dd1f8 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -309,11 +309,20 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( const bufSize = 4 << 20 // 4MB. if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil { - return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err) + syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize) + sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) + + if sz < bufSize { + log.Warningf("Failed to increase rcv buffer to %d on SOCK_RAW on %s. Current buffer %d: %v", bufSize, iface.Name, sz, err) + } } if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil { - return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err) + syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize) + sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if sz < bufSize { + log.Warningf("Failed to increase snd buffer to %d on SOCK_RAW on %s. Curent buffer %d: %v", bufSize, iface.Name, sz, err) + } } return &socketEntry{deviceFile, gsoMaxSize}, nil diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index a8f4f64a5..c4309feb3 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -72,11 +72,14 @@ type Sandbox struct { // will have it as a child process. child bool - // status is an exit status of a sandbox process. - status syscall.WaitStatus - // statusMu protects status. statusMu sync.Mutex + + // status is the exit status of a sandbox process. It's only set if the + // child==true and the sandbox was waited on. This field allows for multiple + // threads to wait on sandbox and get the exit code, since Linux will return + // WaitStatus to one of the waiters only. + status syscall.WaitStatus } // Args is used to configure a new sandbox. @@ -746,35 +749,47 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn // Wait waits for the containerized process to exit, and returns its WaitStatus. func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID) - var ws syscall.WaitStatus if conn, err := s.sandboxConnect(); err != nil { - // The sandbox may have exited while before we had a chance to - // wait on it. + // The sandbox may have exited while before we had a chance to wait on it. + // There is nothing we can do for subcontainers. For the init container, we + // can try to get the sandbox exit code. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } else { defer conn.Close() + // Try the Wait RPC to the sandbox. + var ws syscall.WaitStatus err = conn.Call(boot.ContainerWait, &cid, &ws) if err == nil { // It worked! return ws, nil } + // See comment above. + if !s.IsRootContainer(cid) { + return syscall.WaitStatus(0), err + } + // The sandbox may have exited after we connected, but before // or during the Wait RPC. log.Warningf("Wait RPC to container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } - // The sandbox may have already exited, or exited while handling the - // Wait RPC. The best we can do is ask Linux what the sandbox exit - // status was, since in most cases that will be the same as the - // container exit status. + // The sandbox may have already exited, or exited while handling the Wait RPC. + // The best we can do is ask Linux what the sandbox exit status was, since in + // most cases that will be the same as the container exit status. if err := s.waitForStopped(); err != nil { - return ws, err + return syscall.WaitStatus(0), err } if !s.child { - return ws, fmt.Errorf("sandbox no longer running and its exit status is unavailable") + return syscall.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable") } + + s.statusMu.Lock() + defer s.statusMu.Unlock() return s.status, nil } diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go index 6fb813640..39ced3dab 100644 --- a/test/benchmarks/base/sysbench_test.go +++ b/test/benchmarks/base/sysbench_test.go @@ -71,8 +71,15 @@ func BenchmarkSysbench(b *testing.B) { defer machine.CleanUp() for _, tc := range testCases { - b.Run(tc.name, func(b *testing.B) { - + param := tools.Parameter{ + Name: "testname", + Value: tc.name, + } + name, err := tools.ParametersToName(param) + if err != nil { + b.Fatalf("Failed to parse params: %v", err) + } + b.Run(name, func(b *testing.B) { ctx := context.Background() sysbench := machine.GetContainer(ctx, b) defer sysbench.CleanUp(ctx) diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go index 6671a4969..02e67154e 100644 --- a/test/benchmarks/database/redis_test.go +++ b/test/benchmarks/database/redis_test.go @@ -66,7 +66,15 @@ func BenchmarkRedis(b *testing.B) { ctx := context.Background() for _, operation := range operations { - b.Run(operation, func(b *testing.B) { + param := tools.Parameter{ + Name: "operation", + Value: operation, + } + name, err := tools.ParametersToName(param) + if err != nil { + b.Fatalf("Failed to parse paramaters: %v", err) + } + b.Run(name, func(b *testing.B) { server := serverMachine.GetContainer(ctx, b) defer server.CleanUp(ctx) diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index ef1b8e4ea..56103639d 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" ) // Note: CleanCache versions of this test require running with root permissions. @@ -46,17 +47,36 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { // Dimensions here are clean/dirty cache (do or don't drop caches) // and if the mount on which we are compiling is a tmpfs/bind mount. benchmarks := []struct { - name string clearCache bool // clearCache drops caches before running. tmpfs bool // tmpfs will run compilation on a tmpfs. }{ - {name: "CleanCache", clearCache: true, tmpfs: false}, - {name: "DirtyCache", clearCache: false, tmpfs: false}, - {name: "CleanCacheTmpfs", clearCache: true, tmpfs: true}, - {name: "DirtyCacheTmpfs", clearCache: false, tmpfs: true}, + {clearCache: true, tmpfs: false}, + {clearCache: false, tmpfs: false}, + {clearCache: true, tmpfs: true}, + {clearCache: false, tmpfs: true}, } for _, bm := range benchmarks { - b.Run(bm.name, func(b *testing.B) { + pageCache := tools.Parameter{ + Name: "page_cache", + Value: "clean", + } + if bm.clearCache { + pageCache.Value = "dirty" + } + + filesystem := tools.Parameter{ + Name: "filesystem", + Value: "bind", + } + if bm.tmpfs { + filesystem.Value = "tmpfs" + } + name, err := tools.ParametersToName(pageCache, filesystem) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } + + b.Run(name, func(b *testing.B) { // Grab a container. ctx := context.Background() container := machine.GetContainer(ctx, b) diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go index 65874ed8b..5ca191404 100644 --- a/test/benchmarks/fs/fio_test.go +++ b/test/benchmarks/fs/fio_test.go @@ -67,8 +67,19 @@ func BenchmarkFio(b *testing.B) { for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} { for _, tc := range testCases { - testName := strings.Title(tc.Test) + strings.Title(string(fsType)) - b.Run(testName, func(b *testing.B) { + operation := tools.Parameter{ + Name: "operation", + Value: tc.Test, + } + filesystem := tools.Parameter{ + Name: "filesystem", + Value: string(fsType), + } + name, err := tools.ParametersToName(operation, filesystem) + if err != nil { + b.Fatalf("Failed to parser paramters: %v", err) + } + b.Run(name, func(b *testing.B) { ctx := context.Background() container := machine.GetContainer(ctx, b) defer container.CleanUp(ctx) diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go index 369ab326e..8d7d5f750 100644 --- a/test/benchmarks/network/httpd_test.go +++ b/test/benchmarks/network/httpd_test.go @@ -14,7 +14,7 @@ package network import ( - "fmt" + "strconv" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -31,33 +31,15 @@ var httpdDocs = map[string]string{ "10Mb": "latin10240k.txt", } -// BenchmarkHttpdConcurrency iterates the concurrency argument and tests -// how well the runtime under test handles requests in parallel. -func BenchmarkHttpdConcurrency(b *testing.B) { - // The test iterates over client concurrency, so set other parameters. - concurrency := []int{1, 25, 50, 100, 1000} - - for _, c := range concurrency { - b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { - hey := &tools.Hey{ - Requests: c * b.N, - Concurrency: c, - Doc: httpdDocs["10Kb"], - } - runHttpd(b, hey, false /* reverse */) - }) - } -} - -// BenchmarkHttpdDocSize iterates over different sized payloads, testing how -// well the runtime handles sending different payload sizes. -func BenchmarkHttpdDocSize(b *testing.B) { +// BenchmarkHttpd iterates over different sized payloads and concurrency, testing +// how well the runtime handles sending different payload sizes. +func BenchmarkHttpd(b *testing.B) { benchmarkHttpdDocSize(b, false /* reverse */) } -// BenchmarkReverseHttpdDocSize iterates over different sized payloads, testing +// BenchmarkReverseHttpd iterates over different sized payloads, testing // how well the runtime handles receiving different payload sizes. -func BenchmarkReverseHttpdDocSize(b *testing.B) { +func BenchmarkReverseHttpd(b *testing.B) { benchmarkHttpdDocSize(b, true /* reverse */) } @@ -65,10 +47,22 @@ func BenchmarkReverseHttpdDocSize(b *testing.B) { // for each size. func benchmarkHttpdDocSize(b *testing.B, reverse bool) { b.Helper() - for name, filename := range httpdDocs { + for size, filename := range httpdDocs { concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { - b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { + fsize := tools.Parameter{ + Name: "filesize", + Value: size, + } + concurrency := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + name, err := tools.ParametersToName(fsize, concurrency) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } + b.Run(name, func(b *testing.B) { hey := &tools.Hey{ Requests: c * b.N, Concurrency: c, diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go index 9ec70369b..08565d0b2 100644 --- a/test/benchmarks/network/nginx_test.go +++ b/test/benchmarks/network/nginx_test.go @@ -14,7 +14,7 @@ package network import ( - "fmt" + "strconv" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -31,30 +31,6 @@ var nginxDocs = map[string]string{ "10Mb": "latin10240k.txt", } -// BenchmarkNginxConcurrency iterates the concurrency argument and tests -// how well the runtime under test handles requests in parallel. -func BenchmarkNginxConcurrency(b *testing.B) { - concurrency := []int{1, 25, 100, 1000} - for _, c := range concurrency { - for _, tmpfs := range []bool{true, false} { - fs := "Gofer" - if tmpfs { - fs = "Tmpfs" - } - name := fmt.Sprintf("%d_%s", c, fs) - b.Run(name, func(b *testing.B) { - hey := &tools.Hey{ - Requests: c * b.N, - Concurrency: c, - Doc: nginxDocs["10kb"], // see Dockerfile '//images/benchmarks/nginx' and httpd_test. - } - runNginx(b, hey, false /* reverse */, tmpfs /* tmpfs */) - }) - } - - } -} - // BenchmarkNginxDocSize iterates over different sized payloads, testing how // well the runtime handles sending different payload sizes. func BenchmarkNginxDocSize(b *testing.B) { @@ -71,15 +47,32 @@ func BenchmarkReverseNginxDocSize(b *testing.B) { // benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks // for each size. func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) { - for name, filename := range nginxDocs { + for size, filename := range nginxDocs { concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { - fs := "Gofer" + fsize := tools.Parameter{ + Name: "filesize", + Value: size, + } + + threads := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + + fs := tools.Parameter{ + Name: "filesystem", + Value: "bind", + } if tmpfs { - fs = "Tmpfs" + fs.Value = "tmpfs" + } + name, err := tools.ParametersToName(fsize, threads, fs) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) } - benchName := fmt.Sprintf("%s_%d_%s", name, c, fs) - b.Run(benchName, func(b *testing.B) { + + b.Run(name, func(b *testing.B) { hey := &tools.Hey{ Requests: c * b.N, Concurrency: c, diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go index 0f4a205b6..254538899 100644 --- a/test/benchmarks/network/node_test.go +++ b/test/benchmarks/network/node_test.go @@ -15,7 +15,7 @@ package network import ( "context" - "fmt" + "strconv" "testing" "time" @@ -31,7 +31,15 @@ import ( func BenchmarkNode(b *testing.B) { concurrency := []int{1, 5, 10, 25} for _, c := range concurrency { - b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + param := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + name, err := tools.ParametersToName(param) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } + b.Run(name, func(b *testing.B) { hey := &tools.Hey{ Requests: b.N * c, // Requests b.N requests per thread. Concurrency: c, diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go index 67f63f76a..0174ff3f3 100644 --- a/test/benchmarks/network/ruby_test.go +++ b/test/benchmarks/network/ruby_test.go @@ -16,6 +16,7 @@ package network import ( "context" "fmt" + "strconv" "testing" "time" @@ -31,7 +32,15 @@ import ( func BenchmarkRuby(b *testing.B) { concurrency := []int{1, 5, 10, 25} for _, c := range concurrency { - b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + param := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + name, err := tools.ParametersToName(param) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } + b.Run(name, func(b *testing.B) { hey := &tools.Hey{ Requests: b.N * c, // b.N requests per thread. Concurrency: c, diff --git a/test/benchmarks/tools/BUILD b/test/benchmarks/tools/BUILD index e5734d85c..9290830d7 100644 --- a/test/benchmarks/tools/BUILD +++ b/test/benchmarks/tools/BUILD @@ -4,12 +4,14 @@ package(licenses = ["notice"]) go_library( name = "tools", + testonly = 1, srcs = [ "ab.go", "fio.go", "hey.go", "iperf.go", "meminfo.go", + "parser_util.go", "redis.go", "sysbench.go", "tools.go", diff --git a/test/benchmarks/tools/ab.go b/test/benchmarks/tools/ab.go index 4cc9c3bce..d9abf0763 100644 --- a/test/benchmarks/tools/ab.go +++ b/test/benchmarks/tools/ab.go @@ -46,18 +46,21 @@ func (a *ApacheBench) Report(b *testing.B, output string) { b.Logf("failed to parse transferrate: %v", err) } b.ReportMetric(transferRate*1024, "transfer_rate_b/s") // Convert from Kb/s to b/s. + ReportCustomMetric(b, transferRate*1024, "transfer_rate" /*metric name*/, "bytes_per_second" /*unit*/) latency, err := a.parseLatency(output) if err != nil { b.Logf("failed to parse latency: %v", err) } b.ReportMetric(latency/1000, "mean_latency_secs") // Convert from ms to s. + ReportCustomMetric(b, latency/1000, "mean_latency" /*metric name*/, "s" /*unit*/) reqPerSecond, err := a.parseRequestsPerSecond(output) if err != nil { b.Logf("failed to parse requests per second: %v", err) } b.ReportMetric(reqPerSecond, "requests_per_second") + ReportCustomMetric(b, reqPerSecond, "requests_per_second" /*metric name*/, "QPS" /*unit*/) } var transferRateRE = regexp.MustCompile(`Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received`) diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go index 20000db16..f5f60fa84 100644 --- a/test/benchmarks/tools/fio.go +++ b/test/benchmarks/tools/fio.go @@ -56,13 +56,13 @@ func (f *Fio) Report(b *testing.B, output string) { if err != nil { b.Fatalf("failed to parse bandwidth from %s with: %v", output, err) } - b.ReportMetric(bw, "bandwidth_b/s") // in b/s. + ReportCustomMetric(b, bw, "bandwidth" /*metric name*/, "bytes_per_second" /*unit*/) iops, err := f.parseIOps(output, isRead) if err != nil { b.Fatalf("failed to parse iops from %s with: %v", output, err) } - b.ReportMetric(iops, "iops") + ReportCustomMetric(b, iops, "io_ops" /*metric name*/, "ops_per_second" /*unit*/) } // parseBandwidth reports the bandwidth in b/s. diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go index b1e20e356..b8cb938fe 100644 --- a/test/benchmarks/tools/hey.go +++ b/test/benchmarks/tools/hey.go @@ -43,13 +43,13 @@ func (h *Hey) Report(b *testing.B, output string) { if err != nil { b.Fatalf("failed to parse requests per second: %v", err) } - b.ReportMetric(requests, "requests_per_second") + ReportCustomMetric(b, requests, "requests_per_second" /*metric name*/, "QPS" /*unit*/) ave, err := h.parseAverageLatency(output) if err != nil { b.Fatalf("failed to parse average latency: %v", err) } - b.ReportMetric(ave, "average_latency_secs") + ReportCustomMetric(b, ave, "average_latency" /*metric name*/, "s" /*unit*/) } var heyReqPerSecondRE = regexp.MustCompile(`Requests/sec:\s*(\d+\.?\d+?)\s+`) diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go index df3d9349b..5c4e7125b 100644 --- a/test/benchmarks/tools/iperf.go +++ b/test/benchmarks/tools/iperf.go @@ -42,7 +42,7 @@ func (i *Iperf) Report(b *testing.B, output string) { if err != nil { b.Fatalf("failed to parse bandwitdth from %s: %v", output, err) } - b.ReportMetric(bW*1024, "bandwidth_b/s") // Convert from Kb/s to b/s. + ReportCustomMetric(b, bW*1024, "bandwidth" /*metric name*/, "bytes_per_second" /*unit*/) } // bandwidth parses the Bandwidth number from an iperf report. A sample is below. diff --git a/test/benchmarks/tools/meminfo.go b/test/benchmarks/tools/meminfo.go index 2414a96a7..b5786fe11 100644 --- a/test/benchmarks/tools/meminfo.go +++ b/test/benchmarks/tools/meminfo.go @@ -45,7 +45,7 @@ func (*Meminfo) Report(b *testing.B, before, after string) { b.Fatalf("could not parse before value %s: %v", before, err) } val := 1024 * ((beforeVal - afterVal) / float64(b.N)) - b.ReportMetric(val, "average_container_size_bytes") + ReportCustomMetric(b, val, "average_container_size" /*metric name*/, "bytes" /*units*/) } var memInfoRE = regexp.MustCompile(`MemAvailable:\s*(\d+)\skB\n`) diff --git a/test/benchmarks/tools/parser_util.go b/test/benchmarks/tools/parser_util.go new file mode 100644 index 000000000..a4555c7dd --- /dev/null +++ b/test/benchmarks/tools/parser_util.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "testing" +) + +// Parameter is a test parameter. +type Parameter struct { + Name string + Value string +} + +// Output is parsed and split by these values. Make them illegal in input methods. +// We are constrained on what characters these can be by 1) docker's allowable +// container names, 2) golang allowable benchmark names, and 3) golangs allowable +// charecters in b.ReportMetric calls. +var illegalChars = regexp.MustCompile(`[/\.]`) + +// ParametersToName joins parameters into a string format for parsing. +// It is meant to be used for t.Run() calls in benchmark tools. +func ParametersToName(params ...Parameter) (string, error) { + var strs []string + for _, param := range params { + if illegalChars.MatchString(param.Name) || illegalChars.MatchString(param.Value) { + return "", fmt.Errorf("params Name: %q and Value: %q cannot container '.' or '/'", param.Name, param.Value) + } + strs = append(strs, strings.Join([]string{param.Name, param.Value}, ".")) + } + return strings.Join(strs, "/"), nil +} + +// NameToParameters parses the string created by ParametersToName and returns +// it as a set of Parameters. +// Example: BenchmarkRuby/server_threads.1/doc_size.16KB-6 +// The parameter part of this benchmark is: +// "server_threads.1/doc_size.16KB" (BenchmarkRuby is the name, and 6 is GOMAXPROCS) +// This function will return a slice with two parameters -> +// {Name: server_threads, Value: 1}, {Name: doc_size, Value: 16KB} +func NameToParameters(name string) ([]*Parameter, error) { + var params []*Parameter + for _, cond := range strings.Split(name, "/") { + cs := strings.Split(cond, ".") + switch len(cs) { + case 1: + params = append(params, &Parameter{Name: cond, Value: cond}) + case 2: + params = append(params, &Parameter{Name: cs[0], Value: cs[1]}) + default: + return nil, fmt.Errorf("failed to parse param: %s", cond) + } + } + return params, nil +} + +// ReportCustomMetric reports a metric in a set format for parsing. +func ReportCustomMetric(b *testing.B, value float64, name, unit string) { + if illegalChars.MatchString(name) || illegalChars.MatchString(unit) { + b.Fatalf("name: %q and unit: %q cannot contain '/' or '.'", name, unit) + } + nameUnit := strings.Join([]string{name, unit}, ".") + b.ReportMetric(value, nameUnit) +} + +// Metric holds metric data parsed from a string based on the format +// ReportMetric. +type Metric struct { + Name string + Unit string + Sample float64 +} + +// ParseCustomMetric parses a metric reported with ReportCustomMetric. +func ParseCustomMetric(value, metric string) (*Metric, error) { + sample, err := strconv.ParseFloat(value, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse value: %v", err) + } + nameUnit := strings.Split(metric, ".") + if len(nameUnit) != 2 { + return nil, fmt.Errorf("failed to parse metric: %s", metric) + } + return &Metric{Name: nameUnit[0], Unit: nameUnit[1], Sample: sample}, nil +} diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go index c899ae0d4..e35886437 100644 --- a/test/benchmarks/tools/redis.go +++ b/test/benchmarks/tools/redis.go @@ -49,7 +49,7 @@ func (r *Redis) Report(b *testing.B, output string) { if err != nil { b.Fatalf("parsing result %s failed with err: %v", output, err) } - b.ReportMetric(result, r.Operation) // operations per second + ReportCustomMetric(b, result, r.Operation /*metric_name*/, "QPS" /*unit*/) } // parseOperation grabs the metric operations per second from redis-benchmark output. diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go index 6b2f75ca2..7ccacd8ff 100644 --- a/test/benchmarks/tools/sysbench.go +++ b/test/benchmarks/tools/sysbench.go @@ -80,7 +80,7 @@ func (s *SysbenchCPU) Report(b *testing.B, output string) { if err != nil { b.Fatalf("parsing CPU events from %s failed: %v", output, err) } - b.ReportMetric(result, "cpu_events_per_second") + ReportCustomMetric(b, result, "cpu_events" /*metric name*/, "events_per_second" /*unit*/) } var cpuEventsPerSecondRE = regexp.MustCompile(`events per second:\s*(\d*.?\d*)\n`) @@ -144,7 +144,7 @@ func (s *SysbenchMemory) Report(b *testing.B, output string) { if err != nil { b.Fatalf("parsing result %s failed with err: %v", output, err) } - b.ReportMetric(result, "operations_per_second") + ReportCustomMetric(b, result, "memory_operations" /*metric name*/, "ops_per_second" /*unit*/) } var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`) @@ -198,19 +198,19 @@ func (s *SysbenchMutex) Report(b *testing.B, output string) { if err != nil { b.Fatalf("parsing result %s failed with err: %v", output, err) } - b.ReportMetric(result, "average_execution_time_secs") + ReportCustomMetric(b, result, "average_execution_time" /*metric name*/, "s" /*unit*/) result, err = s.parseDeviation(output) if err != nil { b.Fatalf("parsing result %s failed with err: %v", output, err) } - b.ReportMetric(result, "stdev_execution_time_secs") + ReportCustomMetric(b, result, "stddev_execution_time" /*metric name*/, "s" /*unit*/) result, err = s.parseLatency(output) if err != nil { b.Fatalf("parsing result %s failed with err: %v", output, err) } - b.ReportMetric(result/1000, "average_latency_secs") + ReportCustomMetric(b, result/1000, "average_latency" /*metric name*/, "s" /*unit*/) } var executionTimeRE = regexp.MustCompile(`execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)`) diff --git a/test/iptables/README.md b/test/iptables/README.md index b9f44bd40..28ab195ca 100644 --- a/test/iptables/README.md +++ b/test/iptables/README.md @@ -1,6 +1,6 @@ # iptables Tests -iptables tests are run via `scripts/iptables_test.sh`. +iptables tests are run via `make iptables-tests`. iptables requires raw socket support, so you must add the `--net-raw=true` flag to `/etc/docker/daemon.json` in order to use it. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 398f70ecd..834f7615f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -48,13 +48,6 @@ func singleTest(t *testing.T, test TestCase) { } } -// TODO(gvisor.dev/issue/3549): IPv6 NAT support. -func ipv4Test(t *testing.T, test TestCase) { - t.Run("IPv4", func(t *testing.T) { - iptablesTest(t, test, false) - }) -} - func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) @@ -325,66 +318,66 @@ func TestFilterOutputInvertDestination(t *testing.T) { } func TestNATPreRedirectUDPPort(t *testing.T) { - ipv4Test(t, NATPreRedirectUDPPort{}) + singleTest(t, NATPreRedirectUDPPort{}) } func TestNATPreRedirectTCPPort(t *testing.T) { - ipv4Test(t, NATPreRedirectTCPPort{}) + singleTest(t, NATPreRedirectTCPPort{}) } func TestNATPreRedirectTCPOutgoing(t *testing.T) { - ipv4Test(t, NATPreRedirectTCPOutgoing{}) + singleTest(t, NATPreRedirectTCPOutgoing{}) } func TestNATOutRedirectTCPIncoming(t *testing.T) { - ipv4Test(t, NATOutRedirectTCPIncoming{}) + singleTest(t, NATOutRedirectTCPIncoming{}) } func TestNATOutRedirectUDPPort(t *testing.T) { - ipv4Test(t, NATOutRedirectUDPPort{}) + singleTest(t, NATOutRedirectUDPPort{}) } func TestNATOutRedirectTCPPort(t *testing.T) { - ipv4Test(t, NATOutRedirectTCPPort{}) + singleTest(t, NATOutRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - ipv4Test(t, NATDropUDP{}) + singleTest(t, NATDropUDP{}) } func TestNATAcceptAll(t *testing.T) { - ipv4Test(t, NATAcceptAll{}) + singleTest(t, NATAcceptAll{}) } func TestNATOutRedirectIP(t *testing.T) { - ipv4Test(t, NATOutRedirectIP{}) + singleTest(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - ipv4Test(t, NATOutDontRedirectIP{}) + singleTest(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - ipv4Test(t, NATOutRedirectInvert{}) + singleTest(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - ipv4Test(t, NATPreRedirectIP{}) + singleTest(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - ipv4Test(t, NATPreDontRedirectIP{}) + singleTest(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - ipv4Test(t, NATPreRedirectInvert{}) + singleTest(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - ipv4Test(t, NATRedirectRequiresProtocol{}) + singleTest(t, NATRedirectRequiresProtocol{}) } func TestNATLoopbackSkipsPrerouting(t *testing.T) { - ipv4Test(t, NATLoopbackSkipsPrerouting{}) + singleTest(t, NATLoopbackSkipsPrerouting{}) } func TestInputSource(t *testing.T) { @@ -421,9 +414,9 @@ func TestFilterAddrs(t *testing.T) { } func TestNATPreOriginalDst(t *testing.T) { - ipv4Test(t, NATPreOriginalDst{}) + singleTest(t, NATPreOriginalDst{}) } func TestNATOutOriginalDst(t *testing.T) { - ipv4Test(t, NATOutOriginalDst{}) + singleTest(t, NATOutOriginalDst{}) } diff --git a/test/kubernetes/gvisor-injection-admission-webhook.yaml b/test/kubernetes/gvisor-injection-admission-webhook.yaml new file mode 100644 index 000000000..691f02dda --- /dev/null +++ b/test/kubernetes/gvisor-injection-admission-webhook.yaml @@ -0,0 +1,89 @@ +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- +apiVersion: v1 +kind: Namespace +metadata: + name: e2e + labels: + name: e2e +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: gvisor-injection-admission-webhook + namespace: e2e +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: gvisor-injection-admission-webhook +rules: +- apiGroups: [ admissionregistration.k8s.io ] + resources: [ mutatingwebhookconfigurations ] + verbs: [ create ] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: gvisor-injection-admission-webhook + namespace: e2e +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: gvisor-injection-admission-webhook +subjects: +- kind: ServiceAccount + name: gvisor-injection-admission-webhook + namespace: e2e +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: gvisor-injection-admission-webhook + namespace: e2e + labels: + app: gvisor-injection-admission-webhook +spec: + replicas: 1 + selector: + matchLabels: + app: gvisor-injection-admission-webhook + template: + metadata: + labels: + app: gvisor-injection-admission-webhook + spec: + containers: + - name: webhook + image: gcr.io/gke-gvisor/gvisor-injection-admission-webhook:54ce9bd + args: + - --log-level=debug + ports: + - containerPort: 8443 + serviceAccountName: gvisor-injection-admission-webhook +--- +kind: Service +apiVersion: v1 +metadata: + name: gvisor-injection-admission-webhook + namespace: e2e +spec: + selector: + app: gvisor-injection-admission-webhook + ports: + - protocol: TCP + port: 443 + targetPort: 8443 diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index 0073a1361..3c85ebbee 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package testbench is the packetimpact test API. package testbench import ( diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 94731c64b..11db49e39 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -310,8 +310,6 @@ packetimpact_go_test( packetimpact_go_test( name = "ipv6_fragment_reassembly", srcs = ["ipv6_fragment_reassembly_test.go"], - # TODO(b/160919104): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go index cf881418c..7f7a768d3 100644 --- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -88,7 +88,8 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { // this test. Once the socket option is supported, the following call // can be changed to simply assert success. ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) - if ret == -1 && errno != unix.ENOTSUP { + // Fuchsia will return ENOPROTOPT errno. + if ret == -1 && errno != unix.ENOPROTOOPT { t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) } diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 032ebd04e..248053dc3 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -203,7 +203,6 @@ def syscall_test( tags = platform_tags + tags, ) - # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. if add_overlay: _syscall_test( test = test, @@ -216,6 +215,23 @@ def syscall_test( overlay = True, ) + # TODO(gvisor.dev/issue/4407): Remove tags to enable VFS2 overlay tests. + overlay_vfs2_tags = list(vfs2_tags) + overlay_vfs2_tags.append("manual") + overlay_vfs2_tags.append("noguitar") + overlay_vfs2_tags.append("notap") + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + overlay_vfs2_tags, + overlay = True, + vfs2 = True, + ) + if add_hostinet: _syscall_test( test = test, diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv index f779df8d5..d978baca7 100644 --- a/test/runtimes/exclude/java11.csv +++ b/test/runtimes/exclude/java11.csv @@ -57,6 +57,7 @@ java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901, java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker java/rmi/activation/Activatable/extLoadedImpl/ext.sh,, java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,, +java/security/cert/PolicyNode/GetPolicyQualifiers.java,b/170263154,Kokoro executor cert expired java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker java/util/Calendar/JapaneseEraNameTest.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index 749fb9482..ba993814f 100644 --- a/test/runtimes/exclude/nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -13,7 +13,7 @@ parallel/test-dns-channel-timeout.js,b/161893056, parallel/test-fs-access.js,, parallel/test-fs-watchfile.js,,Flaky - File already exists error parallel/test-fs-write-stream.js,b/166819807,Flaky -parallel/test-fs-write-stream-double-close,b/166819807,Flaky +parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 parallel/test-os.js,b/63997097, diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index f66371265..fdc6d3173 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "proctor", srcs = ["main.go"], + pure = True, visibility = ["//test/runtimes:__pkg__"], deps = ["//test/runtimes/proctor/lib"], ) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 96a775456..f66a9ceb4 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -286,6 +286,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:membarrier_test", +) + +syscall_test( test = "//test/syscalls/linux:memory_accounting_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index d9dbe2267..36b7f1b97 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1079,6 +1079,7 @@ cc_binary( gtest, "//test/util:test_main", "//test/util:test_util", + "//test/util:thread_util", ], ) @@ -1155,6 +1156,24 @@ cc_binary( ) cc_binary( + name = "membarrier_test", + testonly = 1, + srcs = ["membarrier.cc"], + linkstatic = 1, + deps = [ + "@com_google_absl//absl/time", + gtest, + "//test/util:cleanup", + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( name = "mempolicy_test", testonly = 1, srcs = ["mempolicy.cc"], diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc index f08f2dc55..e0e146067 100644 --- a/test/syscalls/linux/ip6tables.cc +++ b/test/syscalls/linux/ip6tables.cc @@ -89,22 +89,16 @@ TEST(IP6TablesBasic, GetRevision) { ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); - struct xt_get_revision rev = { - .name = "REDIRECT", - .revision = 0, - }; + struct xt_get_revision rev = {}; socklen_t rev_len = sizeof(rev); - // TODO(gvisor.dev/issue/3549): IPv6 redirect support. - const int retval = - getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len); - if (IsRunningOnGvisor()) { - EXPECT_THAT(retval, SyscallFailsWithErrno(ENOPROTOOPT)); - return; - } + snprintf(rev.name, sizeof(rev.name), "REDIRECT"); + rev.revision = 0; // Revision 0 exists. - EXPECT_THAT(retval, SyscallSucceeds()); + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); EXPECT_EQ(rev.revision, 0); // Revisions > 0 don't exist. diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc index 7ee10bbde..22550b800 100644 --- a/test/syscalls/linux/iptables.cc +++ b/test/syscalls/linux/iptables.cc @@ -124,12 +124,12 @@ TEST(IPTablesBasic, GetRevision) { ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds()); - struct xt_get_revision rev = { - .name = "REDIRECT", - .revision = 0, - }; + struct xt_get_revision rev = {}; socklen_t rev_len = sizeof(rev); + snprintf(rev.name, sizeof(rev.name), "REDIRECT"); + rev.revision = 0; + // Revision 0 exists. EXPECT_THAT( getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc index 6afcb4e75..6816c1fd0 100644 --- a/test/syscalls/linux/kcov.cc +++ b/test/syscalls/linux/kcov.cc @@ -16,39 +16,47 @@ #include <sys/ioctl.h> #include <sys/mman.h> +#include <atomic> + #include "gtest/gtest.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { namespace { -// For this test to work properly, it must be run with coverage enabled. On +// For this set of tests to run, they must be run with coverage enabled. On // native Linux, this involves compiling the kernel with kcov enabled. For -// gVisor, we need to enable the Go coverage tool, e.g. -// bazel test --collect_coverage_data --instrumentation_filter=//pkg/... <test>. +// gVisor, we need to enable the Go coverage tool, e.g. bazel test -- +// collect_coverage_data --instrumentation_filter=//pkg/... <test>. + +constexpr char kcovPath[] = "/sys/kernel/debug/kcov"; +constexpr int kSize = 4096; +constexpr int KCOV_INIT_TRACE = 0x80086301; +constexpr int KCOV_ENABLE = 0x6364; +constexpr int KCOV_DISABLE = 0x6365; + +uint64_t* KcovMmap(int fd) { + return (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); +} + TEST(KcovTest, Kcov) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); - constexpr int kSize = 4096; - constexpr int KCOV_INIT_TRACE = 0x80086301; - constexpr int KCOV_ENABLE = 0x6364; - constexpr int KCOV_DISABLE = 0x6365; - int fd; - ASSERT_THAT(fd = open("/sys/kernel/debug/kcov", O_RDWR), + ASSERT_THAT(fd = open(kcovPath, O_RDWR), AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); - // Kcov not available. SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); - uint64_t* area = (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), - PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - + uint64_t* area = KcovMmap(fd); ASSERT_TRUE(area != MAP_FAILED); ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); @@ -67,6 +75,109 @@ TEST(KcovTest, Kcov) { ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); } +TEST(KcovTest, PrematureMmap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Cannot mmap before KCOV_INIT_TRACE. + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area == MAP_FAILED); +} + +// Tests that multiple kcov fds can be used simultaneously. +TEST(KcovTest, MultipleFds) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd1; + ASSERT_THAT(fd1 = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + + int fd2; + ASSERT_THAT(fd2 = open(kcovPath, O_RDWR), SyscallSucceeds()); + auto fd_closer = Cleanup([fd1, fd2]() { + close(fd1); + close(fd2); + }); + + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd1, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd1); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd1, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + ASSERT_THAT(ioctl(fd2, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd2); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd2, KCOV_ENABLE, 0), SyscallSucceeds()); +} + +// Tests behavior for two threads trying to use the same kcov fd. +TEST(KcovTest, MultipleThreads) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + int fd; + ASSERT_THAT(fd = open(kcovPath, O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + // Kcov not available. + SKIP_IF(errno == ENOENT); + auto fd_closer = Cleanup([fd]() { close(fd); }); + + // Test the behavior of multiple threads trying to use the same kcov fd + // simultaneously. + std::atomic<bool> t1_enabled(false), t1_disabled(false), t2_failed(false), + t2_exited(false); + auto t1 = ScopedThread([&] { + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + t1_enabled = true; + + // After t2 has made sure that enabling kcov again fails, disable it. + while (!t2_failed) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); + t1_disabled = true; + + // Wait for t2 to enable kcov and then exit, after which we should be able + // to enable kcov again, without needing to set up a new memory mapping. + while (!t2_exited) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + auto t2 = ScopedThread([&] { + // Wait for t1 to enable kcov, and make sure that enabling kcov again fails. + while (!t1_enabled) { + sched_yield(); + } + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + t2_failed = true; + + // Wait for t1 to disable kcov, after which using fd should now succeed. + while (!t1_disabled) { + sched_yield(); + } + uint64_t* area = KcovMmap(fd); + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + }); + + t2.Join(); + t2_exited = true; +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/membarrier.cc b/test/syscalls/linux/membarrier.cc new file mode 100644 index 000000000..516956a25 --- /dev/null +++ b/test/syscalls/linux/membarrier.cc @@ -0,0 +1,268 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <signal.h> +#include <sys/syscall.h> +#include <sys/types.h> +#include <unistd.h> + +#include <atomic> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/util/cleanup.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// This is the classic test case for memory fences on architectures with total +// store ordering; see e.g. Intel SDM Vol. 3A Sec. 8.2.3.4 "Loads May Be +// Reordered with Earlier Stores to Different Locations". In each iteration of +// the test, given two variables X and Y initially set to 0 +// (MembarrierTestSharedState::local_var and remote_var in the code), two +// threads execute as follows: +// +// T1 T2 +// -- -- +// +// X = 1 Y = 1 +// T1fence() T2fence() +// read Y read X +// +// On architectures where memory writes may be locally buffered by each CPU +// (essentially all architectures), if T1fence() and T2fence() are omitted or +// ineffective, it is possible for both T1 and T2 to read 0 because the memory +// write from the other CPU is not yet visible outside that CPU. T1fence() and +// T2fence() are expected to perform the necessary synchronization to restore +// sequential consistency: both threads agree on a order of memory accesses that +// is consistent with program order in each thread, such that at least one +// thread reads 1. +// +// In the NoMembarrier test, T1fence() and T2fence() are both ordinary memory +// fences establishing ordering between memory accesses before and after the +// fence (std::atomic_thread_fence). In all other test cases, T1fence() is not a +// memory fence at all, but only prevents compiler reordering of memory accesses +// (std::atomic_signal_fence); T2fence() is an invocation of the membarrier() +// syscall, which establishes ordering of memory accesses before and after the +// syscall on both threads. + +template <typename F> +int DoMembarrierTestSide(std::atomic<int>* our_var, + std::atomic<int> const& their_var, + F const& test_fence) { + our_var->store(1, std::memory_order_relaxed); + test_fence(); + return their_var.load(std::memory_order_relaxed); +} + +struct MembarrierTestSharedState { + std::atomic<int64_t> remote_iter_cur; + std::atomic<int64_t> remote_iter_done; + std::atomic<int> local_var; + std::atomic<int> remote_var; + int remote_obs_of_local_var; + + void Init() { + remote_iter_cur.store(-1, std::memory_order_relaxed); + remote_iter_done.store(-1, std::memory_order_relaxed); + } +}; + +// Special value for MembarrierTestSharedState::remote_iter_cur indicating that +// the remote thread should terminate. +constexpr int64_t kRemoteIterStop = -2; + +// Must be async-signal-safe. +template <typename F> +void RunMembarrierTestRemoteSide(MembarrierTestSharedState* state, + F const& test_fence) { + int64_t i = 0; + int64_t cur; + while (true) { + while ((cur = state->remote_iter_cur.load(std::memory_order_acquire)) < i) { + if (cur == kRemoteIterStop) { + return; + } + // spin + } + state->remote_obs_of_local_var = + DoMembarrierTestSide(&state->remote_var, state->local_var, test_fence); + state->remote_iter_done.store(i, std::memory_order_release); + i++; + } +} + +template <typename F> +void RunMembarrierTestLocalSide(MembarrierTestSharedState* state, + F const& test_fence) { + // On test completion, instruct the remote thread to terminate. + Cleanup cleanup_remote([&] { + state->remote_iter_cur.store(kRemoteIterStop, std::memory_order_relaxed); + }); + + int64_t i = 0; + absl::Time end = absl::Now() + absl::Seconds(5); // arbitrary test duration + while (absl::Now() < end) { + // Reset both vars to 0. + state->local_var.store(0, std::memory_order_relaxed); + state->remote_var.store(0, std::memory_order_relaxed); + // Instruct the remote thread to begin this iteration. + state->remote_iter_cur.store(i, std::memory_order_release); + // Perform our side of the test. + auto local_obs_of_remote_var = + DoMembarrierTestSide(&state->local_var, state->remote_var, test_fence); + // Wait for the remote thread to finish this iteration. + while (state->remote_iter_done.load(std::memory_order_acquire) < i) { + // spin + } + ASSERT_TRUE(local_obs_of_remote_var != 0 || + state->remote_obs_of_local_var != 0); + i++; + } +} + +TEST(MembarrierTest, NoMembarrier) { + MembarrierTestSharedState state; + state.Init(); + + ScopedThread remote_thread([&] { + RunMembarrierTestRemoteSide( + &state, [] { std::atomic_thread_fence(std::memory_order_seq_cst); }); + }); + RunMembarrierTestLocalSide( + &state, [] { std::atomic_thread_fence(std::memory_order_seq_cst); }); +} + +enum membarrier_cmd { + MEMBARRIER_CMD_QUERY = 0, + MEMBARRIER_CMD_GLOBAL = (1 << 0), + MEMBARRIER_CMD_GLOBAL_EXPEDITED = (1 << 1), + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED = (1 << 2), + MEMBARRIER_CMD_PRIVATE_EXPEDITED = (1 << 3), + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED = (1 << 4), +}; + +int membarrier(membarrier_cmd cmd, int flags) { + return syscall(SYS_membarrier, cmd, flags); +} + +PosixErrorOr<int> SupportedMembarrierCommands() { + int cmds = membarrier(MEMBARRIER_CMD_QUERY, 0); + if (cmds < 0) { + if (errno == ENOSYS) { + // No commands are supported. + return 0; + } + return PosixError(errno, "membarrier(MEMBARRIER_CMD_QUERY) failed"); + } + return cmds; +} + +TEST(MembarrierTest, Global) { + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + MEMBARRIER_CMD_GLOBAL) == 0); + + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); + auto state = static_cast<MembarrierTestSharedState*>(m.ptr()); + state->Init(); + + pid_t const child_pid = fork(); + if (child_pid == 0) { + // In child process. + RunMembarrierTestRemoteSide( + state, [] { TEST_PCHECK(membarrier(MEMBARRIER_CMD_GLOBAL, 0) == 0); }); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + Cleanup cleanup_child([&] { + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); + RunMembarrierTestLocalSide( + state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +TEST(MembarrierTest, GlobalExpedited) { + constexpr int kRequiredCommands = MEMBARRIER_CMD_GLOBAL_EXPEDITED | + MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED; + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + kRequiredCommands) != kRequiredCommands); + + ASSERT_THAT(membarrier(MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, 0), + SyscallSucceeds()); + + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED)); + auto state = static_cast<MembarrierTestSharedState*>(m.ptr()); + state->Init(); + + pid_t const child_pid = fork(); + if (child_pid == 0) { + // In child process. + RunMembarrierTestRemoteSide(state, [] { + TEST_PCHECK(membarrier(MEMBARRIER_CMD_GLOBAL_EXPEDITED, 0) == 0); + }); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + Cleanup cleanup_child([&] { + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); + RunMembarrierTestLocalSide( + state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +TEST(MembarrierTest, PrivateExpedited) { + constexpr int kRequiredCommands = MEMBARRIER_CMD_PRIVATE_EXPEDITED | + MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED; + SKIP_IF((ASSERT_NO_ERRNO_AND_VALUE(SupportedMembarrierCommands()) & + kRequiredCommands) != kRequiredCommands); + + ASSERT_THAT(membarrier(MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, 0), + SyscallSucceeds()); + + MembarrierTestSharedState state; + state.Init(); + + ScopedThread remote_thread([&] { + RunMembarrierTestRemoteSide(&state, [] { + TEST_PCHECK(membarrier(MEMBARRIER_CMD_PRIVATE_EXPEDITED, 0) == 0); + }); + }); + RunMembarrierTestLocalSide( + &state, [] { std::atomic_signal_fence(std::memory_order_seq_cst); }); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index e0981e28a..9f522f833 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -903,6 +903,58 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { EXPECT_EQ(err, ECONNREFUSED); } +TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) { + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + const FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + (bind)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + // Get the bound port. + ASSERT_THAT( + getsockname(s.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)( + s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + constexpr int kBufSz = 1 << 20; // 1 MiB + std::vector<char> writebuf(kBufSz); + + // Start reading the response in a loop. + int read_bytes = 0; + ScopedThread t([&s, &read_bytes]() { + // Too many syscalls. + const DisableSave ds; + + char readbuf[2500] = {}; + int n = -1; + while (n != 0) { + ASSERT_THAT(n = RetryEINTR(read)(s.get(), &readbuf, sizeof(readbuf)), + SyscallSucceeds()); + read_bytes += n; + } + }); + + // Try to send the whole thing. + int n; + ASSERT_THAT(n = SendFd(s.get(), writebuf.data(), kBufSz, 0), + SyscallSucceeds()); + + // We should have written the whole thing. + EXPECT_EQ(n, kBufSz); + EXPECT_THAT(shutdown(s.get(), SHUT_WR), SyscallSucceedsWithValue(0)); + t.Join(); + + // We should have read the whole thing. + EXPECT_EQ(read_bytes, kBufSz); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnect) { const FileDescriptor listener = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); diff --git a/test/util/BUILD b/test/util/BUILD index fc5fb3a8d..26c2b6a2f 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "coreutil", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -254,7 +254,7 @@ cc_library( ], hdrs = ["test_util.h"], defines = select_system(), - deps = [ + deps = coreutil() + [ ":fs_util", ":logging", ":posix_error", diff --git a/tools/BUILD b/tools/BUILD index da83877b1..faf310676 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -5,5 +5,7 @@ package(licenses = ["notice"]) bzl_library( name = "defs_bzl", srcs = ["defs.bzl"], - visibility = ["//visibility:private"], + visibility = [ + "//:sandbox", + ], ) diff --git a/tools/bazel.mk b/tools/bazel.mk index 5e129b2ed..25575c02c 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -152,8 +152,9 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ - | grep -E "^ bazel-bin/" \ - | tr -d '\r' \ + | grep " bazel-bin/" \ + | sed "s/ /\n/g" \ + | strings -n 10 \ | awk '{$$1=$$1};1' \ | xargs -n 1 -I {} sh -c "$(1)" diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index dad5fc3b2..cf5b1dc0d 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -191,3 +191,6 @@ def default_installer(): def default_net_util(): return [] # Nothing needed. + +def coreutil(): + return [] # Nothing needed. diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD index 5748fb390..2b0062a63 100644 --- a/tools/bigquery/BUILD +++ b/tools/bigquery/BUILD @@ -6,5 +6,8 @@ go_library( name = "bigquery", testonly = 1, srcs = ["bigquery.go"], + visibility = [ + "//:sandbox", + ], deps = ["@com_google_cloud_go_bigquery//:go_default_library"], ) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 56f0dc5c9..5f1a882de 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -30,11 +30,20 @@ import ( // Benchmark is the top level structure of recorded benchmark data. BigQuery // will infer the schema from this. type Benchmark struct { - Name string `bq:"name"` - Timestamp time.Time `bq:"timestamp"` - Official bool `bq:"official"` - Metric []*Metric `bq:"metric"` - Metadata *Metadata `bq:"metadata"` + Name string `bq:"name"` + Condition []*Condition `bq:"condition"` + Timestamp time.Time `bq:"timestamp"` + Official bool `bq:"official"` + Metric []*Metric `bq:"metric"` + Metadata *Metadata `bq:"metadata"` +} + +// Condition represents qualifiers for the benchmark. For example: +// Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1" +// and "real_time" parameters as conditions. +type Condition struct { + Name string `bq:"name"` + Value string `bq:"value"` } // Metric holds the actual metric data and unit information for this benchmark. @@ -79,6 +88,14 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) err return nil } +// AddCondition adds a condition to an existing Benchmark. +func (bm *Benchmark) AddCondition(name, value string) { + bm.Condition = append(bm.Condition, &Condition{ + Name: name, + Value: value, + }) +} + // AddMetric adds a metric to an existing Benchmark. func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { m := &Metric{ @@ -90,7 +107,7 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { } // NewBenchmark initializes a new benchmark. -func NewBenchmark(name string, official bool) *Benchmark { +func NewBenchmark(name string, iters int, official bool) *Benchmark { return &Benchmark{ Name: name, Timestamp: time.Now().UTC(), @@ -103,7 +120,7 @@ func NewBenchmark(name string, official bool) *Benchmark { func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error { client, err := bq.NewClient(ctx, projectID) if err != nil { - return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err) + return fmt.Errorf("failed to initialize client on project: %s: %v", projectID, err) } defer client.Close() diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index d98f5c3a1..523a42692 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -102,8 +102,8 @@ var ( // This may be set instead of Binary. Reader io.Reader - // Tool is the tool used to dump a binary. - tool = flag.String("dump_tool", "", "tool used to dump a binary") + // objdumpTool is the tool used to dump a binary. + objdumpTool = flag.String("objdump_tool", "", "tool used to dump a binary") ) // EscapeReason is an escape reason. @@ -387,7 +387,7 @@ func loadObjdump() (map[string][]string, error) { } // Construct our command. - cmd := exec.Command(*tool, args...) + cmd := exec.Command(*objdumpTool, args...) cmd.Stdin = stdin cmd.Stderr = os.Stderr out, err := cmd.StdoutPipe() @@ -398,7 +398,37 @@ func loadObjdump() (map[string][]string, error) { return nil, err } + // Identify calls by address or name. Note that this is also + // constructed dynamically below, as we encounted the addresses. + // This is because some of the functions (duffzero) may have + // jump targets in the middle of the function itself. + funcsAllowed := map[string]struct{}{ + "runtime.duffzero": struct{}{}, + "runtime.duffcopy": struct{}{}, + "runtime.racefuncenter": struct{}{}, + "runtime.gcWriteBarrier": struct{}{}, + "runtime.retpolineAX": struct{}{}, + "runtime.retpolineBP": struct{}{}, + "runtime.retpolineBX": struct{}{}, + "runtime.retpolineCX": struct{}{}, + "runtime.retpolineDI": struct{}{}, + "runtime.retpolineDX": struct{}{}, + "runtime.retpolineR10": struct{}{}, + "runtime.retpolineR11": struct{}{}, + "runtime.retpolineR12": struct{}{}, + "runtime.retpolineR13": struct{}{}, + "runtime.retpolineR14": struct{}{}, + "runtime.retpolineR15": struct{}{}, + "runtime.retpolineR8": struct{}{}, + "runtime.retpolineR9": struct{}{}, + "runtime.retpolineSI": struct{}{}, + "runtime.stackcheck": struct{}{}, + "runtime.settls": struct{}{}, + } + addrsAllowed := make(map[string]struct{}) + // Build the map. + nextFunc := "" // For funcsAllowed. m := make(map[string][]string) r := bufio.NewReader(out) NextLine: @@ -407,6 +437,19 @@ NextLine: if err != nil && err != io.EOF { return nil, err } + fields := strings.Fields(line) + + // Is this an "allowed" function definition? + if len(fields) >= 2 && fields[0] == "TEXT" { + nextFunc = strings.TrimSuffix(fields[1], "(SB)") + if _, ok := funcsAllowed[nextFunc]; !ok { + nextFunc = "" // Don't record addresses. + } + } + if nextFunc != "" && len(fields) > 2 { + // Save the given address (in hex form, as it appears). + addrsAllowed[fields[1]] = struct{}{} + } // We recognize lines corresponding to actual code (not the // symbol name or other metadata) and annotate them if they @@ -416,53 +459,31 @@ NextLine: // // Lines look like this (including the first space): // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX - if len(line) > 0 && line[0] == ' ' { - fields := strings.Fields(line) + if len(fields) >= 5 && line[0] == ' ' { if !strings.Contains(fields[3], "CALL") { continue } - site := strings.TrimSpace(fields[0]) - var callStr string // Friendly string. - if len(fields) > 5 { - callStr = strings.Join(fields[5:], " ") - } - if len(callStr) == 0 { - // Just a raw call? is this asm? - callStr = strings.Join(fields[3:], " ") - } - - // Ignore strings containing duffzero, which is just - // used by stack allocations for types that are large - // enough to warrant Duff's device. - if strings.Contains(callStr, "runtime.duffzero") || - strings.Contains(callStr, "runtime.duffcopy") { - continue - } - - // Ignore the racefuncenter call, which is used for - // race builds. This does not escape. - if strings.Contains(callStr, "runtime.racefuncenter") { - continue - } + site := fields[0] + target := strings.TrimSuffix(fields[4], "(SB)") - // Ignore the write barriers. - if strings.Contains(callStr, "runtime.gcWriteBarrier") { + // Ignore strings containing allowed functions. + if _, ok := funcsAllowed[target]; ok { continue } - - // Ignore retpolines. - if strings.Contains(callStr, "runtime.retpoline") { + if _, ok := addrsAllowed[target]; ok { continue } - - // Ignore stack sanity check (does not split). - if strings.Contains(callStr, "runtime.stackcheck") { - continue - } - - // Ignore tls functions. - if strings.Contains(callStr, "runtime.settls") { - continue + if len(fields) > 5 { + // This may be a future relocation. Some + // objdump versions describe this differently. + // If it contains any of the functions allowed + // above as a string, we let it go. + softTarget := strings.Join(fields[5:], " ") + for name := range funcsAllowed { + if strings.Contains(softTarget, name) { + continue NextLine + } + } } // Does this exist already? @@ -471,11 +492,11 @@ NextLine: existing = make([]string, 0, 1) } for _, other := range existing { - if callStr == other { + if target == other { continue NextLine } } - existing = append(existing, callStr) + existing = append(existing, target) m[site] = existing // Update. } if err == io.EOF { @@ -483,12 +504,25 @@ NextLine: } } + // Zap any accidental false positives. + final := make(map[string][]string) + for site, calls := range m { + filteredCalls := make([]string, 0, len(calls)) + for _, call := range calls { + if _, ok := addrsAllowed[call]; ok { + continue // Omit this call. + } + filteredCalls = append(filteredCalls, call) + } + final[site] = filteredCalls + } + // Wait for the dump to finish. if err := cmd.Wait(); err != nil { return nil, err } - return m, nil + return final, nil } // poser is a type that implements Pos. diff --git a/tools/defs.bzl b/tools/defs.bzl index 60c9c9d8c..e2c72a1f6 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") load("//tools/nogo:defs.bzl", "nogo_test") @@ -39,6 +39,7 @@ short_path = _short_path rbe_platform = _rbe_platform rbe_toolchain = _rbe_toolchain vdso_linker_option = _vdso_linker_option +coreutil = _coreutil # Platform options. default_platform = _default_platform @@ -71,7 +72,8 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, ** ) nogo_test( name = name + "_nogo", - deps = [":" + name + "_nogo_library"], + srcs = kwargs.get("srcs", []), + library = ":" + name + "_nogo_library", ) def calculate_sets(srcs): @@ -202,7 +204,8 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F if nogo: nogo_test( name = name + "_nogo", - deps = [":" + name], + srcs = all_srcs, + library = ":" + name, ) if marshal: @@ -238,7 +241,8 @@ def go_test(name, nogo = True, **kwargs): if nogo: nogo_test( name = name + "_nogo", - deps = [":" + name], + srcs = kwargs.get("srcs", []), + library = ":" + name, ) def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 56fbcb5d2..4a53d25be 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -80,15 +80,15 @@ type Generator struct { func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) + return nil, fmt.Errorf("couldn't open output file %q: %w", out, err) } fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) + return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err) } fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - return nil, fmt.Errorf("Couldn't open unconditional test output file %q: %v", out, err) + return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err) } g := Generator{ inputs: srcs, @@ -181,7 +181,7 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { // Not a valid input file? - return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) + return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err) } if debugEnabled() { diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 4f6ed208a..e1de12e25 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -39,7 +39,7 @@ var ( ) // resolveTypeName returns a qualified type name. -func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { +func resolveTypeName(typ ast.Expr) (field string, qualified string) { for done := false; !done; { // Resolve star expressions. switch rs := typ.(type) { @@ -69,11 +69,7 @@ func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) } // Figure out actual type name. - ident, ok := typ.(*ast.Ident) - if !ok { - panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) - } - field = ident.Name + field = typ.(*ast.Ident).Name qualified = qualified + field return } @@ -119,7 +115,7 @@ func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { } else { // Anonymous types can't be embedded, so we don't need // to worry about providing a useful name here. - name, _ = resolveTypeName("", field.Type) + name, _ = resolveTypeName(field.Type) } // Skip _ fields. @@ -214,9 +210,6 @@ func main() { emitRegister := func(name string) { initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) } - emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) - } // Automated warning. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") @@ -265,52 +258,39 @@ func main() { } type method struct { - receiver string - name string + typeName string + methodName string } - // Search for and add all methods with a pointer receiver and no other - // arguments to a set. We support auto-detecting the existence of - // several different methods with this signature. - simpleMethods := map[method]struct{}{} + // Search for and add all method to a set. We auto-detecting several + // different methods (and insert them if we don't find them, in order + // to ensure that expectations match reality). + // + // While we do this, figure out the right receiver name. If there are + // multiple distinct receivers, then we will just pick the last one. + simpleMethods := make(map[method]struct{}) + receiverNames := make(map[string]string) for _, f := range files { - // Go over all functions. for _, decl := range f.Decls { d, ok := decl.(*ast.FuncDecl) if !ok { continue } - if d.Name == nil || d.Recv == nil || d.Type == nil { + if d.Recv == nil || len(d.Recv.List) != 1 { // Not a named method. continue } - if len(d.Recv.List) != 1 { - // Wrong number of receivers? - continue - } - if d.Type.Params != nil && len(d.Type.Params.List) != 0 { - // Has argument(s). - continue - } - if d.Type.Results != nil && len(d.Type.Results.List) != 0 { - // Has return(s). - continue - } - pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) - if !ok { - // Not a pointer receiver. - continue + // Save the method and the receiver. + name, _ := resolveTypeName(d.Recv.List[0].Type) + simpleMethods[method{ + typeName: name, + methodName: d.Name.Name, + }] = struct{}{} + if len(d.Recv.List[0].Names) > 0 { + receiverNames[name] = d.Recv.List[0].Names[0].Name } - - t, ok := pt.X.(*ast.Ident) - if !ok { - // This shouldn't happen with valid Go. - continue - } - - simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} } } @@ -349,6 +329,11 @@ func main() { for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) + recv, ok := receiverNames[ts.Name.Name] + if !ok { + // Maybe no methods were defined? + recv = strings.ToLower(ts.Name.Name[:1]) + } switch x := ts.Type.(type) { case *ast.StructType: maybeEmitImports() @@ -365,29 +350,32 @@ func main() { emitField(name) } emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName) } emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name) } emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name) } emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name)) + fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name) } emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name) + } + emitZeroCheck := func(name string) { + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name) } // Generate the type name method. - fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") // Generate the fields method. - fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return []string{\n") scanFields(x, "", scanFunctions{ normal: emitField, @@ -401,8 +389,11 @@ func main() { // the code from compiling if a custom beforeSave was defined in a // file not provided to this binary and prevents inherited methods // from being called multiple times by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) + if _, ok := simpleMethods[method{ + typeName: ts.Name.Name, + methodName: "beforeSave", + }]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name) } // Generate the save method. @@ -412,8 +403,8 @@ func main() { // on this specific behavior, but the ability to specify slots // allows a manual implementation to be order-dependent. if generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") + fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv) scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) scanFields(x, "", scanFunctions{value: emitSaveValue}) scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) @@ -422,16 +413,19 @@ func main() { // Define afterLoad if a definition was not found. We do this for // the same reason that we do it for beforeSave. - _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] + _, hasAfterLoad := simpleMethods[method{ + typeName: ts.Name.Name, + methodName: "afterLoad", + }] if !hasAfterLoad && generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name) } // Generate the load method. // // N.B. See the comment above for the save method. if generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) scanFields(x, "", scanFunctions{value: emitLoadValue}) if hasAfterLoad { @@ -439,7 +433,7 @@ func main() { // AfterLoad is called, the object encodes a dependency on // referred objects (i.e. fields). This means that afterLoad // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv) } fmt.Fprintf(outputFile, "}\n\n") } @@ -451,10 +445,10 @@ func main() { maybeEmitImports() // Generate the info methods. - fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return nil\n") fmt.Fprintf(outputFile, "}\n\n") diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 9f1fcd9c7..dd4b46f58 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -1,10 +1,10 @@ load("//tools:defs.bzl", "bzl_library", "go_library") -load("//tools/nogo:defs.bzl", "nogo_dump_tool", "nogo_stdlib") +load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib") package(licenses = ["notice"]) -nogo_dump_tool( - name = "dump_tool", +nogo_objdump_tool( + name = "objdump_tool", visibility = ["//visibility:public"], ) @@ -13,6 +13,12 @@ nogo_stdlib( visibility = ["//visibility:public"], ) +sh_binary( + name = "gentest", + srcs = ["gentest.sh"], + visibility = ["//visibility:public"], +) + go_library( name = "nogo", srcs = [ @@ -27,6 +33,8 @@ go_library( deps = [ "//tools/checkescape", "//tools/checkunsafe", + "@co_honnef_go_tools//staticcheck:go_default_library", + "@co_honnef_go_tools//stylecheck:go_default_library", "@org_golang_x_tools//go/analysis:go_tool_library", "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 39c2ae418..55d34760f 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -29,6 +29,9 @@ var ( // internalDefault is applied when no paths are provided. internalDefault = fmt.Sprintf("%s/.*", notPath("external")) + // generatedPrefix is a regex for generated files. + generatedPrefix = "^(.*/)?(bazel-genfiles|bazel-out|bazel-bin)/" + // externalPrefix is external workspace packages. externalPrefix = "^external/" ) diff --git a/tools/nogo/config.go b/tools/nogo/config.go index cfe7b4aa4..8079618ab 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -41,6 +41,8 @@ import ( "golang.org/x/tools/go/analysis/passes/unreachable" "golang.org/x/tools/go/analysis/passes/unsafeptr" "golang.org/x/tools/go/analysis/passes/unusedresult" + "honnef.co/go/tools/staticcheck" + "honnef.co/go/tools/stylecheck" "gvisor.dev/gvisor/tools/checkescape" "gvisor.dev/gvisor/tools/checkunsafe" @@ -123,6 +125,432 @@ var analyzerConfig = map[*analysis.Analyzer]matcher{ checkunsafe.Analyzer: internalMatches(), } +func init() { + staticMatcher := and( + // Only match internal, non-generated files. + internalMatches(), + generatedExcluded(), + + // We use ALL_CAPS for system definitions, + // which are common enough in the code base + // that we shouldn't annotate exceptions. + // + // Same story for underscores. + resultExcluded([]string{ + "should not use ALL_CAPS in Go names", + "should not use underscores in Go names", + }), + + // Exclude existing matches. + internalExcluded( + "pkg/abi/linux/fuse.go:22", + "pkg/abi/linux/fuse.go:25", + "pkg/abi/linux/socket.go:113", + "pkg/abi/linux/tty.go:73", + "pkg/bpf/decoder.go:112", + "pkg/cpuid/cpuid_x86.go:675", + "pkg/eventchannel/event.go:193", + "pkg/eventchannel/event.go:27", + "pkg/eventchannel/event_test.go:22", + "pkg/eventchannel/rate.go:19", + "pkg/gohacks/gohacks_unsafe.go:33", + "pkg/log/json.go:30", + "pkg/log/log.go:359", + "pkg/merkletree/merkletree.go:230", + "pkg/merkletree/merkletree.go:243", + "pkg/merkletree/merkletree.go:249", + "pkg/merkletree/merkletree.go:266", + "pkg/merkletree/merkletree.go:355", + "pkg/merkletree/merkletree.go:369", + "pkg/metric/metric_test.go:20", + "pkg/p9/p9test/client_test.go:687", + "pkg/p9/transport_test.go:196", + "pkg/pool/pool.go:15", + "pkg/refs/refcounter.go:510", + "pkg/refs/refcounter_test.go:169", + "pkg/refs_vfs2/refs.go:16", + "pkg/safemem/block_unsafe.go:89", + "pkg/seccomp/seccomp.go:82", + "pkg/segment/test/set_functions.go:15", + "pkg/sentry/arch/signal.go:166", + "pkg/sentry/arch/signal.go:171", + "pkg/sentry/control/pprof.go:196", + "pkg/sentry/devices/memdev/full.go:58", + "pkg/sentry/devices/memdev/null.go:59", + "pkg/sentry/devices/memdev/random.go:68", + "pkg/sentry/devices/memdev/zero.go:86", + "pkg/sentry/fdimport/fdimport.go:15", + "pkg/sentry/fs/attr.go:257", + "pkg/sentry/fsbridge/fs.go:116", + "pkg/sentry/fsbridge/vfs.go:124", + "pkg/sentry/fsbridge/vfs.go:70", + "pkg/sentry/fs/copy_up.go:365", + "pkg/sentry/fs/copy_up_test.go:65", + "pkg/sentry/fs/dev/net_tun.go:161", + "pkg/sentry/fs/dev/net_tun.go:63", + "pkg/sentry/fs/dev/null.go:97", + "pkg/sentry/fs/dirent_cache.go:64", + "pkg/sentry/fs/file_overlay.go:327", + "pkg/sentry/fs/file_overlay.go:524", + "pkg/sentry/fs/filetest/filetest.go:55", + "pkg/sentry/fs/filetest/filetest.go:60", + "pkg/sentry/fs/fs.go:77", + "pkg/sentry/fs/fsutil/file.go:290", + "pkg/sentry/fs/fsutil/file.go:346", + "pkg/sentry/fs/fsutil/host_file_mapper.go:105", + "pkg/sentry/fs/fsutil/inode_cached.go:676", + "pkg/sentry/fs/fsutil/inode_cached.go:772", + "pkg/sentry/fs/gofer/attr.go:120", + "pkg/sentry/fs/gofer/fifo.go:33", + "pkg/sentry/fs/gofer/inode.go:410", + "pkg/sentry/fsimpl/devpts/devpts.go:110", + "pkg/sentry/fsimpl/devpts/devpts.go:246", + "pkg/sentry/fsimpl/devpts/devpts.go:50", + "pkg/sentry/fsimpl/devpts/master.go:110", + "pkg/sentry/fsimpl/devpts/master.go:55", + "pkg/sentry/fsimpl/devpts/replica.go:113", + "pkg/sentry/fsimpl/devpts/replica.go:57", + "pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54", + "pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97", + "pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92", + "pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44", + "pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91", + "pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93", + "pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66", + "pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53", + "pkg/sentry/fsimpl/eventfd/eventfd.go:268", + "pkg/sentry/fsimpl/ext/directory.go:163", + "pkg/sentry/fsimpl/ext/directory.go:164", + "pkg/sentry/fsimpl/ext/extent_file.go:142", + "pkg/sentry/fsimpl/ext/extent_file.go:143", + "pkg/sentry/fsimpl/ext/ext.go:105", + "pkg/sentry/fsimpl/ext/filesystem.go:287", + "pkg/sentry/fsimpl/ext/regular_file.go:153", + "pkg/sentry/fsimpl/ext/symlink.go:113", + "pkg/sentry/fsimpl/fuse/connection_control.go:194", + "pkg/sentry/fsimpl/fuse/dev.go:387", + "pkg/sentry/fsimpl/fuse/dev_test.go:318", + "pkg/sentry/fsimpl/fuse/fusefs.go:102", + "pkg/sentry/fsimpl/fuse/read_write.go:129", + "pkg/sentry/fsimpl/fuse/request_response.go:71", + "pkg/sentry/fsimpl/gofer/directory.go:135", + "pkg/sentry/fsimpl/gofer/filesystem.go:679", + "pkg/sentry/fsimpl/gofer/gofer.go:1694", + "pkg/sentry/fsimpl/gofer/gofer.go:276", + "pkg/sentry/fsimpl/gofer/regular_file.go:81", + "pkg/sentry/fsimpl/gofer/special_file.go:141", + "pkg/sentry/fsimpl/host/host.go:184", + "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50", + "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90", + "pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273", + "pkg/sentry/fsimpl/kernfs/filesystem.go:247", + "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320", + "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497", + "pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52", + "pkg/sentry/fsimpl/overlay/directory.go:119", + "pkg/sentry/fsimpl/overlay/filesystem.go:527", + "pkg/sentry/fsimpl/overlay/non_directory.go:152", + "pkg/sentry/fsimpl/overlay/overlay.go:115", + "pkg/sentry/fsimpl/overlay/overlay.go:719", + "pkg/sentry/fsimpl/pipefs/pipefs.go:74", + "pkg/sentry/fsimpl/proc/filesystem.go:52", + "pkg/sentry/fsimpl/proc/filesystem.go:81", + "pkg/sentry/fsimpl/proc/subtasks.go:126", + "pkg/sentry/fsimpl/proc/subtasks.go:189", + "pkg/sentry/fsimpl/proc/task_fds.go:168", + "pkg/sentry/fsimpl/proc/task_fds.go:228", + "pkg/sentry/fsimpl/proc/task_fds.go:301", + "pkg/sentry/fsimpl/proc/task_fds.go:318", + "pkg/sentry/fsimpl/proc/task_fds.go:67", + "pkg/sentry/fsimpl/proc/task_files.go:112", + "pkg/sentry/fsimpl/proc/task_files.go:158", + "pkg/sentry/fsimpl/proc/task_files.go:259", + "pkg/sentry/fsimpl/proc/task_files.go:285", + "pkg/sentry/fsimpl/proc/task_files.go:305", + "pkg/sentry/fsimpl/proc/task_files.go:384", + "pkg/sentry/fsimpl/proc/task_files.go:403", + "pkg/sentry/fsimpl/proc/task_files.go:428", + "pkg/sentry/fsimpl/proc/task_files.go:691", + "pkg/sentry/fsimpl/proc/task_files.go:770", + "pkg/sentry/fsimpl/proc/task_files.go:797", + "pkg/sentry/fsimpl/proc/task_files.go:828", + "pkg/sentry/fsimpl/proc/task_files.go:879", + "pkg/sentry/fsimpl/proc/task_files.go:910", + "pkg/sentry/fsimpl/proc/task_files.go:961", + "pkg/sentry/fsimpl/proc/task.go:127", + "pkg/sentry/fsimpl/proc/task.go:193", + "pkg/sentry/fsimpl/proc/task_net.go:134", + "pkg/sentry/fsimpl/proc/task_net.go:475", + "pkg/sentry/fsimpl/proc/task_net.go:491", + "pkg/sentry/fsimpl/proc/task_net.go:508", + "pkg/sentry/fsimpl/proc/task_net.go:665", + "pkg/sentry/fsimpl/proc/task_net.go:715", + "pkg/sentry/fsimpl/proc/task_net.go:779", + "pkg/sentry/fsimpl/proc/tasks_files.go:113", + "pkg/sentry/fsimpl/proc/tasks_files.go:388", + "pkg/sentry/fsimpl/proc/tasks.go:232", + "pkg/sentry/fsimpl/proc/tasks_sys.go:145", + "pkg/sentry/fsimpl/proc/tasks_sys.go:181", + "pkg/sentry/fsimpl/proc/tasks_sys.go:239", + "pkg/sentry/fsimpl/proc/tasks_sys.go:291", + "pkg/sentry/fsimpl/proc/tasks_sys.go:375", + "pkg/sentry/fsimpl/signalfd/signalfd.go:124", + "pkg/sentry/fsimpl/signalfd/signalfd.go:15", + "pkg/sentry/fsimpl/signalfd/signalfd.go:126", + "pkg/sentry/fsimpl/sockfs/sockfs.go:36", + "pkg/sentry/fsimpl/sockfs/sockfs.go:79", + "pkg/sentry/fsimpl/sys/kcov.go:49", + "pkg/sentry/fsimpl/sys/kcov.go:99", + "pkg/sentry/fsimpl/sys/sys.go:118", + "pkg/sentry/fsimpl/sys/sys.go:56", + "pkg/sentry/fsimpl/testutil/testutil.go:257", + "pkg/sentry/fsimpl/testutil/testutil.go:260", + "pkg/sentry/fsimpl/timerfd/timerfd.go:87", + "pkg/sentry/fsimpl/tmpfs/directory.go:112", + "pkg/sentry/fsimpl/tmpfs/filesystem.go:195", + "pkg/sentry/fsimpl/tmpfs/regular_file.go:226", + "pkg/sentry/fsimpl/tmpfs/regular_file.go:346", + "pkg/sentry/fsimpl/tmpfs/tmpfs.go:103", + "pkg/sentry/fsimpl/tmpfs/tmpfs.go:733", + "pkg/sentry/fsimpl/verity/filesystem.go:490", + "pkg/sentry/fsimpl/verity/verity.go:156", + "pkg/sentry/fsimpl/verity/verity.go:629", + "pkg/sentry/fsimpl/verity/verity.go:672", + "pkg/sentry/fs/mount.go:162", + "pkg/sentry/fs/mount.go:256", + "pkg/sentry/fs/mount_overlay.go:144", + "pkg/sentry/fs/mounts.go:432", + "pkg/sentry/fs/proc/exec_args.go:104", + "pkg/sentry/fs/proc/exec_args.go:73", + "pkg/sentry/fs/proc/fds.go:269", + "pkg/sentry/fs/proc/loadavg.go:33", + "pkg/sentry/fs/proc/meminfo.go:39", + "pkg/sentry/fs/proc/mounts.go:193", + "pkg/sentry/fs/proc/mounts.go:84", + "pkg/sentry/fs/proc/net.go:125", + "pkg/sentry/fs/proc/proc.go:146", + "pkg/sentry/fs/proc/proc.go:204", + "pkg/sentry/fs/proc/seqfile/seqfile.go:210", + "pkg/sentry/fs/proc/sys.go:146", + "pkg/sentry/fs/proc/sys.go:43", + "pkg/sentry/fs/proc/sys_net.go:113", + "pkg/sentry/fs/proc/sys_net.go:205", + "pkg/sentry/fs/proc/sys_net.go:233", + "pkg/sentry/fs/proc/sys_net.go:307", + "pkg/sentry/fs/proc/sys_net.go:335", + "pkg/sentry/fs/proc/sys_net.go:446", + "pkg/sentry/fs/proc/sys_net.go:456", + "pkg/sentry/fs/proc/sys_net.go:89", + "pkg/sentry/fs/proc/task.go:170", + "pkg/sentry/fs/proc/task.go:322", + "pkg/sentry/fs/proc/task.go:427", + "pkg/sentry/fs/proc/task.go:467", + "pkg/sentry/fs/proc/task.go:500", + "pkg/sentry/fs/proc/task.go:784", + "pkg/sentry/fs/proc/task.go:839", + "pkg/sentry/fs/proc/task.go:920", + "pkg/sentry/fs/proc/uid_gid_map.go:108", + "pkg/sentry/fs/proc/uid_gid_map.go:79", + "pkg/sentry/fs/proc/uptime.go:75", + "pkg/sentry/fs/ramfs/dir.go:447", + "pkg/sentry/fs/tmpfs/inode_file.go:436", + "pkg/sentry/fs/tmpfs/inode_file.go:537", + "pkg/sentry/fs/tty/dir.go:313", + "pkg/sentry/fs/tty/master.go:131", + "pkg/sentry/fs/tty/master.go:91", + "pkg/sentry/fs/tty/replica.go:116", + "pkg/sentry/fs/tty/replica.go:88", + "pkg/sentry/kernel/auth/id_map.go:269", + "pkg/sentry/kernel/fasync/fasync.go:67", + "pkg/sentry/kernel/kcov.go:209", + "pkg/sentry/kernel/kcov.go:223", + "pkg/sentry/kernel/kernel.go:343", + "pkg/sentry/kernel/kernel.go:368", + "pkg/sentry/kernel/pipe/node_test.go:112", + "pkg/sentry/kernel/pipe/node_test.go:119", + "pkg/sentry/kernel/pipe/node_test.go:130", + "pkg/sentry/kernel/pipe/node_test.go:137", + "pkg/sentry/kernel/pipe/node_test.go:149", + "pkg/sentry/kernel/pipe/node_test.go:150", + "pkg/sentry/kernel/pipe/node_test.go:158", + "pkg/sentry/kernel/pipe/node_test.go:174", + "pkg/sentry/kernel/pipe/node_test.go:180", + "pkg/sentry/kernel/pipe/node_test.go:193", + "pkg/sentry/kernel/pipe/node_test.go:202", + "pkg/sentry/kernel/pipe/node_test.go:205", + "pkg/sentry/kernel/pipe/node_test.go:216", + "pkg/sentry/kernel/pipe/node_test.go:219", + "pkg/sentry/kernel/pipe/node_test.go:271", + "pkg/sentry/kernel/pipe/node_test.go:290", + "pkg/sentry/kernel/pipe/pipe_test.go:93", + "pkg/sentry/kernel/pipe/reader_writer.go:65", + "pkg/sentry/kernel/posixtimer.go:157", + "pkg/sentry/kernel/ptrace.go:218", + "pkg/sentry/kernel/semaphore/semaphore.go:323", + "pkg/sentry/kernel/sessions.go:123", + "pkg/sentry/kernel/sessions.go:508", + "pkg/sentry/kernel/signal_handlers.go:57", + "pkg/sentry/kernel/task_context.go:72", + "pkg/sentry/kernel/task_exit.go:67", + "pkg/sentry/kernel/task_sched.go:255", + "pkg/sentry/kernel/task_sched.go:280", + "pkg/sentry/kernel/task_sched.go:323", + "pkg/sentry/kernel/task_stop.go:192", + "pkg/sentry/kernel/thread_group.go:530", + "pkg/sentry/kernel/timekeeper.go:316", + "pkg/sentry/kernel/vdso.go:106", + "pkg/sentry/kernel/vdso.go:118", + "pkg/sentry/memmap/memmap.go:103", + "pkg/sentry/memmap/memmap.go:163", + "pkg/sentry/mm/address_space.go:42", + "pkg/sentry/mm/address_space.go:42", + "pkg/sentry/mm/aio_context.go:208", + "pkg/sentry/mm/aio_context.go:288", + "pkg/sentry/mm/pma.go:683", + "pkg/sentry/mm/special_mappable.go:80", + "pkg/sentry/platform/systrap/subprocess.go:370", + "pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go:124", + "pkg/sentry/socket/control/control.go:260", + "pkg/sentry/socket/control/control.go:94", + "pkg/sentry/socket/control/control_vfs2.go:37", + "pkg/sentry/socket/hostinet/stack.go:433", + "pkg/sentry/socket/hostinet/stack.go:438", + "pkg/sentry/socket/hostinet/stack.go:444", + "pkg/sentry/socket/hostinet/stack.go:460", + "pkg/sentry/socket/netfilter/tcp_matcher.go:74", + "pkg/sentry/socket/netfilter/udp_matcher.go:71", + "pkg/sentry/socket/netlink/route/protocol.go:38", + "pkg/sentry/socket/socket.go:332", + "pkg/sentry/socket/unix/transport/connectioned.go:394", + "pkg/sentry/socket/unix/transport/connectionless.go:152", + "pkg/sentry/socket/unix/transport/unix.go:436", + "pkg/sentry/socket/unix/transport/unix.go:490", + "pkg/sentry/socket/unix/transport/unix.go:685", + "pkg/sentry/socket/unix/transport/unix.go:795", + "pkg/sentry/syscalls/linux/sys_sem.go:62", + "pkg/sentry/syscalls/linux/sys_time.go:189", + "pkg/sentry/usage/cpu.go:42", + "pkg/sentry/vfs/anonfs.go:302", + "pkg/sentry/vfs/anonfs.go:99", + "pkg/sentry/vfs/dentry.go:214", + "pkg/sentry/vfs/epoll.go:168", + "pkg/sentry/vfs/epoll.go:314", + "pkg/sentry/vfs/file_description.go:549", + "pkg/sentry/vfs/file_description_impl_util.go:304", + "pkg/sentry/vfs/file_description_impl_util.go:412", + "pkg/sentry/vfs/filesystem.go:76", + "pkg/sentry/vfs/lock.go:15", + "pkg/sentry/vfs/lock.go:47", + "pkg/sentry/vfs/memxattr/xattr.go:37", + "pkg/sentry/vfs/mount.go:510", + "pkg/sentry/vfs/mount.go:667", + "pkg/sentry/vfs/mount_test.go:106", + "pkg/sentry/vfs/mount_test.go:160", + "pkg/sentry/vfs/mount_test.go:215", + "pkg/sentry/vfs/mount_unsafe.go:153", + "pkg/sentry/vfs/resolving_path.go:228", + "pkg/sentry/vfs/vfs.go:897", + "pkg/shim/runsc/runsc.go:16", + "pkg/shim/runsc/utils.go:16", + "pkg/shim/v1/proc/deleted_state.go:16", + "pkg/shim/v1/proc/exec.go:16", + "pkg/shim/v1/proc/exec_state.go:16", + "pkg/shim/v1/proc/init.go:16", + "pkg/shim/v1/proc/init_state.go:16", + "pkg/shim/v1/proc/io.go:16", + "pkg/shim/v1/proc/process.go:16", + "pkg/shim/v1/proc/types.go:16", + "pkg/shim/v1/proc/utils.go:16", + "pkg/shim/v1/shim/api.go:16", + "pkg/shim/v1/shim/platform.go:16", + "pkg/shim/v1/shim/service.go:16", + "pkg/shim/v1/utils/annotations.go:15", + "pkg/shim/v1/utils/utils.go:15", + "pkg/shim/v1/utils/volumes.go:15", + "pkg/shim/v2/api.go:16", + "pkg/shim/v2/epoll.go:18", + "pkg/shim/v2/options/options.go:15", + "pkg/shim/v2/options/options.go:24", + "pkg/shim/v2/options/options.go:26", + "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16", + "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22", + "pkg/shim/v2/service.go:15", + "pkg/shim/v2/service_linux.go:18", + "pkg/state/tests/integer_test.go:23", + "pkg/state/tests/integer_test.go:28", + "pkg/sync/rwmutex_test.go:105", + "pkg/syserr/host_linux.go:35", + "pkg/tcpip/adapters/gonet/gonet_test.go:144", + "pkg/tcpip/adapters/gonet/gonet_test.go:415", + "pkg/tcpip/adapters/gonet/gonet_test.go:99", + "pkg/tcpip/buffer/view.go:238", + "pkg/tcpip/buffer/view.go:238", + "pkg/tcpip/buffer/view.go:246", + "pkg/tcpip/header/tcp.go:151", + "pkg/tcpip/link/sharedmem/pipe/pipe_test.go:493", + "pkg/tcpip/stack/iptables.go:293", + "pkg/tcpip/stack/iptables_types.go:277", + "pkg/tcpip/stack/stack.go:553", + "pkg/tcpip/stack/transport_test.go:30", + "pkg/tcpip/transport/packet/endpoint.go:126", + "pkg/tcpip/transport/raw/endpoint.go:145", + "pkg/tcpip/transport/tcp/sack_scoreboard.go:167", + "pkg/unet/unet_test.go:634", + "pkg/unet/unet_test.go:662", + "pkg/unet/unet_test.go:703", + "pkg/unet/unet_test.go:98", + "pkg/usermem/addr.go:34", + "pkg/usermem/usermem.go:171", + "pkg/usermem/usermem.go:170", + "runsc/boot/compat.go:22", + "runsc/boot/compat.go:56", + "runsc/boot/loader.go:1115", + "runsc/boot/loader.go:1120", + "runsc/cmd/checkpoint.go:151", + "runsc/config/flags.go:32", + "runsc/container/container.go:641", + "runsc/container/container.go:988", + "runsc/specutils/specutils.go:172", + "runsc/specutils/specutils.go:428", + "runsc/specutils/specutils.go:436", + "runsc/specutils/specutils.go:442", + "runsc/specutils/specutils.go:447", + "runsc/specutils/specutils.go:454", + "test/cmd/test_app/fds.go:171", + "test/iptables/filter_output.go:251", + "test/packetimpact/testbench/connections.go:77", + "tools/bigquery/bigquery.go:106", + "tools/checkescape/test1/test1.go:108", + "tools/checkescape/test1/test1.go:122", + "tools/checkescape/test1/test1.go:137", + "tools/checkescape/test1/test1.go:151", + "tools/checkescape/test1/test1.go:170", + "tools/checkescape/test1/test1.go:39", + "tools/checkescape/test1/test1.go:45", + "tools/checkescape/test1/test1.go:50", + "tools/checkescape/test1/test1.go:64", + "tools/checkescape/test1/test1.go:80", + "tools/checkescape/test1/test1.go:94", + "tools/go_generics/imports.go:51", + "tools/go_generics/imports.go:75", + "tools/go_marshal/gomarshal/generator.go:177", + "tools/go_marshal/gomarshal/generator.go:81", + "tools/go_marshal/gomarshal/generator.go:85", + "tools/go_marshal/test/escape/escape.go:15", + "tools/go_marshal/test/test.go:164", + ), + ) + + // Add all staticcheck analyzers; internal only. + for _, a := range staticcheck.Analyzers { + analyzerConfig[a] = staticMatcher + } + // Add all stylecheck analyzers; internal only. + for _, a := range stylecheck.Analyzers { + analyzerConfig[a] = staticMatcher + } +} + var escapesConfig = map[*analysis.Analyzer]matcher{ // Informational only: include all packages. checkescape.EscapeAnalyzer: alwaysMatches(), diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 480438047..c6fcfd402 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -2,8 +2,7 @@ load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") -def _nogo_dump_tool_impl(ctx): - # Extract the Go context. +def _nogo_objdump_tool_impl(ctx): go_ctx = go_context(ctx) # Construct the magic dump command. @@ -40,9 +39,9 @@ def _nogo_dump_tool_impl(ctx): executable = dumper, )] -nogo_dump_tool = go_rule( +nogo_objdump_tool = go_rule( rule, - implementation = _nogo_dump_tool_impl, + implementation = _nogo_objdump_tool_impl, ) # NogoStdlibInfo is the set of standard library facts. @@ -55,7 +54,6 @@ NogoStdlibInfo = provider( ) def _nogo_stdlib_impl(ctx): - # Extract the Go context. go_ctx = go_context(ctx) # Build the standard library facts. @@ -72,12 +70,12 @@ def _nogo_stdlib_impl(ctx): ctx.actions.run( inputs = [config_file] + go_ctx.stdlib_srcs, outputs = [facts, findings], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), + tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), executable = ctx.files._nogo[0], mnemonic = "GoStandardLibraryAnalysis", progress_message = "Analyzing Go Standard Library", arguments = go_ctx.nogo_args + [ - "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, "-stdlib=%s" % config_file.path, "-findings=%s" % findings.path, "-facts=%s" % facts.path, @@ -97,8 +95,8 @@ nogo_stdlib = go_rule( "_nogo": attr.label( default = "//tools/nogo/check:check", ), - "_dump_tool": attr.label( - default = "//tools/nogo:dump_tool", + "_objdump_tool": attr.label( + default = "//tools/nogo:objdump_tool", ), }, ) @@ -121,6 +119,8 @@ NogoInfo = provider( ) def _nogo_aspect_impl(target, ctx): + go_ctx = go_context(ctx) + # If this is a nogo rule itself (and not the shadow of a go_library or # go_binary rule created by such a rule), then we simply return nothing. # All work is done in the shadow properties for go rules. For a proto @@ -135,9 +135,6 @@ def _nogo_aspect_impl(target, ctx): else: return [NogoInfo()] - # Extract the Go context. - go_ctx = go_context(ctx) - # If we're using the "library" attribute, then we need to aggregate the # original library sources and dependencies into this target to perform # proper type analysis. @@ -227,13 +224,13 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = inputs, outputs = [facts, findings, escapes], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), + tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), executable = ctx.files._nogo[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, arguments = go_ctx.nogo_args + [ "-binary=%s" % target_objfile.path, - "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, "-package=%s" % config_file.path, "-findings=%s" % findings.path, "-facts=%s" % facts.path, @@ -271,58 +268,48 @@ nogo_aspect = go_rule( attrs = { "_nogo": attr.label(default = "//tools/nogo/check:check"), "_nogo_stdlib": attr.label(default = "//tools/nogo:stdlib"), - "_dump_tool": attr.label(default = "//tools/nogo:dump_tool"), + "_objdump_tool": attr.label(default = "//tools/nogo:objdump_tool"), }, ) def _nogo_test_impl(ctx): """Check nogo findings.""" - # Build a runner that checks for the existence of the facts file. Note that - # the actual build will fail in the case of a broken analysis. We things - # this way so that any test applied is effectively pushed down to all - # upstream dependencies through the aspect. - inputs = [] - findings = [] - runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) - runner_content = ["#!/bin/bash"] - for dep in ctx.attr.deps: - # Extract the findings. - info = dep[NogoInfo] - inputs.append(info.findings) - findings.append(info.findings) - - # Include all source files, transitively. This will make this target - # "directly affected" for the purpose of build analysis. - inputs += info.srcs - - # If there are findings, dump them and fail. - runner_content.append("if [[ -s \"%s\" ]]; then cat \"%s\" && exit 1; fi" % ( - info.findings.short_path, - info.findings.short_path, - )) - - # Otherwise, draw a sweet unicode checkmark with the package name (in green). - runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath) - runner_content.append("exit 0\n") - ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) + # Build a runner that checks the facts files. + findings = [dep[NogoInfo].findings for dep in ctx.attr.deps] + runner = ctx.actions.declare_file(ctx.label.name) + ctx.actions.run( + inputs = findings + ctx.files.srcs, + outputs = [runner], + tools = depset(ctx.files._gentest), + executable = ctx.files._gentest[0], + mnemonic = "Gentest", + progress_message = "Generating %s" % ctx.label, + arguments = [runner.path] + [f.path for f in findings], + ) return [DefaultInfo( - runfiles = ctx.runfiles(files = inputs), executable = runner, )] _nogo_test = rule( implementation = _nogo_test_impl, attrs = { + # deps should have only a single element. "deps": attr.label_list(aspects = [nogo_aspect]), + # srcs exist here only to ensure that this target is + # directly affected by changes to the source files. + "srcs": attr.label_list(allow_files = True), + "_gentest": attr.label(default = "//tools/nogo:gentest"), }, test = True, ) -def nogo_test(name, **kwargs): +def nogo_test(name, srcs, library, **kwargs): tags = kwargs.pop("tags", []) + ["nogo"] _nogo_test( name = name, + srcs = srcs, + deps = [library], tags = tags, **kwargs ) diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh new file mode 100755 index 000000000..033da11ad --- /dev/null +++ b/tools/nogo/gentest.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -euo pipefail + +if [[ "$#" -lt 2 ]]; then + echo "usage: $0 <output> <findings...>" + exit 2 +fi +declare violations=0 +declare output=$1 +shift + +# Start the script. +echo "#!/bin/sh" > "${output}" + +# Read a list of findings files. +declare filename +declare line +for filename in "$@"; do + if [[ -z "${filename}" ]]; then + continue + fi + while read -r line; do + violations=$((${violations}+1)); + echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}" + done < "${filename}" +done + +# Show violations. +if [[ "${violations}" -eq 0 ]]; then + echo "echo -e '\\033[0;32mPASS\\033[0;31m\\033[0m'" >> "${output}" +else + echo "exit 1" >> "${output}" +fi diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go index 5c39be630..b7b73fa27 100644 --- a/tools/nogo/matchers.go +++ b/tools/nogo/matchers.go @@ -102,6 +102,14 @@ func internalMatches() *pathRegexps { } } +// generatedExcluded excludes all generated code. +func generatedExcluded() *pathRegexps { + return &pathRegexps{ + expr: buildRegexps(generatedPrefix, ".*"), + include: false, + } +} + // resultExcluded excludes explicit message contents. type resultExcluded []string @@ -117,20 +125,23 @@ func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bo // andMatcher is a composite matcher. type andMatcher struct { - first matcher - second matcher + all []matcher } // ShouldReport implements matcher.ShouldReport. func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { - return a.first.ShouldReport(d, fs) && a.second.ShouldReport(d, fs) + for _, m := range a.all { + if !m.ShouldReport(d, fs) { + return false + } + } + return true } // and is a syntactic convension for andMatcher. -func and(first matcher, second matcher) *andMatcher { +func and(ms ...matcher) *andMatcher { return &andMatcher{ - first: first, - second: second, + all: ms, } } diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD new file mode 100644 index 000000000..7d9c9a3fb --- /dev/null +++ b/tools/parsers/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "parsers_test", + size = "small", + srcs = ["go_parser_test.go"], + library = ":parsers", + deps = [ + "//tools/bigquery", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) + +go_library( + name = "parsers", + testonly = 1, + srcs = [ + "go_parser.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//test/benchmarks/tools", + "//tools/bigquery", + ], +) diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go new file mode 100644 index 000000000..2cf74c883 --- /dev/null +++ b/tools/parsers/go_parser.go @@ -0,0 +1,151 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parsers holds parsers to parse Benchmark test output. +// +// Parsers parse Benchmark test output and place it in BigQuery +// structs for sending to BigQuery databases. +package parsers + +import ( + "fmt" + "strconv" + "strings" + + "gvisor.dev/gvisor/test/benchmarks/tools" + "gvisor.dev/gvisor/tools/bigquery" +) + +// parseOutput expects golang benchmark output returns a Benchmark struct formatted for BigQuery. +func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*bigquery.Benchmark, error) { + var benchmarks []*bigquery.Benchmark + lines := strings.Split(output, "\n") + for _, line := range lines { + bm, err := parseLine(line, metadata, official) + if err != nil { + return nil, fmt.Errorf("failed to parse line '%s': %v", line, err) + } + if bm != nil { + benchmarks = append(benchmarks, bm) + } + } + return benchmarks, nil +} + +// parseLine handles parsing a benchmark line into a bigquery.Benchmark. +// +// Example: "BenchmarkRuby/server_threads.1-6 1 1397875880 ns/op 140 requests_per_second.QPS" +// +// This function will return the following benchmark: +// *bigquery.Benchmark{ +// Name: BenchmarkRuby +// []*bigquery.Condition{ +// {Name: GOMAXPROCS, 6} +// {Name: server_threads, 1} +// } +// []*bigquery.Metric{ +// {Name: ns/op, Unit: ns/op, Sample: 1397875880} +// {Name: requests_per_second, Unit: QPS, Sample: 140 } +// } +// Metadata: metadata +//} +func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigquery.Benchmark, error) { + fields := strings.Fields(line) + + // Check if this line is a Benchmark line. Otherwise ignore the line. + if len(fields) < 2 || !strings.HasPrefix(fields[0], "Benchmark") { + return nil, nil + } + + iters, err := strconv.Atoi(fields[1]) + if err != nil { + return nil, fmt.Errorf("expecting number of runs, got %s: %v", fields[1], err) + } + + name, params, err := parseNameParams(fields[0]) + if err != nil { + return nil, fmt.Errorf("parse name/params: %v", err) + } + + bm := bigquery.NewBenchmark(name, iters, official) + bm.Metadata = metadata + for _, p := range params { + bm.AddCondition(p.Name, p.Value) + } + + for i := 1; i < len(fields)/2; i++ { + value := fields[2*i] + metric := fields[2*i+1] + if err := makeMetric(bm, value, metric); err != nil { + return nil, fmt.Errorf("makeMetric on metric %q value: %s: %v", metric, value, err) + } + } + return bm, nil +} + +// parseNameParams parses the Name, GOMAXPROCS, and Params from the test. +// Field here should be of the format TESTNAME/PARAMS-GOMAXPROCS. +// Parameters will be separated by a "/" with individual params being +// "name.value". +func parseNameParams(field string) (string, []*tools.Parameter, error) { + var params []*tools.Parameter + // Remove GOMAXPROCS from end. + maxIndex := strings.LastIndex(field, "-") + if maxIndex < 0 { + return "", nil, fmt.Errorf("GOMAXPROCS not found: %s", field) + } + maxProcs := field[maxIndex+1:] + params = append(params, &tools.Parameter{ + Name: "GOMAXPROCS", + Value: maxProcs, + }) + + remainder := field[0:maxIndex] + index := strings.Index(remainder, "/") + if index == -1 { + return remainder, params, nil + } + + name := remainder[0:index] + p := remainder[index+1:] + + ps, err := tools.NameToParameters(p) + if err != nil { + return "", nil, fmt.Errorf("NameToParameters %s: %v", field, err) + } + params = append(params, ps...) + return name, params, nil +} + +// makeMetric parses metrics and adds them to the passed Benchmark. +func makeMetric(bm *bigquery.Benchmark, value, metric string) error { + switch metric { + // Ignore most output from golang benchmarks. + case "MB/s", "B/op", "allocs/op": + return nil + case "ns/op": + val, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("ParseFloat %s: %v", value, err) + } + bm.AddMetric(metric /*metric name*/, metric /*unit*/, val /*sample*/) + default: + m, err := tools.ParseCustomMetric(value, metric) + if err != nil { + return fmt.Errorf("ParseCustomMetric %s: %v ", metric, err) + } + bm.AddMetric(m.Name, m.Unit, m.Sample) + } + return nil +} diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go new file mode 100644 index 000000000..36996b7c8 --- /dev/null +++ b/tools/parsers/go_parser_test.go @@ -0,0 +1,171 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package parsers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/tools/bigquery" +) + +func TestParseLine(t *testing.T) { + testCases := []struct { + name string + data string + want *bigquery.Benchmark + }{ + { + name: "Iperf", + data: "BenchmarkIperf/Upload-6 1 11094914892 ns/op 4751711232 bandwidth.bytes_per_second", + want: &bigquery.Benchmark{ + Name: "BenchmarkIperf", + Condition: []*bigquery.Condition{ + { + Name: "GOMAXPROCS", + Value: "6", + }, + { + Name: "Upload", + Value: "Upload", + }, + }, + Metric: []*bigquery.Metric{ + { + Name: "ns/op", + Unit: "ns/op", + Sample: 11094914892.0, + }, + { + Name: "bandwidth", + Unit: "bytes_per_second", + Sample: 4751711232.0, + }, + }, + }, + }, + { + name: "Ruby", + data: "BenchmarkRuby/server_threads.1-6 1 1397875880 ns/op 0.00710 average_latency.s 140 requests_per_second.QPS", + want: &bigquery.Benchmark{ + Name: "BenchmarkRuby", + Condition: []*bigquery.Condition{ + { + Name: "GOMAXPROCS", + Value: "6", + }, + { + Name: "server_threads", + Value: "1", + }, + }, + Metric: []*bigquery.Metric{ + { + Name: "ns/op", + Unit: "ns/op", + Sample: 1397875880.0, + }, + { + Name: "average_latency", + Unit: "s", + Sample: 0.00710, + }, + { + Name: "requests_per_second", + Unit: "QPS", + Sample: 140.0, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := parseLine(tc.data, nil, false) + if err != nil { + t.Fatalf("parseLine failed with: %v", err) + } + + tc.want.Timestamp = got.Timestamp + + if !cmp.Equal(tc.want, got, nil) { + for _, c := range got.Condition { + t.Logf("Cond: %+v", c) + } + for _, m := range got.Metric { + t.Logf("Metric: %+v", m) + } + t.Fatalf("Compare failed want: %+v got: %+v", tc.want, got) + } + }) + + } +} + +func TestParseOutput(t *testing.T) { + testCases := []struct { + name string + data string + numBenchmarks int + numMetrics int + numConditions int + }{ + { + name: "Startup", + data: ` + BenchmarkStartupEmpty + BenchmarkStartupEmpty-6 2 766377884 ns/op 1 allocs/op + BenchmarkStartupNode + BenchmarkStartupNode-6 1 1752158409 ns/op 1 allocs/op + `, + numBenchmarks: 2, + numMetrics: 1, + numConditions: 1, + }, + { + name: "Ruby", + data: `BenchmarkRuby +BenchmarkRuby/server_threads.1 +BenchmarkRuby/server_threads.1-6 1 1397875880 ns/op 0.00710 average_latency.s 140 requests_per_second.QPS +BenchmarkRuby/server_threads.5 +BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 465 requests_per_second.QPS`, + numBenchmarks: 2, + numMetrics: 3, + numConditions: 2, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bms, err := parseOutput(tc.data, nil, false) + if err != nil { + t.Fatalf("parseOutput failed: %v", err) + } else if len(bms) != tc.numBenchmarks { + t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(bms), bms) + } + + for _, bm := range bms { + if len(bm.Metric) != tc.numMetrics { + t.Fatalf("NumMetrics failed want: %d got: %d %+v", tc.numMetrics, len(bm.Metric), bm.Metric) + } + + if len(bm.Condition) != tc.numConditions { + t.Fatalf("NumConditions failed want: %d got: %d %+v", tc.numConditions, len(bm.Condition), bm.Condition) + } + } + }) + } +} |