diff options
145 files changed, 4189 insertions, 1521 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 71650a4b8..ad8e710da 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,9 @@ will need to be added to the appropriate `BUILD` files, and the `:gopath` target will need to be re-run to generate appropriate symlinks in the `GOPATH` directory tree. +Dependencies can be added by using `go mod get`. In order to keep the +`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`. + ### Coding Guidelines All Go code should conform to the [Go style guidelines][gostyle]. C++ code @@ -4,10 +4,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # Load go bazel rules and gazelle. http_archive( name = "io_bazel_rules_go", - sha256 = "f99a9d76e972e0c8f935b2fe6d0d9d778f67c760c6d2400e23fc2e469016e2bd", + sha256 = "94f90feaa65c9cdc840cd21f67d967870b5943d684966a47569da8073e42063d", urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.22.0/rules_go-v0.22.0.tar.gz", ], ) @@ -20,12 +20,12 @@ http_archive( ], ) -load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_toolchains") +load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() go_register_toolchains( - go_version = "1.13.7", + go_version = "1.14", nogo = "@//:nogo", ) @@ -43,8 +43,8 @@ gazelle_dependencies() go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", - sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=", - version = "v0.0.0-20191220220014-0732a990476f", + sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=", + version = "v0.0.0-20200302150141-5c8b2ff67527", ) # Load C++ rules. @@ -68,16 +68,19 @@ http_archive( "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", ], ) + load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") + rules_proto_dependencies() + rules_proto_toolchains() # Load python dependencies. git_repository( name = "rules_python", - commit = "94677401bc56ed5d756f50b441a6a5c7f735a6d4", + commit = "abc4869e02fe9b3866942e89f07b7341f830e805", remote = "https://github.com/bazelbuild/rules_python.git", - shallow_since = "1573842889 -0500", + shallow_since = "1583341286 -0500", ) load("@rules_python//python:pip.bzl", "pip_import") @@ -96,11 +99,11 @@ pip_install() # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( name = "bazel_toolchains", - sha256 = "a653c9d318e42b14c0ccd7ac50c4a2a276c0db1e39743ab88b5aa2f0bc9cf607", - strip_prefix = "bazel-toolchains-2.0.2", + sha256 = "b5a8039df7119d618402472f3adff8a1bd0ae9d5e253f53fcc4c47122e91a3d2", + strip_prefix = "bazel-toolchains-2.1.1", urls = [ - "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.0.2/bazel-toolchains-2.0.2.tar.gz", - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.0.2.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/releases/download/2.1.1/bazel-toolchains-2.1.1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2.1.1.tar.gz", ], ) @@ -146,9 +149,9 @@ load( # This container is built from the Dockerfile in test/iptables/runner. container_pull( name = "iptables-test", + digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", registry = "gcr.io", repository = "gvisor-presubmit/iptables-test", - digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", ) load( @@ -202,6 +205,13 @@ go_repository( ) go_repository( + name = "com_github_kr_pretty", + importpath = "github.com/kr/pretty", + sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=", + version = "v0.2.0", +) + +go_repository( name = "com_github_kr_pty", importpath = "github.com/kr/pty", sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=", @@ -209,6 +219,13 @@ go_repository( ) go_repository( + name = "com_github_kr_text", + importpath = "github.com/kr/text", + sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=", + version = "v0.1.0", +) + +go_repository( name = "com_github_opencontainers_runtime-spec", importpath = "github.com/opencontainers/runtime-spec", sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=", @@ -237,30 +254,31 @@ go_repository( ) go_repository( - name = "org_golang_x_crypto", - importpath = "golang.org/x/crypto", - sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=", - version = "v0.0.0-20190308221718-c2843e01d9a2", + name = "in_gopkg_check_v1", + importpath = "gopkg.in/check.v1", + sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=", + version = "v1.0.0-20190902080502-41f04d3bba15", ) go_repository( - name = "org_golang_x_net", - importpath = "golang.org/x/net", - sum = "h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=", - version = "v0.0.0-20190311183353-d8887717615a", + name = "org_golang_x_crypto", + importpath = "golang.org/x/crypto", + sum = "h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=", + version = "v0.0.0-20191011191535-87dc89f01550", ) go_repository( - name = "org_golang_x_text", - importpath = "golang.org/x/text", - sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=", - version = "v0.3.0", + name = "org_golang_x_mod", + importpath = "golang.org/x/mod", + sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=", + version = "v0.2.1-0.20200224194123-e5e73c1b9c72", ) go_repository( - name = "org_golang_x_tools", - commit = "36563e24a262", - importpath = "golang.org/x/tools", + name = "org_golang_x_net", + importpath = "golang.org/x/net", + sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=", + version = "v0.0.0-20190620200207-3b0461eec859", ) go_repository( @@ -271,15 +289,31 @@ go_repository( ) go_repository( + name = "org_golang_x_text", + importpath = "golang.org/x/text", + sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=", + version = "v0.3.0", +) + +go_repository( name = "org_golang_x_time", - commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e", importpath = "golang.org/x/time", + sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=", + version = "v0.0.0-20191024005414-555d28b269f0", ) go_repository( name = "org_golang_x_tools", - commit = "aa82965741a9fecd12b026fbb3d3c6ed3231b8f8", importpath = "golang.org/x/tools", + sum = "h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ=", + version = "v0.0.0-20191119224855-298f0cb1881e", +) + +go_repository( + name = "org_golang_x_xerrors", + importpath = "golang.org/x/xerrors", + sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", + version = "v0.0.0-20191204190536-9bdfabe68543", ) go_repository( diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock index b44817bd3..17ebcbec3 100644 --- a/benchmarks/workloads/ruby/Gemfile.lock +++ b/benchmarks/workloads/ruby/Gemfile.lock @@ -1,28 +1,41 @@ GEM remote: https://rubygems.org/ specs: + activemerchant (1.105.0) + activesupport (>= 4.2) + builder (>= 2.1.2, < 4.0.0) + i18n (>= 0.6.9) + nokogiri (~> 1.4) activesupport (5.2.3) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 0.7, < 2) minitest (~> 5.1) tzinfo (~> 1.1) + bcrypt (3.1.13) + builder (3.2.4) cassandra-driver (3.2.3) ione (~> 1.2) concurrent-ruby (1.1.5) + ffi (1.12.2) i18n (1.6.0) concurrent-ruby (~> 1.0) ione (1.2.4) + mini_portile2 (2.4.0) minitest (5.11.3) mustermann (1.0.3) + nokogiri (1.10.8) + mini_portile2 (~> 2.4.0) pdf-core (0.7.0) prawn (2.2.2) pdf-core (~> 0.7.0) ttfunk (~> 1.5) - puma (3.12.1) - rack (2.0.7) + puma (3.12.2) + rack (2.2.2) rack-protection (2.0.5) rack - rake (12.3.2) + rake (12.3.3) + rbnacl (7.1.1) + ffi redis (4.1.1) ruby-fann (1.2.6) sinatra (2.0.5) @@ -43,9 +56,12 @@ PLATFORMS ruby DEPENDENCIES + activemerchant + bcrypt cassandra-driver puma rake + rbnacl redis ruby-fann sinatra diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock index dd8d56fb7..af03ed7fd 100644 --- a/benchmarks/workloads/ruby_template/Gemfile.lock +++ b/benchmarks/workloads/ruby_template/Gemfile.lock @@ -2,25 +2,25 @@ GEM remote: https://rubygems.org/ specs: mustermann (1.0.3) - puma (3.12.0) + puma (3.12.2) rack (2.0.6) rack-protection (2.0.5) rack + redis (4.1.0) sinatra (2.0.5) mustermann (~> 1.0) rack (~> 2.0) rack-protection (= 2.0.5) tilt (~> 2.0) tilt (2.0.9) - redis (4.1.0) PLATFORMS ruby DEPENDENCIES puma - sinatra redis + sinatra BUNDLED WITH 1.17.1
\ No newline at end of file @@ -1,23 +1,20 @@ module gvisor.dev/gvisor -go 1.13 +go 1.14 require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/mock v1.3.1 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/go-cmp v0.2.0 - github.com/google/go-github/v28 v28.1.1 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6 - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 + github.com/golang/protobuf v1.3.1 + github.com/google/btree v1.0.0 + github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 + github.com/kr/pretty v0.2.0 // indirect + github.com/kr/pty v1.1.1 + github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 + github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e + github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect + golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) @@ -1,21 +1,32 @@ +github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8= github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM= +github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI= github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI= github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q= github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk= github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD new file mode 100644 index 000000000..a77a3beea --- /dev/null +++ b/pkg/buffer/BUILD @@ -0,0 +1,39 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "buffer_list", + out = "buffer_list.go", + package = "buffer", + prefix = "buffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Buffer", + "Linker": "*Buffer", + }, +) + +go_library( + name = "buffer", + srcs = [ + "buffer.go", + "buffer_list.go", + "safemem.go", + "view.go", + "view_unsafe.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/safemem", + ], +) + +go_test( + name = "buffer_test", + size = "small", + srcs = ["view_test.go"], + library = ":buffer", +) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go new file mode 100644 index 000000000..d5f64609b --- /dev/null +++ b/pkg/buffer/buffer.go @@ -0,0 +1,67 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package buffer provides the implementation of a buffer view. +package buffer + +import ( + "sync" +) + +const bufferSize = 8144 // See below. + +// Buffer encapsulates a queueable byte buffer. +// +// Note that the total size is slightly less than two pages. This is done +// intentionally to ensure that the buffer object aligns with runtime +// internals. We have no hard size or alignment requirements. This two page +// size will effectively minimize internal fragmentation, but still have a +// large enough chunk to limit excessive segmentation. +// +// +stateify savable +type Buffer struct { + data [bufferSize]byte + read int + write int + bufferEntry +} + +// Reset resets internal data. +// +// This must be called before use. +func (b *Buffer) Reset() { + b.read = 0 + b.write = 0 +} + +// Empty indicates the buffer is empty. +// +// This indicates there is no data left to read. +func (b *Buffer) Empty() bool { + return b.read == b.write +} + +// Full indicates the buffer is full. +// +// This indicates there is no capacity left to write. +func (b *Buffer) Full() bool { + return b.write == len(b.data) +} + +// bufferPool is a pool for buffers. +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(Buffer) + }, +} diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go new file mode 100644 index 000000000..071aaa488 --- /dev/null +++ b/pkg/buffer/safemem.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "io" + + "gvisor.dev/gvisor/pkg/safemem" +) + +// WriteBlock returns this buffer as a write Block. +func (b *Buffer) WriteBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.data[b.write:]) +} + +// ReadBlock returns this buffer as a read Block. +func (b *Buffer) ReadBlock() safemem.Block { + return safemem.BlockFromSafeSlice(b.data[b.read:b.write]) +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +// +// This will advance the write index. +func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + need := int(srcs.NumBytes()) + if need == 0 { + return 0, nil + } + + var ( + dst safemem.BlockSeq + blocks []safemem.Block + ) + + // Need at least one buffer. + firstBuf := v.data.Back() + if firstBuf == nil { + firstBuf = bufferPool.Get().(*Buffer) + v.data.PushBack(firstBuf) + } + + // Does the last block have sufficient capacity alone? + if l := len(firstBuf.data) - firstBuf.write; l >= need { + dst = safemem.BlockSeqOf(firstBuf.WriteBlock()) + } else { + // Append blocks until sufficient. + need -= l + blocks = append(blocks, firstBuf.WriteBlock()) + for need > 0 { + emptyBuf := bufferPool.Get().(*Buffer) + v.data.PushBack(emptyBuf) + need -= len(emptyBuf.data) // Full block. + blocks = append(blocks, emptyBuf.WriteBlock()) + } + dst = safemem.BlockSeqFromSlice(blocks) + } + + // Perform the copy. + n, err := safemem.CopySeq(dst, srcs) + v.size += int64(n) + + // Update all indices. + for left := int(n); left > 0; firstBuf = firstBuf.Next() { + if l := len(firstBuf.data) - firstBuf.write; left >= l { + firstBuf.write += l // Whole block. + left -= l + } else { + firstBuf.write += left // Partial block. + left = 0 + } + } + + return n, err +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +// +// This will not advance the read index; the caller should follow +// this call with a call to TrimFront in order to remove the read +// data from the buffer. This is done to support pipe sematics. +func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + need := int(dsts.NumBytes()) + if need == 0 { + return 0, nil + } + + var ( + src safemem.BlockSeq + blocks []safemem.Block + ) + + firstBuf := v.data.Front() + if firstBuf == nil { + return 0, io.EOF + } + + // Is all the data in a single block? + if l := firstBuf.write - firstBuf.read; l >= need { + src = safemem.BlockSeqOf(firstBuf.ReadBlock()) + } else { + // Build a list of all the buffers. + need -= l + blocks = append(blocks, firstBuf.ReadBlock()) + for buf := firstBuf.Next(); buf != nil && need > 0; buf = buf.Next() { + need -= buf.write - buf.read + blocks = append(blocks, buf.ReadBlock()) + } + src = safemem.BlockSeqFromSlice(blocks) + } + + // Perform the copy. + n, err := safemem.CopySeq(dsts, src) + + // See above: we would normally advance the read index here, but we + // don't do that in order to support pipe semantics. We rely on a + // separate call to TrimFront() in this case. + + return n, err +} diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go new file mode 100644 index 000000000..00fc11e9c --- /dev/null +++ b/pkg/buffer/view.go @@ -0,0 +1,382 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "fmt" + "io" +) + +// View is a non-linear buffer. +// +// All methods are thread compatible. +// +// +stateify savable +type View struct { + data bufferList + size int64 +} + +// TrimFront removes the first count bytes from the buffer. +func (v *View) TrimFront(count int64) { + if count >= v.size { + v.advanceRead(v.size) + } else { + v.advanceRead(count) + } +} + +// Read implements io.Reader.Read. +// +// Note that reading does not advance the read index. This must be done +// manually using TrimFront or other methods. +func (v *View) Read(p []byte) (int, error) { + return v.ReadAt(p, 0) +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (v *View) ReadAt(p []byte, offset int64) (int, error) { + var ( + skipped int64 + done int64 + ) + for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() { + needToSkip := int(offset - skipped) + if l := buf.write - buf.read; l <= needToSkip { + skipped += int64(l) + continue + } + + // Actually read data. + n := copy(p[done:], buf.data[buf.read+needToSkip:buf.write]) + skipped += int64(needToSkip) + done += int64(n) + } + if int(done) < len(p) { + return int(done), io.EOF + } + return int(done), nil +} + +// Write implements io.Writer.Write. +func (v *View) Write(p []byte) (int, error) { + v.Append(p) // Does not fail. + return len(p), nil +} + +// advanceRead advances the view's read index. +// +// Precondition: there must be sufficient bytes in the buffer. +func (v *View) advanceRead(count int64) { + for buf := v.data.Front(); buf != nil && count > 0; { + l := int64(buf.write - buf.read) + if l > count { + // There is still data for reading. + buf.read += int(count) + v.size -= count + count = 0 + break + } + + // Read from this buffer. + buf.read += int(l) + count -= l + v.size -= l + + // When all data has been read from a buffer, we push + // it into the empty buffer pool for reuse. + oldBuf := buf + buf = buf.Next() // Iterate. + v.data.Remove(oldBuf) + oldBuf.Reset() + bufferPool.Put(oldBuf) + } + if count > 0 { + panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count)) + } +} + +// Truncate truncates the view to the given bytes. +func (v *View) Truncate(length int64) { + if length < 0 || length >= v.size { + return // Nothing to do. + } + for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() { + l := int64(buf.write - buf.read) // Local bytes. + switch { + case v.size-l >= length: + // Drop the buffer completely; see above. + v.data.Remove(buf) + v.size -= l + buf.Reset() + bufferPool.Put(buf) + + case v.size > length && v.size-l < length: + // Just truncate the buffer locally. + delta := (length - (v.size - l)) + buf.write = buf.read + int(delta) + v.size = length + + default: + // Should never happen. + panic("invalid buffer during truncation") + } + } + v.size = length // Save the new size. +} + +// Grow grows the given view to the number of bytes. If zero +// is true, all these bytes will be zero. If zero is false, +// then this is the caller's responsibility. +// +// Precondition: length must be >= 0. +func (v *View) Grow(length int64, zero bool) { + if length < 0 { + panic("negative length provided") + } + for v.size < length { + buf := v.data.Back() + + // Is there at least one buffer? + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Write up to length bytes. + l := len(buf.data) - buf.write + if int64(l) > length-v.size { + l = int(length - v.size) + } + + // Zero the written section; note that this pattern is + // specifically recognized and optimized by the compiler. + if zero { + for i := buf.write; i < buf.write+l; i++ { + buf.data[i] = 0 + } + } + + // Advance the index. + buf.write += l + v.size += int64(l) + } +} + +// Prepend prepends the given data. +func (v *View) Prepend(data []byte) { + // Is there any space in the first buffer? + if buf := v.data.Front(); buf != nil && buf.read > 0 { + // Fill up before the first write. + avail := buf.read + copy(buf.data[0:], data[len(data)-avail:]) + data = data[:len(data)-avail] + v.size += int64(avail) + } + + for len(data) > 0 { + // Do we need an empty buffer? + buf := bufferPool.Get().(*Buffer) + v.data.PushFront(buf) + + // The buffer is empty; copy last chunk. + start := len(data) - len(buf.data) + if start < 0 { + start = 0 // Everything. + } + + // We have to put the data at the end of the current + // buffer in order to ensure that the next prepend will + // correctly fill up the beginning of this buffer. + bStart := len(buf.data) - len(data[start:]) + n := copy(buf.data[bStart:], data[start:]) + buf.read = bStart + buf.write = len(buf.data) + data = data[:start] + v.size += int64(n) + } +} + +// Append appends the given data. +func (v *View) Append(data []byte) { + for done := 0; done < len(data); { + buf := v.data.Back() + + // Find the first empty buffer. + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Copy in to the given buffer. + n := copy(buf.data[buf.write:], data[done:]) + done += n + buf.write += n + v.size += int64(n) + } +} + +// Flatten returns a flattened copy of this data. +// +// This method should not be used in any performance-sensitive paths. It may +// allocate a fresh byte slice sufficiently large to contain all the data in +// the buffer. +// +// N.B. Tee data still belongs to this view, as if there is a single buffer +// present, then it will be returned directly. This should be used for +// temporary use only, and a reference to the given slice should not be held. +func (v *View) Flatten() []byte { + if buf := v.data.Front(); buf.Next() == nil { + return buf.data[buf.read:buf.write] // Only one buffer. + } + data := make([]byte, 0, v.size) // Need to flatten. + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + // Copy to the allocated slice. + data = append(data, buf.data[buf.read:buf.write]...) + } + return data +} + +// Size indicates the total amount of data available in this view. +func (v *View) Size() (sz int64) { + sz = v.size // Pre-calculated. + return sz +} + +// Copy makes a strict copy of this view. +func (v *View) Copy() (other View) { + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + other.Append(buf.data[buf.read:buf.write]) + } + return other +} + +// Apply applies the given function across all valid data. +func (v *View) Apply(fn func([]byte)) { + for buf := v.data.Front(); buf != nil; buf = buf.Next() { + if l := int64(buf.write - buf.read); l > 0 { + fn(buf.data[buf.read:buf.write]) + } + } +} + +// Merge merges the provided View with this one. +// +// The other view will be empty after this operation. +func (v *View) Merge(other *View) { + // Copy over all buffers. + for buf := other.data.Front(); buf != nil && !buf.Empty(); buf = other.data.Front() { + other.data.Remove(buf) + v.data.PushBack(buf) + } + + // Adjust sizes. + v.size += other.size + other.size = 0 +} + +// WriteFromReader writes to the buffer from an io.Reader. +func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) { + var ( + done int64 + n int + err error + ) + for done < count { + buf := v.data.Back() + + // Find the first empty buffer. + if buf == nil || buf.Full() { + buf = bufferPool.Get().(*Buffer) + v.data.PushBack(buf) + } + + // Is this less than the minimum batch? + if len(buf.data[buf.write:]) < minBatch && (count-done) >= int64(minBatch) { + tmp := make([]byte, minBatch) + n, err = r.Read(tmp) + v.Write(tmp[:n]) + done += int64(n) + if err != nil { + break + } + continue + } + + // Limit the read, if necessary. + end := len(buf.data) + if int64(end-buf.write) > (count - done) { + end = buf.write + int(count-done) + } + + // Pass the relevant portion of the buffer. + n, err = r.Read(buf.data[buf.write:end]) + buf.write += n + done += int64(n) + v.size += int64(n) + if err == io.EOF { + err = nil // Short write allowed. + break + } else if err != nil { + break + } + } + return done, err +} + +// ReadToWriter reads from the buffer into an io.Writer. +// +// N.B. This does not consume the bytes read. TrimFront should +// be called appropriately after this call in order to do so. +func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) { + var ( + done int64 + n int + err error + ) + offset := 0 // Spill-over for batching. + for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() { + l := buf.write - buf.read - offset + + // Is this less than the minimum batch? + if l < minBatch && (count-done) >= int64(minBatch) && (v.size-done) >= int64(minBatch) { + tmp := make([]byte, minBatch) + n, err = v.ReadAt(tmp, done) + w.Write(tmp[:n]) + done += int64(n) + offset = n - l // Reset below. + if err != nil { + break + } + continue + } + + // Limit the write if necessary. + if int64(l) >= (count - done) { + l = int(count - done) + } + + // Perform the actual write. + n, err = w.Write(buf.data[buf.read+offset : buf.read+offset+l]) + done += int64(n) + if err != nil { + break + } + + // Reset spill-over. + offset = 0 + } + return done, err +} diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go new file mode 100644 index 000000000..37e652f16 --- /dev/null +++ b/pkg/buffer/view_test.go @@ -0,0 +1,233 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "bytes" + "strings" + "testing" +) + +func TestView(t *testing.T) { + testCases := []struct { + name string + input string + output string + ops []func(*View) + }{ + // Prepend. + { + name: "prepend", + input: "world", + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("hello ")) + }, + }, + output: "hello world", + }, + { + name: "prepend fill", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("0")) + }, + }, + output: "0" + strings.Repeat("1", bufferSize-1), + }, + { + name: "prepend overflow", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte("0")) + }, + }, + output: "0" + strings.Repeat("1", bufferSize), + }, + { + name: "prepend multiple buffers", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Prepend([]byte(strings.Repeat("0", bufferSize*3))) + }, + }, + output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1), + }, + + // Append. + { + name: "append", + input: "hello", + ops: []func(*View){ + func(v *View) { + v.Append([]byte(" world")) + }, + }, + output: "hello world", + }, + { + name: "append fill", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Append([]byte("0")) + }, + }, + output: strings.Repeat("1", bufferSize-1) + "0", + }, + { + name: "append overflow", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Append([]byte("0")) + }, + }, + output: strings.Repeat("1", bufferSize) + "0", + }, + { + name: "append multiple buffers", + input: strings.Repeat("1", bufferSize-1), + ops: []func(*View){ + func(v *View) { + v.Append([]byte(strings.Repeat("0", bufferSize*3))) + }, + }, + output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3), + }, + + // Truncate. + { + name: "truncate", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.Truncate(5) + }, + }, + output: "hello", + }, + { + name: "truncate multiple buffers", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.Truncate(bufferSize*2 - 1) + }, + }, + output: strings.Repeat("1", bufferSize*2-1), + }, + { + name: "truncate multiple buffers to one buffer", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.Truncate(5) + }, + }, + output: "11111", + }, + + // TrimFront. + { + name: "trim", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.TrimFront(6) + }, + }, + output: "world", + }, + { + name: "trim multiple buffers", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.TrimFront(1) + }, + }, + output: strings.Repeat("1", bufferSize*2-1), + }, + { + name: "trim multiple buffers to one buffer", + input: strings.Repeat("1", bufferSize*2), + ops: []func(*View){ + func(v *View) { + v.TrimFront(bufferSize*2 - 1) + }, + }, + output: "1", + }, + + // Grow. + { + name: "grow", + input: "hello world", + ops: []func(*View){ + func(v *View) { + v.Grow(1, true) + }, + }, + output: "hello world", + }, + { + name: "grow from zero", + ops: []func(*View){ + func(v *View) { + v.Grow(1024, true) + }, + }, + output: strings.Repeat("\x00", 1024), + }, + { + name: "grow from non-zero", + input: strings.Repeat("1", bufferSize), + ops: []func(*View){ + func(v *View) { + v.Grow(bufferSize*2, true) + }, + }, + output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Construct the new view. + var view View + view.Append([]byte(tc.input)) + + // Run all operations. + for _, op := range tc.ops { + op(&view) + } + + // Flatten and validate. + out := view.Flatten() + if !bytes.Equal([]byte(tc.output), out) { + t.Errorf("expected %q, got %q", tc.output, string(out)) + } + + // Ensure the size is correct. + if len(out) != int(view.Size()) { + t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size()) + } + }) + } +} diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/pkg/buffer/view_unsafe.go index 4d54b8b8f..d1ef39b26 100644 --- a/pkg/sentry/kernel/pipe/buffer_test.go +++ b/pkg/buffer/view_unsafe.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pipe +package buffer import ( - "testing" "unsafe" - - "gvisor.dev/gvisor/pkg/usermem" ) -func TestBufferSize(t *testing.T) { - bufferSize := unsafe.Sizeof(buffer{}) - if bufferSize < usermem.PageSize { - t.Errorf("buffer is less than a page") - } - if bufferSize > (2 * usermem.PageSize) { - t.Errorf("buffer is greater than two pages") - } -} +// minBatch is the smallest Read or Write operation that the +// WriteFromReader and ReadToWriter functions will use. +// +// This is defined as the size of a native pointer. +const minBatch = int(unsafe.Sizeof(uintptr(0))) diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 019caadca..f3a609b57 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -88,8 +88,9 @@ func (l *List) Back() Element { // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { - ElementMapper{}.linkerFor(e).SetNext(l.head) - ElementMapper{}.linkerFor(e).SetPrev(nil) + linker := ElementMapper{}.linkerFor(e) + linker.SetNext(l.head) + linker.SetPrev(nil) if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) @@ -102,8 +103,9 @@ func (l *List) PushFront(e Element) { // PushBack inserts the element e at the back of list l. func (l *List) PushBack(e Element) { - ElementMapper{}.linkerFor(e).SetNext(nil) - ElementMapper{}.linkerFor(e).SetPrev(l.tail) + linker := ElementMapper{}.linkerFor(e) + linker.SetNext(nil) + linker.SetPrev(l.tail) if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) @@ -132,10 +134,14 @@ func (l *List) PushBackList(m *List) { // InsertAfter inserts e after b. func (l *List) InsertAfter(b, e Element) { - a := ElementMapper{}.linkerFor(b).Next() - ElementMapper{}.linkerFor(e).SetNext(a) - ElementMapper{}.linkerFor(e).SetPrev(b) - ElementMapper{}.linkerFor(b).SetNext(e) + bLinker := ElementMapper{}.linkerFor(b) + eLinker := ElementMapper{}.linkerFor(e) + + a := bLinker.Next() + + eLinker.SetNext(a) + eLinker.SetPrev(b) + bLinker.SetNext(e) if a != nil { ElementMapper{}.linkerFor(a).SetPrev(e) @@ -146,10 +152,13 @@ func (l *List) InsertAfter(b, e Element) { // InsertBefore inserts e before a. func (l *List) InsertBefore(a, e Element) { - b := ElementMapper{}.linkerFor(a).Prev() - ElementMapper{}.linkerFor(e).SetNext(a) - ElementMapper{}.linkerFor(e).SetPrev(b) - ElementMapper{}.linkerFor(a).SetPrev(e) + aLinker := ElementMapper{}.linkerFor(a) + eLinker := ElementMapper{}.linkerFor(e) + + b := aLinker.Prev() + eLinker.SetNext(a) + eLinker.SetPrev(b) + aLinker.SetPrev(e) if b != nil { ElementMapper{}.linkerFor(b).SetNext(e) diff --git a/pkg/safemem/seq_test.go b/pkg/safemem/seq_test.go index eba4bb535..de34005e9 100644 --- a/pkg/safemem/seq_test.go +++ b/pkg/safemem/seq_test.go @@ -20,6 +20,27 @@ import ( "testing" ) +func TestBlockSeqOfEmptyBlock(t *testing.T) { + bs := BlockSeqOf(Block{}) + if !bs.IsEmpty() { + t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs) + } +} + +func TestBlockSeqOfNonemptyBlock(t *testing.T) { + b := BlockFromSafeSlice(make([]byte, 1)) + bs := BlockSeqOf(b) + if bs.IsEmpty() { + t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs) + } + if head := bs.Head(); head != b { + t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b) + } + if tail := bs.Tail(); !tail.IsEmpty() { + t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail) + } +} + type blockSeqTest struct { desc string diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go index dcdfc9600..f5f0574f8 100644 --- a/pkg/safemem/seq_unsafe.go +++ b/pkg/safemem/seq_unsafe.go @@ -56,6 +56,9 @@ type BlockSeq struct { // BlockSeqOf returns a BlockSeq representing the single Block b. func BlockSeqOf(b Block) BlockSeq { + if b.length == 0 { + return BlockSeq{} + } bs := BlockSeq{ data: b.start, length: -1, diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index da5a5e4b2..88766f33b 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -451,7 +451,7 @@ func TestRandom(t *testing.T) { } } - fmt.Printf("Testing filters: %v", syscallRules) + t.Logf("Testing filters: %v", syscallRules) instrs, err := BuildProgram([]RuleSet{ RuleSet{ Rules: syscallRules, diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index d7794db63..5053393c1 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -32,6 +32,12 @@ import ( const ( // SyscallWidth is the width of insturctions. SyscallWidth = 4 + + // fpsimdMagic is the magic number which is used in fpsimd_context. + fpsimdMagic = 0x46508001 + + // fpsimdContextSize is the size of fpsimd_context. + fpsimdContextSize = 0x210 ) // ARMTrapFlag is the mask for the trap flag. @@ -40,24 +46,24 @@ const ARMTrapFlag = uint64(1) << 21 // aarch64FPState is aarch64 floating point state. type aarch64FPState []byte -// initAarch64FPState (defined in asm files) sets up initial state. -func initAarch64FPState(data *FloatingPointData) { - // TODO(gvisor.dev/issue/1238): floating-point is not supported. +// initAarch64FPState sets up initial state. +func initAarch64FPState(data aarch64FPState) { + binary.LittleEndian.PutUint32(data, fpsimdMagic) + binary.LittleEndian.PutUint32(data[4:], fpsimdContextSize) } func newAarch64FPStateSlice() []byte { - return alignedBytes(4096, 32)[:4096] + return alignedBytes(4096, 16)[:fpsimdContextSize] } // newAarch64FPState returns an initialized floating point state. // // The returned state is large enough to store all floating point state // supported by host, even if the app won't use much of it due to a restricted -// FeatureSet. Since they may still be able to see state not advertised by -// CPUID we must ensure it does not contain any sentry state. +// FeatureSet. func newAarch64FPState() aarch64FPState { f := aarch64FPState(newAarch64FPStateSlice()) - initAarch64FPState(f.FloatingPointData()) + initAarch64FPState(f) return f } @@ -136,10 +142,10 @@ func (s State) Proto() *rpb.Registers { // Fork creates and returns an identical copy of the state. func (s *State) Fork() State { - // TODO(gvisor.dev/issue/1238): floating-point is not supported. return State{ - Regs: s.Regs, - FeatureSet: s.FeatureSet, + Regs: s.Regs, + aarch64FPState: s.aarch64FPState.fork(), + FeatureSet: s.FeatureSet, } } @@ -288,8 +294,10 @@ func New(arch Arch, fs *cpuid.FeatureSet) Context { case ARM64: return &context64{ State{ - FeatureSet: fs, + aarch64FPState: newAarch64FPState(), + FeatureSet: fs, }, + []aarch64FPState(nil), } } panic(fmt.Sprintf("unknown architecture %v", arch)) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index ac98897b5..885115ae2 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -53,6 +53,11 @@ const ( preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5 ) +var ( + // CPUIDInstruction doesn't exist on ARM64. + CPUIDInstruction = []byte{} +) + // These constants are selected as heuristics to help make the Platform's // potentially limited address space conform as closely to Linux as possible. const ( @@ -68,6 +73,7 @@ const ( // context64 represents an ARM64 context. type context64 struct { State + sigFPState []aarch64FPState // fpstate to be restored on sigreturn. } // Arch implements Context.Arch. @@ -75,10 +81,19 @@ func (c *context64) Arch() Arch { return ARM64 } +func (c *context64) copySigFPState() []aarch64FPState { + var sigfps []aarch64FPState + for _, s := range c.sigFPState { + sigfps = append(sigfps, s.fork()) + } + return sigfps +} + // Fork returns an exact copy of this context. func (c *context64) Fork() Context { return &context64{ - State: c.State.Fork(), + State: c.State.Fork(), + sigFPState: c.copySigFPState(), } } diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index 4f4cc46a8..0c1db4b13 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -30,14 +30,29 @@ type SignalContext64 struct { Sp uint64 Pc uint64 Pstate uint64 - _pad [8]byte // __attribute__((__aligned__(16))) - Reserved [4096]uint8 + _pad [8]byte // __attribute__((__aligned__(16))) + Fpsimd64 FpsimdContext // size = 528 + Reserved [3568]uint8 +} + +type aarch64Ctx struct { + Magic uint32 + Size uint32 +} + +// FpsimdContext is equivalent to struct fpsimd_context on arm64 +// (arch/arm64/include/uapi/asm/sigcontext.h). +type FpsimdContext struct { + Head aarch64Ctx + Fpsr uint32 + Fpcr uint32 + Vregs [64]uint64 // actually [32]uint128 } // UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h). type UContext64 struct { Flags uint64 - Link *UContext64 + Link uint64 Stack SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 151808911..5d1907c0e 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -117,15 +117,43 @@ func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error { return nil } -// Goroutine is an RPC stub which dumps out the stack trace for all running +// GoroutineProfile is an RPC stub which dumps out the stack trace for all running // goroutines. -func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error { +func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { return errNoOutput } output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil { + if err := pprof.Lookup("goroutine").WriteTo(output, 0); err != nil { + return err + } + return nil +} + +// BlockProfile is an RPC stub which dumps out the stack trace that led to +// blocking on synchronization primitives. +func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error { + if len(o.FilePayload.Files) < 1 { + return errNoOutput + } + output := o.FilePayload.Files[0] + defer output.Close() + if err := pprof.Lookup("block").WriteTo(output, 0); err != nil { + return err + } + return nil +} + +// MutexProfile is an RPC stub which dumps out the stack trace of holders of +// contended mutexes. +func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error { + if len(o.FilePayload.Files) < 1 { + return errNoOutput + } + output := o.FilePayload.Files[0] + defer output.Close() + if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil { return err } return nil diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 9b6bb26d0..9379a4d7b 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -26,6 +26,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 7e66c29b0..acbd401a0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/usermem" ) @@ -124,10 +125,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + } - "net": newDirectory(ctx, map[string]*fs.Inode{ + if isNetTunSupported(inet.StackFromContext(ctx)) { + contents["net"] = newDirectory(ctx, map[string]*fs.Inode{ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), - }, msrc), + }, msrc) } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index 755644488..dc7ad075a 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserror" @@ -168,3 +169,9 @@ func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.Eve func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { fops.device.EventUnregister(e) } + +// isNetTunSupported returns whether /dev/net/tun device is supported for s. +func isNetTunSupported(s inet.Stack) bool { + _, ok := s.(*netstack.Stack) + return ok +} diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index acab0411a..e0b32e1c1 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -1438,8 +1438,8 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName }, nil } -func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error { - uattr, err := dir.Inode.UnstableAttr(ctx) +func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error { + uattr, err := d.Inode.UnstableAttr(ctx) if err != nil { return syserror.EPERM } @@ -1465,30 +1465,33 @@ func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error { return syserror.EPERM } -// MayDelete determines whether `name`, a child of `dir`, can be deleted or +// MayDelete determines whether `name`, a child of `d`, can be deleted or // renamed by `ctx`. // // Compare Linux kernel fs/namei.c:may_delete. -func MayDelete(ctx context.Context, root, dir *Dirent, name string) error { - if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil { +func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error { + if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil { return err } - victim, err := dir.Walk(ctx, root, name) + unlock := d.lockDirectory() + defer unlock() + + victim, err := d.walk(ctx, root, name, true /* may unlock */) if err != nil { return err } defer victim.DecRef() - return mayDelete(ctx, dir, victim) + return d.mayDelete(ctx, victim) } // mayDelete determines whether `victim`, a child of `dir`, can be deleted or // renamed by `ctx`. // // Preconditions: `dir` is writable and executable by `ctx`. -func mayDelete(ctx context.Context, dir, victim *Dirent) error { - if err := checkSticky(ctx, dir, victim); err != nil { +func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error { + if err := d.checkSticky(ctx, victim); err != nil { return err } @@ -1542,7 +1545,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string defer renamed.DecRef() // Check that the renamed dirent is deletable. - if err := mayDelete(ctx, oldParent, renamed); err != nil { + if err := oldParent.mayDelete(ctx, renamed); err != nil { return err } @@ -1580,7 +1583,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string // across the Rename, so must call DecRef manually (no defer). // Check that we can delete replaced. - if err := mayDelete(ctx, newParent, replaced); err != nil { + if err := newParent.mayDelete(ctx, replaced); err != nil { replaced.DecRef() return err } diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index daecc4ffe..1922ff08c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -259,8 +259,8 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, ui // RemoveXattr implements fs.InodeOperations.RemoveXattr. func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error { - i.mu.RLock() - defer i.mu.RUnlock() + i.mu.Lock() + defer i.mu.Unlock() if _, ok := i.xattrs[name]; ok { delete(i.xattrs, name) return nil diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 8ab8d8a02..4e9b0fc00 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -72,24 +72,26 @@ var _ fs.InodeOperations = (*taskDir)(nil) // newTaskDir creates a new proc task entry. func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode { contents := map[string]*fs.Inode{ - "auxv": newAuxvec(t, msrc), - "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), - "comm": newComm(t, msrc), - "environ": newExecArgInode(t, msrc, environExecArg), - "exe": newExe(t, msrc), - "fd": newFdDir(t, msrc), - "fdinfo": newFdInfoDir(t, msrc), - "gid_map": newGIDMap(t, msrc), - "io": newIO(t, msrc, isThreadGroup), - "maps": newMaps(t, msrc), - "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), - "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), - "ns": newNamespaceDir(t, msrc), - "smaps": newSmaps(t, msrc), - "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), - "statm": newStatm(t, msrc), - "status": newStatus(t, msrc, p.pidns), - "uid_map": newUIDMap(t, msrc), + "auxv": newAuxvec(t, msrc), + "cmdline": newExecArgInode(t, msrc, cmdlineExecArg), + "comm": newComm(t, msrc), + "environ": newExecArgInode(t, msrc, environExecArg), + "exe": newExe(t, msrc), + "fd": newFdDir(t, msrc), + "fdinfo": newFdInfoDir(t, msrc), + "gid_map": newGIDMap(t, msrc), + "io": newIO(t, msrc, isThreadGroup), + "maps": newMaps(t, msrc), + "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), + "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), + "ns": newNamespaceDir(t, msrc), + "oom_score": newOOMScore(t, msrc), + "oom_score_adj": newOOMScoreAdj(t, msrc), + "smaps": newSmaps(t, msrc), + "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns), + "statm": newStatm(t, msrc), + "status": newStatus(t, msrc, p.pidns), + "uid_map": newUIDMap(t, msrc), } if isThreadGroup { contents["task"] = p.newSubtasks(t, msrc) @@ -796,4 +798,92 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc return int64(n), err } +// newOOMScore returns a oom_score file. It is a stub that always returns 0. +// TODO(gvisor.dev/issue/1967) +func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + return newStaticProcInode(t, msrc, []byte("0\n")) +} + +// oomScoreAdj is a file containing the oom_score adjustment for a task. +// +// +stateify savable +type oomScoreAdj struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// +stateify savable +type oomScoreAdjFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +// newOOMScoreAdj returns a oom_score_adj file. +func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + i := &oomScoreAdj{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, i, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + adj, err := f.t.OOMScoreAdj() + if err != nil { + return 0, err + } + adjBytes := []byte(strconv.FormatInt(int64(adj), 10) + "\n") + n, err := dst.CopyOut(ctx, adjBytes) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + + if err := f.t.SetOOMScoreAdj(v); err != nil { + return 0, err + } + + return n, nil +} + // LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go) diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index d00850e25..c4a8f0b38 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1045,13 +1045,13 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // using the old file descriptor, preventing us from safely // closing it. We could handle this by invalidating existing // memmap.Translations, but this is expensive. Instead, use - // dup2() to make the old file descriptor refer to the new file + // dup3 to make the old file descriptor refer to the new file // description, then close the new file descriptor (which is no // longer needed). Racing callers may use the old or new file // description, but this doesn't matter since they refer to the // same file (unless d.fs.opts.overlayfsStaleRead is true, // which we handle separately). - if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil { + if err := syscall.Dup3(int(h.fd), int(d.handle.fd), 0); err != nil { d.handleMu.Unlock() ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) h.close(ctx) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 54c1031a7..e95209661 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -361,8 +361,15 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro rw.d.handleMu.RLock() if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off) - rw.d.handleMu.RUnlock() rw.off += n + rw.d.dataMu.Lock() + if rw.off > rw.d.size { + atomic.StoreUint64(&rw.d.size, rw.off) + // The remote file's size will implicitly be extended to the correct + // value when we write back to it. + } + rw.d.dataMu.Unlock() + rw.d.handleMu.RUnlock() return n, err } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 2d814668a..18e5cd6f6 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -62,11 +62,13 @@ func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNames "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"), "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"), }), - "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), - "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), - "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), - "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), - "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), + "oom_score": newTaskOwnedFile(task, inoGen.NextIno(), 0444, newStaticFile("0\n")), + "oom_score_adj": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &oomScoreAdj{task: task}), + "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}), + "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), + "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}), + "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { contents["task"] = newSubtasks(task, pidns, inoGen, cgroupControllers) diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index efd3b3453..5a231ac86 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -525,3 +525,46 @@ func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled) return nil } + +// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file. +// +// +stateify savable +type oomScoreAdj struct { + kernfs.DynamicBytesFile + + task *kernel.Task +} + +var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error { + adj, err := o.task.OOMScoreAdj() + if err != nil { + return err + } + fmt.Fprintf(buf, "%d\n", adj) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is large. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + + if err := o.task.SetOOMScoreAdj(v); err != nil { + return 0, err + } + + return n, nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index c5d531fe0..0eb401619 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -63,21 +63,23 @@ var ( "thread-self": threadSelfLink.NextOff, } taskStaticFiles = map[string]testutil.DirentType{ - "auxv": linux.DT_REG, - "cgroup": linux.DT_REG, - "cmdline": linux.DT_REG, - "comm": linux.DT_REG, - "environ": linux.DT_REG, - "gid_map": linux.DT_REG, - "io": linux.DT_REG, - "maps": linux.DT_REG, - "ns": linux.DT_DIR, - "smaps": linux.DT_REG, - "stat": linux.DT_REG, - "statm": linux.DT_REG, - "status": linux.DT_REG, - "task": linux.DT_DIR, - "uid_map": linux.DT_REG, + "auxv": linux.DT_REG, + "cgroup": linux.DT_REG, + "cmdline": linux.DT_REG, + "comm": linux.DT_REG, + "environ": linux.DT_REG, + "gid_map": linux.DT_REG, + "io": linux.DT_REG, + "maps": linux.DT_REG, + "ns": linux.DT_DIR, + "oom_score": linux.DT_REG, + "oom_score_adj": linux.DT_REG, + "smaps": linux.DT_REG, + "stat": linux.DT_REG, + "statm": linux.DT_REG, + "status": linux.DT_REG, + "task": linux.DT_DIR, + "uid_map": linux.DT_REG, } ) diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go index c16667e7f..029af3025 100644 --- a/pkg/sentry/inet/namespace.go +++ b/pkg/sentry/inet/namespace.go @@ -23,7 +23,10 @@ type Namespace struct { // creator allows kernel to create new network stack for network namespaces. // If nil, no networking will function if network is namespaced. - creator NetworkStackCreator + // + // At afterLoad(), creator will be used to create network stack. Stateify + // needs to wait for this field to be loaded before calling afterLoad(). + creator NetworkStackCreator `state:"wait"` // isRoot indicates whether this is the root network namespace. isRoot bool diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8b76750e9..1d627564f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -755,6 +755,8 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} { return ctx.k.GlobalInit().Leader().MountNamespaceVFS2() case fs.CtxDirentCacheLimiter: return ctx.k.DirentCacheLimiter + case inet.CtxStack: + return ctx.k.RootNetworkNamespace().Stack() case ktime.CtxRealtimeClock: return ctx.k.RealtimeClock() case limits.CtxLimits: @@ -1481,6 +1483,8 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { return ctx.k.GlobalInit().Leader().MountNamespaceVFS2() case fs.CtxDirentCacheLimiter: return ctx.k.DirentCacheLimiter + case inet.CtxStack: + return ctx.k.RootNetworkNamespace().Stack() case ktime.CtxRealtimeClock: return ctx.k.RealtimeClock() case limits.CtxLimits: diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 4c049d5b4..f29dc0472 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,25 +1,10 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "buffer_list", - out = "buffer_list.go", - package = "pipe", - prefix = "buffer", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*buffer", - "Linker": "*buffer", - }, -) - go_library( name = "pipe", srcs = [ - "buffer.go", - "buffer_list.go", "device.go", "node.go", "pipe.go", @@ -33,8 +18,8 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/amutex", + "//pkg/buffer", "//pkg/context", - "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", @@ -51,7 +36,6 @@ go_test( name = "pipe_test", size = "small", srcs = [ - "buffer_test.go", "node_test.go", "pipe_test.go", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go deleted file mode 100644 index fe3be5dbd..000000000 --- a/pkg/sentry/kernel/pipe/buffer.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pipe - -import ( - "io" - - "gvisor.dev/gvisor/pkg/safemem" - "gvisor.dev/gvisor/pkg/sync" -) - -// buffer encapsulates a queueable byte buffer. -// -// Note that the total size is slightly less than two pages. This -// is done intentionally to ensure that the buffer object aligns -// with runtime internals. We have no hard size or alignment -// requirements. This two page size will effectively minimize -// internal fragmentation, but still have a large enough chunk -// to limit excessive segmentation. -// -// +stateify savable -type buffer struct { - data [8144]byte - read int - write int - bufferEntry -} - -// Reset resets internal data. -// -// This must be called before use. -func (b *buffer) Reset() { - b.read = 0 - b.write = 0 -} - -// Empty indicates the buffer is empty. -// -// This indicates there is no data left to read. -func (b *buffer) Empty() bool { - return b.read == b.write -} - -// Full indicates the buffer is full. -// -// This indicates there is no capacity left to write. -func (b *buffer) Full() bool { - return b.write == len(b.data) -} - -// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. -func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { - dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:])) - n, err := safemem.CopySeq(dst, srcs) - b.write += int(n) - return n, err -} - -// WriteFromReader writes to the buffer from an io.Reader. -func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) { - dst := b.data[b.write:] - if count < int64(len(dst)) { - dst = b.data[b.write:][:count] - } - n, err := r.Read(dst) - b.write += n - return int64(n), err -} - -// ReadToBlocks implements safemem.Reader.ReadToBlocks. -func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { - src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write])) - n, err := safemem.CopySeq(dsts, src) - b.read += int(n) - return n, err -} - -// ReadToWriter reads from the buffer into an io.Writer. -func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) { - src := b.data[b.read:b.write] - if count < int64(len(src)) { - src = b.data[b.read:][:count] - } - n, err := w.Write(src) - if !dup { - b.read += n - } - return int64(n), err -} - -// bufferPool is a pool for buffers. -var bufferPool = sync.Pool{ - New: func() interface{} { - return new(buffer) - }, -} - -// newBuffer grabs a new buffer from the pool. -func newBuffer() *buffer { - b := bufferPool.Get().(*buffer) - b.Reset() - return b -} diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 08410283f..725e9db7d 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "syscall" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" @@ -70,10 +71,10 @@ type Pipe struct { // mu protects all pipe internal state below. mu sync.Mutex `state:"nosave"` - // data is the buffer queue of pipe contents. + // view is the underlying set of buffers. // // This is protected by mu. - data bufferList + view buffer.View // max is the maximum size of the pipe in bytes. When this max has been // reached, writers will get EWOULDBLOCK. @@ -81,11 +82,6 @@ type Pipe struct { // This is protected by mu. max int64 - // size is the current size of the pipe in bytes. - // - // This is protected by mu. - size int64 - // hadWriter indicates if this pipe ever had a writer. Note that this // does not necessarily indicate there is *currently* a writer, just // that there has been a writer at some point since the pipe was @@ -196,7 +192,7 @@ type readOps struct { limit func(int64) // read performs the actual read operation. - read func(*buffer) (int64, error) + read func(*buffer.View) (int64, error) } // read reads data from the pipe into dst and returns the number of bytes @@ -213,7 +209,7 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { defer p.mu.Unlock() // Is the pipe empty? - if p.size == 0 { + if p.view.Size() == 0 { if !p.HasWriters() { // There are no writers, return EOF. return 0, nil @@ -222,71 +218,13 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { } // Limit how much we consume. - if ops.left() > p.size { - ops.limit(p.size) + if ops.left() > p.view.Size() { + ops.limit(p.view.Size()) } - done := int64(0) - for ops.left() > 0 { - // Pop the first buffer. - first := p.data.Front() - if first == nil { - break - } - - // Copy user data. - n, err := ops.read(first) - done += int64(n) - p.size -= n - - // Empty buffer? - if first.Empty() { - // Push to the free list. - p.data.Remove(first) - bufferPool.Put(first) - } - - // Handle errors. - if err != nil { - return done, err - } - } - - return done, nil -} - -// dup duplicates all data from this pipe into the given writer. -// -// There is no blocking behavior implemented here. The writer may propagate -// some blocking error. All the writes must be complete writes. -func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) { - p.mu.Lock() - defer p.mu.Unlock() - - // Is the pipe empty? - if p.size == 0 { - if !p.HasWriters() { - // See above. - return 0, nil - } - return 0, syserror.ErrWouldBlock - } - - // Limit how much we consume. - if ops.left() > p.size { - ops.limit(p.size) - } - - done := int64(0) - for buf := p.data.Front(); buf != nil; buf = buf.Next() { - n, err := ops.read(buf) - done += n - if err != nil { - return done, err - } - } - - return done, nil + // Copy user data; the read op is responsible for trimming. + done, err := ops.read(&p.view) + return done, err } type writeOps struct { @@ -297,7 +235,7 @@ type writeOps struct { limit func(int64) // write should write to the provided buffer. - write func(*buffer) (int64, error) + write func(*buffer.View) (int64, error) } // write writes data from sv into the pipe and returns the number of bytes @@ -317,33 +255,19 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) { // POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be // atomic, but requires no atomicity for writes larger than this. wanted := ops.left() - if avail := p.max - p.size; wanted > avail { + if avail := p.max - p.view.Size(); wanted > avail { if wanted <= p.atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) } - done := int64(0) - for ops.left() > 0 { - // Need a new buffer? - last := p.data.Back() - if last == nil || last.Full() { - // Add a new buffer to the data list. - last = newBuffer() - p.data.PushBack(last) - } - - // Copy user data. - n, err := ops.write(last) - done += int64(n) - p.size += n - - // Handle errors. - if err != nil { - return done, err - } + // Copy user data. + done, err := ops.write(&p.view) + if err != nil { + return done, err } + if wanted > done { // Partial write due to full pipe. return done, syserror.ErrWouldBlock @@ -396,7 +320,7 @@ func (p *Pipe) HasWriters() bool { // Precondition: mu must be held. func (p *Pipe) rReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasReaders() && p.data.Front() != nil { + if p.HasReaders() && p.view.Size() != 0 { ready |= waiter.EventIn } if !p.HasWriters() && p.hadWriter { @@ -422,7 +346,7 @@ func (p *Pipe) rReadiness() waiter.EventMask { // Precondition: mu must be held. func (p *Pipe) wReadinessLocked() waiter.EventMask { ready := waiter.EventMask(0) - if p.HasWriters() && p.size < p.max { + if p.HasWriters() && p.view.Size() < p.max { ready |= waiter.EventOut } if !p.HasReaders() { @@ -451,7 +375,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask { func (p *Pipe) queued() int64 { p.mu.Lock() defer p.mu.Unlock() - return p.size + return p.view.Size() } // FifoSize implements fs.FifoSizer.FifoSize. @@ -474,7 +398,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) { } p.mu.Lock() defer p.mu.Unlock() - if size < p.size { + if size < p.view.Size() { return 0, syserror.EBUSY } p.max = size diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 80158239e..5a1d4fd57 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" @@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) limit: func(l int64) { dst = dst.TakeFirst64(l) }, - read: func(buf *buffer) (int64, error) { - n, err := dst.CopyOutFrom(ctx, buf) + read: func(view *buffer.View) (int64, error) { + n, err := dst.CopyOutFrom(ctx, view) dst = dst.DropFirst64(n) + view.TrimFront(n) return n, err }, }) @@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) limit: func(l int64) { count = l }, - read: func(buf *buffer) (int64, error) { - n, err := buf.ReadToWriter(w, count, dup) + read: func(view *buffer.View) (int64, error) { + n, err := view.ReadToWriter(w, count) + if !dup { + view.TrimFront(n) + } count -= n return n, err }, } - if dup { - // There is no notification for dup operations. - return p.dup(ctx, ops) - } n, err := p.read(ctx, ops) if n > 0 { p.Notify(waiter.EventOut) @@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) limit: func(l int64) { src = src.TakeFirst64(l) }, - write: func(buf *buffer) (int64, error) { - n, err := src.CopyInTo(ctx, buf) + write: func(view *buffer.View) (int64, error) { + n, err := src.CopyInTo(ctx, view) src = src.DropFirst64(n) return n, err }, @@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e limit: func(l int64) { count = l }, - write: func(buf *buffer) (int64, error) { - n, err := buf.WriteFromReader(r, count) + write: func(view *buffer.View) (int64, error) { + n, err := view.WriteFromReader(r, count) count -= n return n, err }, diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 2cee2e6ed..c0dbbe890 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -554,6 +555,13 @@ type Task struct { // // startTime is protected by mu. startTime ktime.Time + + // oomScoreAdj is the task's OOM score adjustment. This is currently not + // used but is maintained for consistency. + // TODO(gvisor.dev/issue/1967) + // + // oomScoreAdj is protected by mu, and is owned by the task goroutine. + oomScoreAdj int32 } func (t *Task) savePtraceTracer() *Task { @@ -847,3 +855,28 @@ func (t *Task) AbstractSockets() *AbstractSocketNamespace { func (t *Task) ContainerID() string { return t.containerID } + +// OOMScoreAdj gets the task's OOM score adjustment. +func (t *Task) OOMScoreAdj() (int32, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.ExitState() == TaskExitDead { + return 0, syserror.ESRCH + } + return t.oomScoreAdj, nil +} + +// SetOOMScoreAdj sets the task's OOM score adjustment. The value should be +// between -1000 and 1000 inclusive. +func (t *Task) SetOOMScoreAdj(adj int32) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.ExitState() == TaskExitDead { + return syserror.ESRCH + } + if adj > 1000 || adj < -1000 { + return syserror.EINVAL + } + t.oomScoreAdj = adj + return nil +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 78866f280..dda502bb8 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -264,6 +264,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { rseqSignature = t.rseqSignature } + adj, err := t.OOMScoreAdj() + if err != nil { + return 0, nil, err + } + cfg := &TaskConfig{ Kernel: t.k, ThreadGroup: tg, @@ -282,6 +287,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { RSeqAddr: rseqAddr, RSeqSignature: rseqSignature, ContainerID: t.ContainerID(), + OOMScoreAdj: adj, } if opts.NewThreadGroup { cfg.Parent = t diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index 5568c91bc..799cbcd93 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -126,13 +126,39 @@ func (t *Task) doStop() { } } +func (*runApp) handleCPUIDInstruction(t *Task) error { + if len(arch.CPUIDInstruction) == 0 { + // CPUID emulation isn't supported, but this code can be + // executed, because the ptrace platform returns + // ErrContextSignalCPUID on page faults too. Look at + // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more + // details. + return platform.ErrContextSignal + } + // Is this a CPUID instruction? + region := trace.StartRegion(t.traceContext, cpuidRegion) + expected := arch.CPUIDInstruction[:] + found := make([]byte, len(expected)) + _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) + if err == nil && bytes.Equal(expected, found) { + // Skip the cpuid instruction. + t.Arch().CPUIDEmulate(t) + t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) + region.End() + + return nil + } + region.End() // Not an actual CPUID, but required copy-in. + return platform.ErrContextSignal +} + // The runApp state checks for interrupts before executing untrusted // application code. // // +stateify savable type runApp struct{} -func (*runApp) execute(t *Task) taskRunState { +func (app *runApp) execute(t *Task) taskRunState { if t.interrupted() { // Checkpointing instructs tasks to stop by sending an interrupt, so we // must check for stops before entering runInterrupt (instead of @@ -237,21 +263,10 @@ func (*runApp) execute(t *Task) taskRunState { return (*runApp)(nil) case platform.ErrContextSignalCPUID: - // Is this a CPUID instruction? - region := trace.StartRegion(t.traceContext, cpuidRegion) - expected := arch.CPUIDInstruction[:] - found := make([]byte, len(expected)) - _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) - if err == nil && bytes.Equal(expected, found) { - // Skip the cpuid instruction. - t.Arch().CPUIDEmulate(t) - t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected))) - region.End() - + if err := app.handleCPUIDInstruction(t); err == nil { // Resume execution. return (*runApp)(nil) } - region.End() // Not an actual CPUID, but required copy-in. // The instruction at the given RIP was not a CPUID, and we // fallthrough to the default signal deliver behavior below. diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index a5035bb7f..2bbf48bb8 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -93,6 +93,9 @@ type TaskConfig struct { // ContainerID is the container the new task belongs to. ContainerID string + + // oomScoreAdj is the task's OOM score adjustment. + OOMScoreAdj int32 } // NewTask creates a new task defined by cfg. @@ -143,6 +146,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, + oomScoreAdj: cfg.OOMScoreAdj, } t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index 35cd55fef..4b23f7803 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -81,12 +81,6 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { // Save the death message, which will be thrown. c.dieState.message = msg - // Reload all registers to have an accurate stack trace when we return - // to host mode. This means that the stack should be unwound correctly. - if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 { - throw(msg) - } - // Setup the trampoline. dieArchSetup(c, context, &c.dieState.guestRegs) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index a63a6a071..99cac665d 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -31,6 +31,12 @@ import ( // //go:nosplit func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { + // Reload all registers to have an accurate stack trace when we return + // to host mode. This means that the stack should be unwound correctly. + if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 { + throw(c.dieState.message) + } + // If the vCPU is in user mode, we set the stack to the stored stack // value in the vCPU itself. We don't want to unwind the user stack. if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet { diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index c61700892..04efa0147 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -82,6 +82,8 @@ fallback: // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 - // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. - MOVD R9, 8(RSP) - BL ·dieHandler(SB) + // R0: Fake the old PC as caller + // R1: First argument (vCPU) + MOVD.P R1, 8(RSP) // R1: First argument (vCPU) + MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller + B ·dieHandler(SB) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 2f02c03cf..195331383 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -18,9 +18,30 @@ package kvm import ( "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" ) +// dieArchSetup initialies the state for dieTrampoline. +// +// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC +// to be in R0. The trampoline then simulates a call to dieHandler from the +// provided PC. +// //go:nosplit func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { - // TODO(gvisor.dev/issue/1249): dieTrampoline supporting for Arm64. + // If the vCPU is in user mode, we set the stack to the stored stack + // value in the vCPU itself. We don't want to unwind the user stack. + if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t { + regs := c.CPU.Registers() + context.Regs[0] = regs.Regs[0] + context.Sp = regs.Sp + context.Regs[29] = regs.Regs[29] // stack base address + } else { + context.Regs[0] = guestRegs.Regs.Pc + context.Sp = guestRegs.Regs.Sp + context.Regs[29] = guestRegs.Regs.Regs[29] + context.Pstate = guestRegs.Regs.Pstate + } + context.Regs[1] = uint64(uintptr(unsafe.Pointer(c))) + context.Pc = uint64(dieTrampolineAddr) } diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1c8384e6b..b531f2f85 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -29,30 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - type kvmVcpuInit struct { target uint32 features [7]uint32 diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index e99798c56..cd74945e7 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -21,6 +21,7 @@ import ( "strings" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -183,13 +184,76 @@ func enableCpuidFault() { // appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. // Ref attachedThread() for more detail. -func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { - return append(rules, seccomp.RuleSet{ - Rules: seccomp.SyscallRules{ - syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, +func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet { + rules = append(rules, + // Rules for trapping vsyscall access. + seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_GETTIMEOFDAY: {}, + syscall.SYS_TIME: {}, + unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. }, - }, - Action: linux.SECCOMP_RET_ALLOW, - }) + Action: linux.SECCOMP_RET_TRAP, + Vsyscall: true, + }) + if defaultAction != linux.SECCOMP_RET_ALLOW { + rules = append(rules, + seccomp.RuleSet{ + Rules: seccomp.SyscallRules{ + syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ + {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }) + } + return rules +} + +// probeSeccomp returns true iff seccomp is run after ptrace notifications, +// which is generally the case for kernel version >= 4.8. This check is dynamic +// because kernels have be backported behavior. +// +// See createStub for more information. +// +// Precondition: the runtime OS thread must be locked. +func probeSeccomp() bool { + // Create a completely new, destroyable process. + t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO) + if err != nil { + panic(fmt.Sprintf("seccomp probe failed: %v", err)) + } + defer t.destroy() + + // Set registers to the yield system call. This call is not allowed + // by the filters specified in the attachThread function. + regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD) + if err := t.setRegs(®s); err != nil { + panic(fmt.Sprintf("ptrace set regs failed: %v", err)) + } + + for { + // Attempt an emulation. + if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { + panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) + } + + sig := t.wait(stopped) + if sig == (syscallEvent | syscall.SIGTRAP) { + // Did the seccomp errno hook already run? This would + // indicate that seccomp is first in line and we're + // less than 4.8. + if err := t.getRegs(®s); err != nil { + panic(fmt.Sprintf("ptrace get-regs failed: %v", err)) + } + if _, err := syscallReturnValue(®s); err == nil { + // The seccomp errno mode ran first, and reset + // the error in the registers. + return false + } + // The seccomp hook did not run yet, and therefore it + // is safe to use RET_KILL mode for dispatched calls. + return true + } + } } diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go index 7b975137f..7f5c393f0 100644 --- a/pkg/sentry/platform/ptrace/subprocess_arm64.go +++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go @@ -160,6 +160,15 @@ func enableCpuidFault() { // appendArchSeccompRules append architecture specific seccomp rules when creating BPF program. // Ref attachedThread() for more detail. -func appendArchSeccompRules(rules []seccomp.RuleSet) []seccomp.RuleSet { +func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet { return rules } + +// probeSeccomp returns true if seccomp is run after ptrace notifications, +// which is generally the case for kernel version >= 4.8. +// +// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so +// probeSeccomp can always return true. +func probeSeccomp() bool { + return true +} diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 74968dfdf..2ce528601 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -20,7 +20,6 @@ import ( "fmt" "syscall" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" @@ -30,54 +29,6 @@ import ( const syscallEvent syscall.Signal = 0x80 -// probeSeccomp returns true iff seccomp is run after ptrace notifications, -// which is generally the case for kernel version >= 4.8. This check is dynamic -// because kernels have be backported behavior. -// -// See createStub for more information. -// -// Precondition: the runtime OS thread must be locked. -func probeSeccomp() bool { - // Create a completely new, destroyable process. - t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO) - if err != nil { - panic(fmt.Sprintf("seccomp probe failed: %v", err)) - } - defer t.destroy() - - // Set registers to the yield system call. This call is not allowed - // by the filters specified in the attachThread function. - regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD) - if err := t.setRegs(®s); err != nil { - panic(fmt.Sprintf("ptrace set regs failed: %v", err)) - } - - for { - // Attempt an emulation. - if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 { - panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno)) - } - - sig := t.wait(stopped) - if sig == (syscallEvent | syscall.SIGTRAP) { - // Did the seccomp errno hook already run? This would - // indicate that seccomp is first in line and we're - // less than 4.8. - if err := t.getRegs(®s); err != nil { - panic(fmt.Sprintf("ptrace get-regs failed: %v", err)) - } - if _, err := syscallReturnValue(®s); err == nil { - // The seccomp errno mode ran first, and reset - // the error in the registers. - return false - } - // The seccomp hook did not run yet, and therefore it - // is safe to use RET_KILL mode for dispatched calls. - return true - } - } -} - // createStub creates a fresh stub processes. // // Precondition: the runtime OS thread must be locked. @@ -123,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // stub and all its children. This is used to create child stubs // (below), so we must include the ability to fork, but otherwise lock // down available calls only to what is needed. - rules := []seccomp.RuleSet{ - // Rules for trapping vsyscall access. - { - Rules: seccomp.SyscallRules{ - syscall.SYS_GETTIMEOFDAY: {}, - syscall.SYS_TIME: {}, - unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64. - }, - Action: linux.SECCOMP_RET_TRAP, - Vsyscall: true, - }, - } + rules := []seccomp.RuleSet{} if defaultAction != linux.SECCOMP_RET_ALLOW { rules = append(rules, seccomp.RuleSet{ Rules: seccomp.SyscallRules{ @@ -173,9 +113,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }, Action: linux.SECCOMP_RET_ALLOW, }) - - rules = appendArchSeccompRules(rules) } + rules = appendArchSeccompRules(rules, defaultAction) instrs, err := seccomp.BuildProgram(rules, defaultAction) if err != nil { return nil, err diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index f6da41c27..8122ac6e2 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -27,26 +27,27 @@ const ( _PTE_PGT_BASE = 0x7000 _PTE_PGT_SIZE = 0x1000 - _PSR_MODE_EL0t = 0x0 - _PSR_MODE_EL1t = 0x4 - _PSR_MODE_EL1h = 0x5 - _PSR_EL_MASK = 0xf - - _PSR_D_BIT = 0x200 - _PSR_A_BIT = 0x100 - _PSR_I_BIT = 0x80 - _PSR_F_BIT = 0x40 + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 ) const ( + // PSR bits + PSR_MODE_EL0t = 0x00000000 + PSR_MODE_EL1t = 0x00000004 + PSR_MODE_EL1h = 0x00000005 + PSR_MODE_MASK = 0x0000000f + // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = _PSR_MODE_EL1h + KernelFlagsSet = PSR_MODE_EL1h // UserFlagsSet are always set in userspace. - UserFlagsSet = _PSR_MODE_EL0t + UserFlagsSet = PSR_MODE_EL0t - KernelFlagsClear = _PSR_EL_MASK - UserFlagsClear = _PSR_EL_MASK + KernelFlagsClear = PSR_MODE_MASK + UserFlagsClear = PSR_MODE_MASK PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT ) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 4f2406ce3..581841555 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -80,7 +80,7 @@ go_library( "pagetables_amd64.go", "pagetables_arm64.go", "pagetables_x86.go", - "pcids_x86.go", + "pcids.go", "walker_amd64.go", "walker_arm64.go", "walker_empty.go", diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids.go index e199bae18..9206030bf 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build i386 amd64 - package pagetables import ( diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e187276c5..13a9a60b4 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -712,14 +712,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) - if err != nil { - return err + if len(sockaddr) < 2 { + return syserr.ErrInvalidArgument } - if err := s.checkFamily(family, true /* exact */); err != nil { - return err + + family := usermem.ByteOrder.Uint16(sockaddr) + var addr tcpip.FullAddress + + // Bind for AF_PACKET requires only family, protocol and ifindex. + // In function AddressAndFamily, we check the address length which is + // not needed for AF_PACKET bind. + if family == linux.AF_PACKET { + var a linux.SockAddrLink + if len(sockaddr) < sockAddrLinkSize { + return syserr.ErrInvalidArgument + } + binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + + if a.Protocol != uint16(s.protocol) { + return syserr.ErrInvalidArgument + } + + addr = tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + } + } else { + var err *syserr.Error + addr, family, err = AddressAndFamily(sockaddr) + if err != nil { + return err + } + + if err = s.checkFamily(family, true /* exact */); err != nil { + return err + } + + addr = s.mapFamily(addr, family) } - addr = s.mapFamily(addr, family) // Issue the bind request to the endpoint. return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) @@ -2637,7 +2667,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } // Add bytes removed from the endpoint but not yet sent to the caller. + s.readMu.Lock() v += len(s.readView) + s.readMu.Unlock() if v > math.MaxInt32 { v = math.MaxInt32 diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index c7883e68e..0d24fd3c4 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -42,6 +42,8 @@ go_library( "sys_socket.go", "sys_splice.go", "sys_stat.go", + "sys_stat_amd64.go", + "sys_stat_arm64.go", "sys_sync.go", "sys_sysinfo.go", "sys_syslog.go", diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index c21f14dc0..d10a9bed8 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -1236,7 +1236,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return syserror.ENOTEMPTY } - if err := fs.MayDelete(t, root, d, name); err != nil { + if err := d.MayDelete(t, root, name); err != nil { return err } @@ -1517,7 +1517,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error { return syserror.ENOTDIR } - if err := fs.MayDelete(t, root, d, name); err != nil { + if err := d.MayDelete(t, root, name); err != nil { return err } diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 701b27b4a..9bd2df104 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -25,24 +25,6 @@ import ( // LINT.IfChange -func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat { - return linux.Stat{ - Dev: sattr.DeviceID, - Ino: sattr.InodeID, - Nlink: uattr.Links, - Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()), - UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()), - GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()), - Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)), - Size: uattr.Size, - Blksize: sattr.BlockSize, - Blocks: uattr.Usage / 512, - ATime: uattr.AccessTime.Timespec(), - MTime: uattr.ModificationTime.Timespec(), - CTime: uattr.StatusChangeTime.Timespec(), - } -} - // Stat implements linux syscall stat(2). func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go new file mode 100644 index 000000000..0a04a6113 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go @@ -0,0 +1,45 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +// LINT.IfChange + +func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat { + return linux.Stat{ + Dev: sattr.DeviceID, + Ino: sattr.InodeID, + Nlink: uattr.Links, + Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()), + UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()), + GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)), + Size: uattr.Size, + Blksize: sattr.BlockSize, + Blocks: uattr.Usage / 512, + ATime: uattr.AccessTime.Timespec(), + MTime: uattr.ModificationTime.Timespec(), + CTime: uattr.StatusChangeTime.Timespec(), + } +} + +// LINT.ThenChange(vfs2/stat_amd64.go) diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go new file mode 100644 index 000000000..5a3b1bfad --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go @@ -0,0 +1,45 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +// LINT.IfChange + +func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat { + return linux.Stat{ + Dev: sattr.DeviceID, + Ino: sattr.InodeID, + Nlink: uint32(uattr.Links), + Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()), + UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()), + GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)), + Size: uattr.Size, + Blksize: int32(sattr.BlockSize), + Blocks: uattr.Usage / 512, + ATime: uattr.AccessTime.Timespec(), + MTime: uattr.ModificationTime.Timespec(), + CTime: uattr.StatusChangeTime.Timespec(), + } +} + +// LINT.ThenChange(vfs2/stat_arm64.go) diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index f51761e81..e7695e995 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -22,6 +22,8 @@ go_library( "read_write.go", "setstat.go", "stat.go", + "stat_amd64.go", + "stat_arm64.go", "sync.go", "xattr.go", ], diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go index dca8d7011..12c532310 100644 --- a/pkg/sentry/syscalls/linux/vfs2/stat.go +++ b/pkg/sentry/syscalls/linux/vfs2/stat.go @@ -113,29 +113,6 @@ func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags return stat.CopyOut(t, statAddr) } -// This takes both input and output as pointer arguments to avoid copying large -// structs. -func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) { - // Linux just copies fields from struct kstat without regard to struct - // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too. - userns := t.UserNamespace() - *stat = linux.Stat{ - Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)), - Ino: statx.Ino, - Nlink: uint64(statx.Nlink), - Mode: uint32(statx.Mode), - UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()), - GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()), - Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)), - Size: int64(statx.Size), - Blksize: int64(statx.Blksize), - Blocks: int64(statx.Blocks), - ATime: timespecFromStatxTimestamp(statx.Atime), - MTime: timespecFromStatxTimestamp(statx.Mtime), - CTime: timespecFromStatxTimestamp(statx.Ctime), - } -} - func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec { return linux.Timespec{ Sec: sxts.Sec, diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go new file mode 100644 index 000000000..2da538fc6 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// This takes both input and output as pointer arguments to avoid copying large +// structs. +func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) { + // Linux just copies fields from struct kstat without regard to struct + // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too. + userns := t.UserNamespace() + *stat = linux.Stat{ + Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)), + Ino: statx.Ino, + Nlink: uint64(statx.Nlink), + Mode: uint32(statx.Mode), + UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()), + GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)), + Size: int64(statx.Size), + Blksize: int64(statx.Blksize), + Blocks: int64(statx.Blocks), + ATime: timespecFromStatxTimestamp(statx.Atime), + MTime: timespecFromStatxTimestamp(statx.Mtime), + CTime: timespecFromStatxTimestamp(statx.Ctime), + } +} diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go new file mode 100644 index 000000000..88b9c7627 --- /dev/null +++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package vfs2 + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// This takes both input and output as pointer arguments to avoid copying large +// structs. +func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) { + // Linux just copies fields from struct kstat without regard to struct + // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too. + userns := t.UserNamespace() + *stat = linux.Stat{ + Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)), + Ino: statx.Ino, + Nlink: uint32(statx.Nlink), + Mode: uint32(statx.Mode), + UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()), + GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()), + Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)), + Size: int64(statx.Size), + Blksize: int32(statx.Blksize), + Blocks: int64(statx.Blocks), + ATime: timespecFromStatxTimestamp(statx.Atime), + MTime: timespecFromStatxTimestamp(statx.Mtime), + CTime: timespecFromStatxTimestamp(statx.Ctime), + } +} diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 9912df799..31a4e5480 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -139,6 +139,23 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth return mntns, nil } +// NewDisconnectedMount returns a Mount representing fs with the given root +// (which may be nil). The new Mount is not associated with any MountNamespace +// and is not connected to any other Mounts. References are taken on fs and +// root. +func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) { + fs.IncRef() + if root != nil { + root.IncRef() + } + return &Mount{ + vfs: vfs, + fs: fs, + root: root, + refs: 1, + }, nil +} + // MountAt creates and mounts a Filesystem configured by the given arguments. func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { rft := vfs.getFilesystemType(fsTypeName) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 73f8043be..bde81e1ef 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -126,17 +126,23 @@ func (vfs *VirtualFilesystem) Init() error { // Construct vfs.anonMount. anonfsDevMinor, err := vfs.GetAnonBlockDevMinor() if err != nil { - return err + // This shouldn't be possible since anonBlockDevMinorNext was + // initialized to 1 above (no device numbers have been allocated yet). + panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err)) } anonfs := anonFilesystem{ devMinor: anonfsDevMinor, } anonfs.vfsfs.Init(vfs, &anonfs) - vfs.anonMount = &Mount{ - vfs: vfs, - fs: &anonfs.vfsfs, - refs: 1, + defer anonfs.vfsfs.DecRef() + anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{}) + if err != nil { + // We should not be passing any MountOptions that would cause + // construction of this mount to fail. + panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err)) } + vfs.anonMount = anonMount + return nil } diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index bfb2fac26..f7d6009a0 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -221,7 +221,7 @@ func (w *Watchdog) waitForStart() { return } var buf bytes.Buffer - buf.WriteString("Watchdog.Start() not called within %s:\n") + buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout)) w.doAction(w.StartupTimeoutAction, false, &buf) } @@ -325,7 +325,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer - buf.WriteString("Watchdog goroutine is stuck:\n") + buf.WriteString("Watchdog goroutine is stuck:") w.doAction(w.TaskTimeoutAction, false, &buf) } @@ -359,7 +359,7 @@ func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { case <-metricsEmitted: case <-time.After(1 * time.Second): } - panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String())) default: panic(fmt.Sprintf("Unknown watchdog action %v", action)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c6c160dfc..8dc0f7c0e 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -785,6 +785,52 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// ndpOptions checks that optsBuf only contains opts. +func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { + t.Helper() + + it, err := optsBuf.Iter(true) + if err != nil { + t.Errorf("optsBuf.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + t.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -796,47 +842,31 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { icmp := h.(header.ICMPv6) ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - it, err := ns.Options().Iter(true) - if err != nil { - t.Errorf("opts.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, _ := it.Next() - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - panic("not implemented") - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } + ndpOptions(t, ns.Options(), opts) } } // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). -func NDPRS() NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPRS as far as the size of the message is concerned. The values within the +// message are up to checkers to validate. +func NDPRS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) +} + +// NDPRSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Router Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPRS message as far as the size is concerned. +func NDPRSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + rs := header.NDPRouterSolicit(icmp.NDPPayload()) + ndpOptions(t, rs.Options(), opts) + } } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 705cf01ee..8febd54c8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..8b4213eec --- /dev/null +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-0] + _ = x[DHCPv6ManagedAddress-1] + _ = x[DHCPv6OtherConfigurations-2] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index f651871ce..a9f4d5dad 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1220,9 +1220,15 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { - // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + if ref == nil { + ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + } + localAddr := ref.ep.ID().LocalAddress + r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since @@ -1234,10 +1240,25 @@ func (ndp *ndpState) startSolicitingRouters() { log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(pkt.NDPPayload()) + rs.Options().Serialize(optsSerializer) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6e9306d09..98b1c807c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -3384,6 +3384,10 @@ func TestRouterSolicitation(t *testing.T) { tests := []struct { name string linkHeaderLen uint16 + linkAddr tcpip.LinkAddress + nicAddr tcpip.Address + expectedSrcAddr tcpip.Address + expectedNDPOpts []header.NDPOption maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3392,6 +3396,7 @@ func TestRouterSolicitation(t *testing.T) { }{ { name: "Single RS with delay", + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3401,6 +3406,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS with delay", linkHeaderLen: 1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3408,8 +3415,14 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, { - name: "Single RS without delay", - linkHeaderLen: 2, + name: "Single RS without delay", + linkHeaderLen: 2, + linkAddr: linkAddr1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + expectedNDPOpts: []header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3419,6 +3432,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS without delay and invalid zero interval", linkHeaderLen: 3, + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3427,6 +3442,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Three RS without delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 3, rtrSolicitInt: 500 * time.Millisecond, effectiveRtrSolicitInt: 500 * time.Millisecond, @@ -3435,6 +3452,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with invalid negative delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3457,7 +3476,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -3481,10 +3500,10 @@ func TestRouterSolicitation(t *testing.T) { checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), + checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), - checker.NDPRS(), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { @@ -3510,13 +3529,19 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Make sure each RS got sent at the right - // times. + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } + + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } + for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) waitForPkt(defaultAsyncEventTimeout) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 46d3a6646..3e6196aee 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) for _, r := range primaryAddrs { // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoing() { + if !r.isValidForOutgoingRLocked() { continue } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ebb6c5e3b..13354d884 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -551,11 +551,13 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } -// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address -// and returns the network protocol number to be used to communicate with the -// specified address. It returns an error if the passed address is incompatible -// with the receiver. -func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 426da1ee6..2a396e9bc 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 5722815e9..09a1cd436 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -76,6 +76,7 @@ type endpoint struct { sndBufSize int closed bool stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool } // NewEndpoint returns a new packet endpoint. @@ -125,6 +126,7 @@ func (ep *endpoint) Close() { } ep.closed = true + ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -216,7 +218,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." // - packet(7). - return tcpip.ErrNotSupported + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound { + return tcpip.ErrAlreadyBound + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + + return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 272e8f570..a32f9eacf 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -32,6 +32,7 @@ go_library( srcs = [ "accept.go", "connect.go", + "connect_unsafe.go", "cubic.go", "cubic_state.go", "dispatcher.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 13e383ffc..85049e54e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -236,6 +236,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) n.amss = mssForRoute(&n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 7730e6445..c0f73ef16 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() ttl := h.ep.ttl + amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ @@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, - MSS: h.ep.amss, + MSS: amss, } if ttl == 0 { ttl = s.route.DefaultTTL() @@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + h.ep.mu.RLock() + amss := h.ep.amss + h.ep.mu.RUnlock() + h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: h.ep.amss, + MSS: amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. + h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } + h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the @@ -577,7 +584,7 @@ func (h *handshake) execute() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { + if (n¬ifyClose)|(n¬ifyAbort) != 0 { return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -617,17 +624,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions { var optionPool = sync.Pool{ New: func() interface{} { - return make([]byte, maxOptionSize) + return &[maxOptionSize]byte{} }, } func getOptions() []byte { - return optionPool.Get().([]byte) + return (*optionPool.Get().(*[maxOptionSize]byte))[:] } func putOptions(options []byte) { // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) + optionPool.Put(optionsToArray(options)) } func makeSynOptions(opts header.TCPSynOptions) []byte { diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f1ad19dac..dc9c18b6f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -798,7 +798,21 @@ func (e *endpoint) Abort() { // If the endpoint disconnected after the check, nothing needs to be // done, so sending a notification which will potentially be ignored is // fine. - if e.EndpointState().connected() { + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { e.notifyProtocolGoroutine(notifyAbort) return } @@ -945,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { + e.mu.RLock() e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() + e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() + e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -994,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvBufSize = rcvWnd availAfter := e.receiveBufferAvailableLocked() mask := uint32(notifyReceiveWindowChanged) - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.notifyProtocolGoroutine(mask) @@ -1009,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() + e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1038,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() if err == tcpip.ErrClosedForReceive { @@ -1071,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1289,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// windowCrossedACKThreshold checks if the receive window to be announced now -// would be under aMSS or under half receive buffer, whichever smaller. This is -// useful as a receive side silly window syndrome prevention mechanism. If +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If // window grows to reasonable value, we should send ACK to the sender to inform // the rx space is now large. We also want ensure a series of small read()'s // won't trigger a flood of spurious tiny ACK's. @@ -1302,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. -func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { newAvail := e.receiveBufferAvailableLocked() oldAvail := newAvail - deltaBefore if oldAvail < 0 { @@ -1365,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) + e.mu.RLock() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1391,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Immediately send an ACK to uncork the sender silly window // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.notifyProtocolGoroutine(mask) return nil @@ -1854,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1890,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2096,10 +2117,13 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Close for write. if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() + if e.EndpointState() == StateTimeWait { + e.mu.Unlock() + return tcpip.ErrNotConnected + } break } @@ -2256,7 +2280,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2400,13 +2424,14 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { + e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() // Increase counter if the receive window falls down below MSS // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } e.rcvList.PushBack(s) @@ -2414,7 +2439,7 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 958f03ac1..d80aff1b6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -195,6 +195,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum for i := first; i < len(r.pendingRcvdSegments); i++ { r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index 9fd061d7d..e28f213ba 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -41,6 +41,7 @@ func (h *segmentHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[:n-1] return x } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1e9a0dea3..8cea20fb5 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -204,6 +204,7 @@ func (c *Context) Cleanup() { if c.EP != nil { c.EP.Close() } + c.Stack().Close() } // Stack returns a reference to the stack in the Context. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1c6a600b8..0af4514e1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicID, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port } if route.IsResolutionRequired() { @@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err @@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 43fb047ed..466bd9381 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 35391030f..7ea5bfade 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -158,6 +158,9 @@ type Config struct { // DebugLog is the path to log debug information to, if not empty. DebugLog string + // PanicLog is the path to log GO's runtime messages, if not empty. + PanicLog string + // DebugLogFormat is the log format for debug. DebugLogFormat string @@ -269,6 +272,7 @@ func (c *Config) ToFlags() []string { "--log=" + c.LogFilename, "--log-format=" + c.LogFormat, "--debug-log=" + c.DebugLog, + "--panic-log=" + c.PanicLog, "--debug-log-format=" + c.DebugLogFormat, "--file-access=" + c.FileAccess.String(), "--overlay=" + strconv.FormatBool(c.Overlay), diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 17e774e0c..8125d5061 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -101,11 +101,14 @@ const ( // Profiling related commands (see pprof.go for more details). const ( - StartCPUProfile = "Profile.StartCPUProfile" - StopCPUProfile = "Profile.StopCPUProfile" - HeapProfile = "Profile.HeapProfile" - StartTrace = "Profile.StartTrace" - StopTrace = "Profile.StopTrace" + StartCPUProfile = "Profile.StartCPUProfile" + StopCPUProfile = "Profile.StopCPUProfile" + HeapProfile = "Profile.HeapProfile" + GoroutineProfile = "Profile.GoroutineProfile" + BlockProfile = "Profile.BlockProfile" + MutexProfile = "Profile.MutexProfile" + StartTrace = "Profile.StartTrace" + StopTrace = "Profile.StopTrace" ) // Logging related commands (see logging.go for more details). diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 79965460e..b5de2588b 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -32,17 +32,20 @@ import ( // Debug implements subcommands.Command for the "debug" command. type Debug struct { - pid int - stacks bool - signal int - profileHeap string - profileCPU string - trace string - strace string - logLevel string - logPackets string - duration time.Duration - ps bool + pid int + stacks bool + signal int + profileHeap string + profileCPU string + profileGoroutine string + profileBlock string + profileMutex string + trace string + strace string + logLevel string + logPackets string + duration time.Duration + ps bool } // Name implements subcommands.Command. @@ -66,6 +69,9 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log") f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.") f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.") + f.StringVar(&d.profileGoroutine, "profile-goroutine", "", "writes goroutine profile to the given file.") + f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.") + f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.") f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles") f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.") f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox") @@ -147,6 +153,42 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } log.Infof("Heap profile written to %q", d.profileHeap) } + if d.profileGoroutine != "" { + f, err := os.Create(d.profileGoroutine) + if err != nil { + return Errorf(err.Error()) + } + defer f.Close() + + if err := c.Sandbox.GoroutineProfile(f); err != nil { + return Errorf(err.Error()) + } + log.Infof("Goroutine profile written to %q", d.profileGoroutine) + } + if d.profileBlock != "" { + f, err := os.Create(d.profileBlock) + if err != nil { + return Errorf(err.Error()) + } + defer f.Close() + + if err := c.Sandbox.BlockProfile(f); err != nil { + return Errorf(err.Error()) + } + log.Infof("Block profile written to %q", d.profileBlock) + } + if d.profileMutex != "" { + f, err := os.Create(d.profileMutex) + if err != nil { + return Errorf(err.Error()) + } + defer f.Close() + + if err := c.Sandbox.MutexProfile(f); err != nil { + return Errorf(err.Error()) + } + log.Infof("Mutex profile written to %q", d.profileMutex) + } delay := false if d.profileCPU != "" { diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index c2518d52b..651615d4c 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -333,13 +333,13 @@ func TestJobControlSignalRootContainer(t *testing.T) { // file. Writes after a certain point will block unless we drain the // PTY, so we must continually copy from it. // - // We log the output to stdout for debugabilitly, and also to a buffer, + // We log the output to stderr for debugabilitly, and also to a buffer, // since we wait on particular output from bash below. We use a custom // blockingBuffer which is thread-safe and also blocks on Read calls, // which makes this a suitable Reader for WaitUntilRead. ptyBuf := newBlockingBuffer() tee := io.TeeReader(ptyMaster, ptyBuf) - go io.Copy(os.Stdout, tee) + go io.Copy(os.Stderr, tee) // Start the container. if err := c.Start(conf); err != nil { diff --git a/runsc/container/container.go b/runsc/container/container.go index 68782c4be..c9839044c 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -17,6 +17,7 @@ package container import ( "context" + "errors" "fmt" "io/ioutil" "os" @@ -1066,18 +1067,10 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error { // adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer. func (c *Container) adjustGoferOOMScoreAdj() error { - if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil { - if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil { - // Ignore NotExist error because it can be returned when the sandbox - // exited while OOM score was being adjusted. - if !os.IsNotExist(err) { - return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err) - } - log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid) - } + if c.GoferPid == 0 || c.Spec.Process.OOMScoreAdj == nil { + return nil } - - return nil + return setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj) } // adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox. @@ -1154,29 +1147,29 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) } // Set the lowest of all containers oom_score_adj to the sandbox. - if err := setOOMScoreAdj(s.Pid, lowScore); err != nil { - // Ignore NotExist error because it can be returned when the sandbox - // exited while OOM score was being adjusted. - if !os.IsNotExist(err) { - return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err) - } - log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid) - } - - return nil + return setOOMScoreAdj(s.Pid, lowScore) } // setOOMScoreAdj sets oom_score_adj to the given value for the given PID. // /proc must be available and mounted read-write. scoreAdj should be between -// -1000 and 1000. +// -1000 and 1000. It's a noop if the process has already exited. func setOOMScoreAdj(pid int, scoreAdj int) error { f, err := os.OpenFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid), os.O_WRONLY, 0644) if err != nil { + // Ignore NotExist errors because it can race with process exit. + if os.IsNotExist(err) { + log.Warningf("Process (%d) not found setting oom_score_adj", pid) + return nil + } return err } defer f.Close() if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil { - return err + if errors.Is(err, syscall.ESRCH) { + log.Warningf("Process (%d) exited while setting oom_score_adj", pid) + return nil + } + return fmt.Errorf("setting oom_score_adj to %q: %v", scoreAdj, err) } return nil } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index bdd65b498..c7eea85b3 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -2092,7 +2092,7 @@ func TestOverlayfsStaleRead(t *testing.T) { defer out.Close() const want = "foobar" - cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name()) + cmd := fmt.Sprintf("cat %q >&2 && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name()) spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd) if err := run(spec, conf); err != nil { t.Fatalf("Error running container: %v", err) diff --git a/runsc/main.go b/runsc/main.go index af73bed97..62e184ec9 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -54,9 +54,11 @@ var ( // Debugging flags. debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") + panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") logPackets = flag.Bool("log-packets", false, "enable network packet logging.") logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.") @@ -206,6 +208,7 @@ func main() { LogFilename: *logFilename, LogFormat: *logFormat, DebugLog: *debugLog, + PanicLog: *panicLog, DebugLogFormat: *debugLogFormat, FileAccess: fsAccess, FSGoferHostUDS: *fsGoferHostUDS, @@ -258,20 +261,6 @@ func main() { if *debugLogFD > -1 { f := os.NewFile(uintptr(*debugLogFD), "debug log file") - // Quick sanity check to make sure no other commands get passed - // a log fd (they should use log dir instead). - if subcommand != "boot" && subcommand != "gofer" { - cmd.Fatalf("flag --debug-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand) - } - - // If we are the boot process, then we own our stdio FDs and can do what we - // want with them. Since Docker and Containerd both eat boot's stderr, we - // dup our stderr to the provided log FD so that panics will appear in the - // logs, rather than just disappear. - if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil { - cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err) - } - e = newEmitter(*debugLogFormat, f) } else if *debugLog != "" { @@ -287,6 +276,26 @@ func main() { e = newEmitter("text", ioutil.Discard) } + if *panicLogFD > -1 || *debugLogFD > -1 { + fd := *panicLogFD + if fd < 0 { + fd = *debugLogFD + } + // Quick sanity check to make sure no other commands get passed + // a log fd (they should use log dir instead). + if subcommand != "boot" && subcommand != "gofer" { + cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand) + } + + // If we are the boot process, then we own our stdio FDs and can do what we + // want with them. Since Docker and Containerd both eat boot's stderr, we + // dup our stderr to the provided log FD so that panics will appear in the + // logs, rather than just disappear. + if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { + cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) + } + } + if *alsoLogToStderr { e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ec72bdbfd..192bde40c 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -369,6 +369,24 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF cmd.Args = append(cmd.Args, "--debug-log-fd="+strconv.Itoa(nextFD)) nextFD++ } + if conf.PanicLog != "" { + test := "" + if len(conf.TestOnlyTestNameEnv) != 0 { + // Fetch test name if one is provided and the test only flag was set. + if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok { + test = t + } + } + + panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test) + if err != nil { + return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err) + } + defer panicLogFile.Close() + cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile) + cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD)) + nextFD++ + } cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM))) @@ -972,6 +990,66 @@ func (s *Sandbox) StopCPUProfile() error { return nil } +// GoroutineProfile writes a goroutine profile to the given file. +func (s *Sandbox) GoroutineProfile(f *os.File) error { + log.Debugf("Goroutine profile %q", s.ID) + conn, err := s.sandboxConnect() + if err != nil { + return err + } + defer conn.Close() + + opts := control.ProfileOpts{ + FilePayload: urpc.FilePayload{ + Files: []*os.File{f}, + }, + } + if err := conn.Call(boot.GoroutineProfile, &opts, nil); err != nil { + return fmt.Errorf("getting sandbox %q goroutine profile: %v", s.ID, err) + } + return nil +} + +// BlockProfile writes a block profile to the given file. +func (s *Sandbox) BlockProfile(f *os.File) error { + log.Debugf("Block profile %q", s.ID) + conn, err := s.sandboxConnect() + if err != nil { + return err + } + defer conn.Close() + + opts := control.ProfileOpts{ + FilePayload: urpc.FilePayload{ + Files: []*os.File{f}, + }, + } + if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil { + return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err) + } + return nil +} + +// MutexProfile writes a mutex profile to the given file. +func (s *Sandbox) MutexProfile(f *os.File) error { + log.Debugf("Mutex profile %q", s.ID) + conn, err := s.sandboxConnect() + if err != nil { + return err + } + defer conn.Close() + + opts := control.ProfileOpts{ + FilePayload: urpc.FilePayload{ + Files: []*os.File{f}, + }, + } + if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil { + return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err) + } + return nil +} + // StartTrace start trace writing to the given file. func (s *Sandbox) StartTrace(f *os.File) error { log.Debugf("Trace start %q", s.ID) diff --git a/scripts/build.sh b/scripts/build.sh index 4c042af6c..7c9c99800 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -17,7 +17,7 @@ source $(dirname $0)/common.sh # Install required packages for make_repository.sh et al. -sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils +apt_install dpkg-sig coreutils apt-utils xz-utils # Build runsc. runsc=$(build -c opt //runsc) diff --git a/scripts/common.sh b/scripts/common.sh index 3ca699e4a..735a383de 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -84,3 +84,17 @@ function install_runsc() { # Restart docker to pick up the new runtime configuration. sudo systemctl restart docker } + +# Installs the given packages. Note that the package names should be verified to +# be correct, otherwise this may result in a loop that spins until time out. +function apt_install() { + while true; do + if (sudo apt-get update && sudo apt-get install -y "$@"); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + return $result + fi + done +} diff --git a/scripts/release.sh b/scripts/release.sh index 091abf87f..e14ba04a7 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -25,6 +25,14 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then echo "No KOKORO_RELEASE_TAG provided." >&2 exit 1 fi +if ! [[ -v KOKORO_RELNOTES ]]; then + echo "No KOKORO_RELNOTES provided." >&2 + exit 1 +fi +if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then + echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2 + exit 1 +fi # Unless an explicit releaser is provided, use the bot e-mail. declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot} @@ -46,4 +54,7 @@ EOF fi # Run the release tool, which pushes to the origin repository. -tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}" +tools/tag_release.sh \ + "${KOKORO_RELEASE_COMMIT}" \ + "${KOKORO_RELEASE_TAG}" \ + "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index d113555b1..fb0b2db41 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -1,8 +1,48 @@ -load("defs.bzl", "packetdrill_test") +load("defs.bzl", "packetdrill_linux_test", "packetdrill_netstack_test", "packetdrill_test") package(licenses = ["notice"]) packetdrill_test( - name = "fin_wait2_timeout", + name = "packetdrill_sanity_test", + scripts = ["sanity_test.pkt"], +) + +packetdrill_test( + name = "accept_ack_drop_test", + scripts = ["accept_ack_drop.pkt"], +) + +packetdrill_test( + name = "fin_wait2_timeout_test", scripts = ["fin_wait2_timeout.pkt"], ) + +packetdrill_linux_test( + name = "tcp_user_timeout_test_linux_test", + scripts = ["linux/tcp_user_timeout.pkt"], +) + +packetdrill_netstack_test( + name = "tcp_user_timeout_test_netstack_test", + scripts = ["netstack/tcp_user_timeout.pkt"], +) + +packetdrill_test( + name = "listen_close_before_handshake_complete_test", + scripts = ["listen_close_before_handshake_complete.pkt"], +) + +packetdrill_test( + name = "no_rst_to_rst_test", + scripts = ["no_rst_to_rst.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_test", + scripts = ["tcp_defer_accept.pkt"], +) + +packetdrill_test( + name = "tcp_defer_accept_timeout_test", + scripts = ["tcp_defer_accept_timeout.pkt"], +) diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt new file mode 100644 index 000000000..76e638fd4 --- /dev/null +++ b/test/packetdrill/accept_ack_drop.pkt @@ -0,0 +1,27 @@ +// Test that the accept works if the final ACK is dropped and an ack with data +// follows the dropped ack. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + ++0.0 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl index 8623ce7b1..f499c177b 100644 --- a/test/packetdrill/defs.bzl +++ b/test/packetdrill/defs.bzl @@ -66,7 +66,7 @@ def packetdrill_linux_test(name, **kwargs): if "tags" not in kwargs: kwargs["tags"] = _PACKETDRILL_TAGS _packetdrill_test( - name = name + "_linux_test", + name = name, flags = ["--dut_platform", "linux"], **kwargs ) @@ -75,13 +75,13 @@ def packetdrill_netstack_test(name, **kwargs): if "tags" not in kwargs: kwargs["tags"] = _PACKETDRILL_TAGS _packetdrill_test( - name = name + "_netstack_test", + name = name, # This is the default runtime unless # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"], **kwargs ) -def packetdrill_test(**kwargs): - packetdrill_linux_test(**kwargs) - packetdrill_netstack_test(**kwargs) +def packetdrill_test(name, **kwargs): + packetdrill_linux_test(name + "_linux_test", **kwargs) + packetdrill_netstack_test(name + "_netstack_test", **kwargs) diff --git a/test/packetdrill/linux/tcp_user_timeout.pkt b/test/packetdrill/linux/tcp_user_timeout.pkt new file mode 100644 index 000000000..38018cb42 --- /dev/null +++ b/test/packetdrill/linux/tcp_user_timeout.pkt @@ -0,0 +1,39 @@ +// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection +// if there is pending unacked data after the user specified timeout. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0.1 < . 1:1(0) ack 1 win 32792 + ++0.100 accept(3, ..., ...) = 4 + +// Okay, we received nothing, and decide to close this idle socket. +// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth +// trying hard to cleanly close this flow, at the price of keeping +// a TCP structure in kernel for about 1 minute! ++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 + +// The write/ack is required mainly for netstack as netstack does +// not update its RTO during the handshake. ++0 write(4, ..., 100) = 100 ++0 > P. 1:101(100) ack 1 <...> ++0 < . 1:1(0) ack 101 win 32792 + ++0 close(4) = 0 + ++0 > F. 101:101(0) ack 1 <...> ++.3~+.400 > F. 101:101(0) ack 1 <...> ++.3~+.400 > F. 101:101(0) ack 1 <...> ++.6~+.800 > F. 101:101(0) ack 1 <...> ++1.2~+1.300 > F. 101:101(0) ack 1 <...> + +// We finally receive something from the peer, but it is way too late +// Our socket vanished because TCP_USER_TIMEOUT was really small. ++.1 < . 1:2(1) ack 102 win 32792 ++0 > R 102:102(0) win 0 diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt new file mode 100644 index 000000000..51c3f1a32 --- /dev/null +++ b/test/packetdrill/listen_close_before_handshake_complete.pkt @@ -0,0 +1,31 @@ +// Test that closing a listening socket closes any connections in SYN-RCVD +// state and any packets bound for these connections generate a RESET. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> + ++0.100 close(3) = 0 ++0.1 < P. 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.0 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/netstack/tcp_user_timeout.pkt b/test/packetdrill/netstack/tcp_user_timeout.pkt new file mode 100644 index 000000000..60103adba --- /dev/null +++ b/test/packetdrill/netstack/tcp_user_timeout.pkt @@ -0,0 +1,38 @@ +// Test that a socket w/ TCP_USER_TIMEOUT set aborts the connection +// if there is pending unacked data after the user specified timeout. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0.1 < . 1:1(0) ack 1 win 32792 + ++0.100 accept(3, ..., ...) = 4 + +// Okay, we received nothing, and decide to close this idle socket. +// We set TCP_USER_TIMEOUT to 3 seconds because really it is not worth +// trying hard to cleanly close this flow, at the price of keeping +// a TCP structure in kernel for about 1 minute! ++2 setsockopt(4, SOL_TCP, TCP_USER_TIMEOUT, [3000], 4) = 0 + +// The write/ack is required mainly for netstack as netstack does +// not update its RTO during the handshake. ++0 write(4, ..., 100) = 100 ++0 > P. 1:101(100) ack 1 <...> ++0 < . 1:1(0) ack 101 win 32792 + ++0 close(4) = 0 + ++0 > F. 101:101(0) ack 1 <...> ++.2~+.300 > F. 101:101(0) ack 1 <...> ++.4~+.500 > F. 101:101(0) ack 1 <...> ++.8~+.900 > F. 101:101(0) ack 1 <...> + +// We finally receive something from the peer, but it is way too late +// Our socket vanished because TCP_USER_TIMEOUT was really small. ++1.61 < . 1:2(1) ack 102 win 32792 ++0 > R 102:102(0) win 0 diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt new file mode 100644 index 000000000..612747827 --- /dev/null +++ b/test/packetdrill/no_rst_to_rst.pkt @@ -0,0 +1,36 @@ +// Test a RST is not generated in response to a RST and a RST is correctly +// generated when an accepted endpoint is RST due to an incoming RST. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0 > S. 0:0(0) ack 1 <...> ++0 < P. 1:1(0) ack 1 win 257 + ++0.100 accept(3, ..., ...) = 4 + ++0.200 < R 1:1(0) win 0 + ++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer) + ++0.00 < . 1:1(0) ack 1 win 257 + +// Linux generates a reset with no ack number/bit set. This is contradictory to +// what is specified in Rule 1 under Reset Generation in +// https://tools.ietf.org/html/rfc793#section-3.4. +// "1. If the connection does not exist (CLOSED) then a reset is sent +// in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected +// by this means. +// +// If the incoming segment has an ACK field, the reset takes its +// sequence number from the ACK field of the segment, otherwise the +// reset has sequence number zero and the ACK field is set to the sum +// of the sequence number and segment length of the incoming segment. +// The connection remains in the CLOSED state." + ++0.00 > R 1:1(0) win 0
\ No newline at end of file diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh index 0b22dfd5c..c8268170f 100755 --- a/test/packetdrill/packetdrill_test.sh +++ b/test/packetdrill/packetdrill_test.sh @@ -91,8 +91,8 @@ fi # Variables specific to the test runner start with TEST_RUNNER_. declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill" # Use random numbers so that test networks don't collide. -declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}" -declare -r TEST_NET="test_net-${RANDOM}${RANDOM}" +declare -r CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)" +declare -r TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)" declare -r tolerance_usecs=100000 # On both DUT and test runner, testing packets are on the eth2 interface. declare -r TEST_DEVICE="eth2" diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt new file mode 100644 index 000000000..a86b90ce6 --- /dev/null +++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt @@ -0,0 +1,9 @@ +// Test that a listening socket generates a RST when it receives an +// ACK and syn cookies are not in use. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 bind(3, ..., ...) = 0 + ++0 listen(3, 1) = 0 ++0.1 < . 1:1(0) ack 1 win 32792 ++0 > R 1:1(0) ack 0 win 0
\ No newline at end of file diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt new file mode 100644 index 000000000..b3b58c366 --- /dev/null +++ b/test/packetdrill/sanity_test.pkt @@ -0,0 +1,7 @@ +// Basic sanity test. One system call. +// +// All of the plumbing has to be working however, and the packetdrill wire +// client needs to be able to connect to the wire server and send the script, +// probe local interfaces, run through the test w/ timings, etc. + +0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt new file mode 100644 index 000000000..a17f946db --- /dev/null +++ b/test/packetdrill/tcp_defer_accept.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT +// timeout is not hit but an ACK w/ data does complete and deliver the +// connection to the accept queue. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// Now send an ACK w/ data. This should complete the connection +// and deliver the socket to the accept queue. ++0.1 < . 1:5(4) ack 1 win 257 ++0.0 > . 1:1(0) ack 5 <...> + +// This should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + +// Now read the data and we should get 4 bytes. ++0.000 read(4,..., 4) = 4 ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 5 <...> ++0.0 < F. 5:5(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 6 <...>
\ No newline at end of file diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt new file mode 100644 index 000000000..201fdeb14 --- /dev/null +++ b/test/packetdrill/tcp_defer_accept_timeout.pkt @@ -0,0 +1,48 @@ +// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout +// is hit and a connection is delivered. + +0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3 ++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0 ++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0 ++0 bind(3, ..., ...) = 0 + +// Set backlog to 1 so that we can easily test. ++0 listen(3, 1) = 0 + +// Establish a connection without timestamps. ++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7> ++0.0 > S. 0:0(0) ack 1 <...> + +// Send a bare ACK this should not complete the connection as we +// set the TCP_DEFER_ACCEPT above. ++0.0 < . 1:1(0) ack 1 win 257 + +// The bare ACK should be dropped and no connection should be delivered +// to the accept queue. ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT +// to 5 seconds above. ++2.5 < . 1:1(0) ack 1 win 257 ++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block) + +// set accept socket back to blocking. ++0.000 fcntl(3, F_SETFL, O_RDWR) = 0 + +// We should see one more retransmit of the SYN-ACK as a last ditch +// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another +// ACK or a packet with data. ++.35~+2.35 > S. 0:0(0) ack 1 <...> + +// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed. ++0.0 < . 1:1(0) ack 1 win 257 + +// The ACK above should cause connection to transition to connected state. ++0.000 accept(3, ..., ...) = 4 ++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0 + ++0.000 close(4) = 0 + ++0.0 > F. 1:1(0) ack 1 <...> ++0.0 < F. 1:1(0) ack 2 win 257 ++0.01 > . 2:2(0) ack 2 <...> diff --git a/test/perf/BUILD b/test/perf/BUILD index 346a28e16..0a0def6a3 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -30,6 +30,7 @@ syscall_test( syscall_test( size = "enormous", + tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", ) @@ -40,6 +41,7 @@ syscall_test( syscall_test( size = "enormous", + tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", ) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index f96e2415e..869169ad5 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -66,13 +66,12 @@ func (tc TestCase) Args() []string { } if tc.benchmark { return []string{ - fmt.Sprintf("%s=^$", filterTestFlag), fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + fmt.Sprintf("%s=", filterTestFlag), } } return []string{ - fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()), - fmt.Sprintf("%s=^$", filterBenchmarkFlag), + fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), } } @@ -147,6 +146,8 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) } + out = []byte(strings.Trim(string(out), "\n")) + // Parse benchmark output. for _, line := range strings.Split(string(out), "\n") { // Strip comments. diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 3518e862d..9800a0cdf 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -318,10 +318,14 @@ syscall_test( test = "//test/syscalls/linux:proc_test", ) -syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") - syscall_test(test = "//test/syscalls/linux:proc_net_test") +syscall_test(test = "//test/syscalls/linux:proc_pid_oomscore_test") + +syscall_test(test = "//test/syscalls/linux:proc_pid_smaps_test") + +syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test") + syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", @@ -680,6 +684,11 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:tuntap_test") +syscall_test( + add_hostinet = True, + test = "//test/syscalls/linux:tuntap_hostinet_test", +) + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 704bae17b..43455f1a3 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -167,11 +167,6 @@ cc_library( ) cc_library( - name = "temp_umask", - hdrs = ["temp_umask.h"], -) - -cc_library( name = "unix_domain_socket_test_util", testonly = 1, srcs = ["unix_domain_socket_test_util.cc"], @@ -608,7 +603,10 @@ cc_binary( cc_binary( name = "exceptions_test", testonly = 1, - srcs = ["exceptions.cc"], + srcs = select_arch( + amd64 = ["exceptions.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ gtest, @@ -1140,11 +1138,11 @@ cc_binary( srcs = ["mkdir.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], @@ -1299,12 +1297,12 @@ cc_binary( srcs = ["open_create.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], @@ -1475,7 +1473,10 @@ cc_binary( cc_binary( name = "arch_prctl_test", testonly = 1, - srcs = ["arch_prctl.cc"], + srcs = select_arch( + amd64 = ["arch_prctl.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ "//test/util:file_descriptor", @@ -1631,6 +1632,19 @@ cc_binary( ) cc_binary( + name = "proc_pid_oomscore_test", + testonly = 1, + srcs = ["proc_pid_oomscore.cc"], + linkstatic = 1, + deps = [ + "//test/util:fs_util", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( name = "proc_pid_smaps_test", testonly = 1, srcs = ["proc_pid_smaps.cc"], @@ -3322,7 +3336,10 @@ cc_binary( cc_binary( name = "sysret_test", testonly = 1, - srcs = ["sysret.cc"], + srcs = select_arch( + amd64 = ["sysret.cc"], + arm64 = [], + ), linkstatic = 1, deps = [ gtest, @@ -3460,6 +3477,18 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_hostinet_test", + testonly = 1, + srcs = ["tuntap_hostinet.cc"], + linkstatic = 1, + deps = [ + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc index adfb149df..a26fc6af3 100644 --- a/test/syscalls/linux/bad.cc +++ b/test/syscalls/linux/bad.cc @@ -28,7 +28,7 @@ namespace { constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms; #elif __aarch64__ // Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kNotImplementedSyscall = SYS_arch_specific_syscall + 15; +constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15; #endif TEST(BadSyscallTest, NotImplemented) { diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4e473268c..4dd302eed 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,13 +153,6 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } -TEST(DevTest, NetTunExists) { - struct stat statbuf = {}; - ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds()); - // Check that it's a character device with rw-rw-rw- permissions. - EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); -} - } // namespace } // namespace testing diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index cf138d328..def4c50a4 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -18,10 +18,10 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 902d0a0dc..51eacf3f2 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -19,11 +19,11 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 92ae55eec..248762ca9 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <arpa/inet.h> +#include <ifaddrs.h> #include <linux/capability.h> #include <linux/if_arp.h> #include <linux/if_packet.h> @@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() { return ifr.ifr_ifindex; } -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - +// Receive and verify the message via packet socket on interface. +void ReceiveMessage(int sock, int ifindex) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = sock; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns @@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) { // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) { EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); } +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + int loopback_index = GetLoopbackIndex(); + ReceiveMessage(socket_, loopback_index); +} + // Send via a packet socket. TEST_P(CookedPacketTest, Send) { // TODO(b/129292371): Remove once we support packet socket writing. @@ -313,6 +322,114 @@ TEST_P(CookedPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Bind and receive via packet socket. +TEST_P(CookedPacketTest, BindReceive) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + ReceiveMessage(socket_, bind_addr.sll_ifindex); +} + +// Double Bind socket. +TEST_P(CookedPacketTest, DoubleBind) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Binding socket again should fail. + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched + // to EADDRINUSE. + AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); +} + +// Bind and verify we do not receive data on interface which is not bound +TEST_P(CookedPacketTest, BindDrop) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + + // Bind to packet socket requires only family, protocol and ifindex. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = ifr.ifr_ifindex; + + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Send to loopback interface. + struct sockaddr_in dest = {}; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket never receives any data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +// Bind with invalid address. +TEST_P(CookedPacketTest, BindFail) { + // Null address. + ASSERT_THAT(bind(socket_, nullptr, sizeof(struct sockaddr)), + SyscallFailsWithErrno(EFAULT)); + + // Address of size 1. + uint8_t addr = 0; + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index f91187e75..5a70f6c3b 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1431,6 +1431,12 @@ TEST(ProcPidFile, SubprocessRunning) { EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); } // Test whether /proc/PID/ files can be read for a zombie process. @@ -1466,6 +1472,12 @@ TEST(ProcPidFile, SubprocessZombie) { EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. // @@ -1527,6 +1539,15 @@ TEST(ProcPidFile, SubprocessExited) { EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); + + if (!IsRunningOnGvisor()) { + // FIXME(gvisor.dev/issue/164): Succeeds on gVisor. + EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); + } + + EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)), + SyscallFailsWithErrno(ESRCH)); } PosixError DirContainsImpl(absl::string_view path, diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc new file mode 100644 index 000000000..707821a3f --- /dev/null +++ b/test/syscalls/linux/proc_pid_oomscore.cc @@ -0,0 +1,72 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> + +#include <exception> +#include <iostream> +#include <string> + +#include "test/util/fs_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +PosixErrorOr<int> ReadProcNumber(std::string path) { + ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path)); + EXPECT_EQ(contents[contents.length() - 1], '\n'); + + int num; + if (!absl::SimpleAtoi(contents, &num)) { + return PosixError(EINVAL, "invalid value: " + contents); + } + + return num; +} + +TEST(ProcPidOomscoreTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score")); + EXPECT_LE(oom_score, 1000); + EXPECT_GE(oom_score, -1000); +} + +TEST(ProcPidOomscoreAdjTest, BasicRead) { + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + + // oom_score_adj defaults to 0. + EXPECT_EQ(oom_score, 0); +} + +TEST(ProcPidOomscoreAdjTest, BasicWrite) { + constexpr int test_value = 7; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY)); + ASSERT_THAT( + RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1), + SyscallSucceeds()); + + auto const oom_score = + ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj")); + EXPECT_EQ(oom_score, test_value); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index cf6499f8b..8e0fc9acc 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -53,7 +53,7 @@ namespace { constexpr uint32_t kFilteredSyscall = SYS_vserver; #elif __aarch64__ // Use the last of arch_specific_syscalls which are not implemented on arm64. -constexpr uint32_t kFilteredSyscall = SYS_arch_specific_syscall + 15; +constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15; #endif // Applies a seccomp-bpf filter that returns `filtered_result` for diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index e8f24a59e..f7d6139f1 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) { SyscallFailsWithErrno(EAGAIN)); } +TEST_P(AllSocketPairTest, SendTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), + SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 89); + EXPECT_EQ(actual_tv.tv_usec, 42000); +} + +TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct timeval_with_extra { + struct timeval tv; + int64_t extra_data; + } ABSL_ATTRIBUTE_PACKED; + + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 123000}, + }; + + EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &tv_extra, sizeof(tv_extra)), + SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv.tv_usec, 123000); +} + TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) { ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf))); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) { +TEST_P(AllSocketPairTest, RecvTimeoutDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - struct timeval tv { - .tv_sec = 0, .tv_usec = 35 - }; + timeval actual_tv = {.tv_sec = -1, .tv_usec = -1}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv_usec, 0); +} + +TEST_P(AllSocketPairTest, SetGetRecvTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + timeval tv = {.tv_sec = 123, .tv_usec = 456000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); + + timeval actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv_sec, 123); + EXPECT_EQ(actual_tv.tv_usec, 456000); } -TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { +TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct timeval_with_extra { @@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) { int64_t extra_data; } ABSL_ATTRIBUTE_PACKED; - timeval_with_extra tv_extra; - tv_extra.tv.tv_sec = 0; - tv_extra.tv.tv_usec = 25; + timeval_with_extra tv_extra = { + .tv = {.tv_sec = 0, .tv_usec = 432000}, + }; EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv_extra, sizeof(tv_extra)), SyscallSucceeds()); + + timeval_with_extra actual_tv = {}; + socklen_t len = sizeof(actual_tv); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, + &actual_tv, &len), + SyscallSucceeds()); + EXPECT_EQ(actual_tv.tv.tv_sec, 0); + EXPECT_EQ(actual_tv.tv.tv_usec, 432000); } TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) { diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index c4591a3b9..d9c1ac0e1 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -143,6 +143,20 @@ TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; @@ -1349,6 +1363,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { SyscallFailsWithErrno(ENOTCONN)); } +TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { + auto s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr), + addrlen); + int buf_sz = 1 << 18; + EXPECT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc index 1ccb95733..e75bba669 100644 --- a/test/syscalls/linux/time.cc +++ b/test/syscalls/linux/time.cc @@ -26,6 +26,7 @@ namespace { constexpr long kFudgeSeconds = 5; +#if defined(__x86_64__) || defined(__i386__) // Mimics the time(2) wrapper from glibc prior to 2.15. time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; @@ -98,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) { reinterpret_cast<struct timezone*>(0x1)), ::testing::KilledBySignal(SIGSEGV), ""); } +#endif } // namespace diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index f6ac9d7b8..f734511d6 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -153,6 +153,13 @@ std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, } // namespace +TEST(TuntapStaticTest, NetTunExists) { + struct stat statbuf; + ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + class TuntapTest : public ::testing::Test { protected: void TearDown() override { diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc new file mode 100644 index 000000000..1513fb9d5 --- /dev/null +++ b/test/syscalls/linux/tuntap_hostinet.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +TEST(TuntapHostInetTest, NoNetTun) { + SKIP_IF(!IsRunningOnGvisor()); + SKIP_IF(!IsRunningWithHostinet()); + + struct stat statbuf; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT)); +} + +} // namespace +} // namespace testing + +} // namespace gvisor diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc index 2c2303358..ae4377108 100644 --- a/test/syscalls/linux/vsyscall.cc +++ b/test/syscalls/linux/vsyscall.cc @@ -24,6 +24,7 @@ namespace testing { namespace { +#if defined(__x86_64__) || defined(__i386__) time_t vsyscall_time(time_t* t) { constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400; return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t); @@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) { time_t t; EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds()); } +#endif } // namespace diff --git a/test/util/BUILD b/test/util/BUILD index 8b5a0f25c..2a17c33ee 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -350,3 +350,9 @@ cc_library( ":save_util", ], ) + +cc_library( + name = "temp_umask", + testonly = 1, + hdrs = ["temp_umask.h"], +) diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h index 81a25440c..e7de84a54 100644 --- a/test/syscalls/linux/temp_umask.h +++ b/test/util/temp_umask.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ -#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_ +#define GVISOR_TEST_UTIL_TEMP_UMASK_H_ #include <sys/stat.h> #include <sys/types.h> @@ -36,4 +36,4 @@ class TempUmask { } // namespace testing } // namespace gvisor -#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_ diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index b5d5a4487..44cb33ae4 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -7,6 +7,9 @@ go_library( srcs = [ "generator.go", "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", "generator_tests.go", "util.go", ], diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index d365a1f3c..729489de5 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -235,6 +235,10 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) types = append(types, t) continue + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) + types = append(types, t) + continue } // A user specifically requested marshalling on this type, but we // don't support it. @@ -281,17 +285,20 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface i := newInterfaceGenerator(t, fset) switch ty := t.Type.(type) { case *ast.StructType: - i.validateStruct() - i.emitMarshallableForStruct() - return i + i.validateStruct(t, ty) + i.emitMarshallableForStruct(ty) case *ast.Ident: i.validatePrimitiveNewtype(ty) - i.emitMarshallableForPrimitiveNewtype() - return i + i.emitMarshallableForPrimitiveNewtype(ty) + case *ast.ArrayType: + i.validateArrayNewtype(t.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) default: // This should've been filtered out by collectMarshallabeTypes. panic(fmt.Sprintf("Unexpected type %+v", ty)) } + return i } // generateOneTestSuite generates a test suite for the automatically generated diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index ea1af998e..8babf61d2 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,10 +15,8 @@ package gomarshal import ( - "fmt" "go/ast" "go/token" - "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -81,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { g.as[fieldName] = struct{}{} } -func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - // abortAt aborts the go_marshal tool with the given error message, with a // reference position to the input source. Same as abortAt, but uses g to // resolve p to position. @@ -100,71 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we provide - // suggestions for some disallowed types and reject them, then attempt - // to marshal any remaining types by invoking the marshal.Marshallable - // interface on them. If these types don't actually implement - // marshal.Marshallable, compilation of the generated code will fail - // with an appropriate error message. - return - case "int": - g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } -} - -// validateStruct ensures the type we're working with can be marshalled. These -// checks are done ahead of time and in one place so we can make assumptions -// later. -func (g *interfaceGenerator) validateStruct() { - g.forEachField(func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - g.forEachField(func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - g.validatePrimitiveNewtype(t) - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n, _ *ast.Ident, len int) { - a := f.Type.(*ast.ArrayType) - if a.Len == nil { - g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } - - if len <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - // scalarSize returns the size of type identified by t. If t isn't a primitive // type, the size isn't known at code generation time, and must be resolved via // the marshal.Marshallable interface. @@ -191,8 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. -func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -215,9 +136,8 @@ func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar stri } } -// unmarshalStructFieldScalar reads a single scalar field from a struct, from a -// byte slice. -func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { switch typ { case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) @@ -243,580 +163,3 @@ func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar st g.recordPotentiallyNonPackedField(accessor) } } - -// marshalPrimitiveScalar writes a single primitive variable to a byte slice. -func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) - default: - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. -func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { - switch typ { - case "byte": - g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) - case "int8", "uint8": - g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) - - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) - default: - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns <clause>, true, where <clause> is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor, _ := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - return strings.Join(cs, " && "), true -} - -func (g *interfaceGenerator) emitMarshallableForStruct() { - // Is g.t a packed struct without consideing field types? - thisPacked := true - g.forEachField(func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false - } - } - } - }) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(n, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("return err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("if err != nil {\n") - g.inIndent(func() { - g.emit("return err\n") - }) - g.emit("}\n") - - g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return nil\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast deserialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("n, err := w.Write(buf)\n") - g.emit("return int64(n), err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") -} - -// emitMarshallableForPrimitiveNewtype outputs code to implement the -// marshal.Marshallable interface for a newtype on a primitive. Primitive -// newtypes are always packed, so we can omit the various fallbacks required for -// non-packed structs. -func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - g.recordUsedImport("usermem") - - nt := g.t.Type.(*ast.Ident) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(nt); !dynamic { - g.emit("return %d\n", size) - } else { - g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.marshalPrimitiveScalar(g.r, nt.Name, "dst") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Scalar newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - - }) - g.emit("}\n\n") - -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..da36d9305 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if arrayLen(a) <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d\n", size*len) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..159397825 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..e66a38b2e --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,450 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + // Is g.t a packed struct without consideing field types? + thisPacked := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 8ba47eb67..fd992e44a 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -164,7 +164,7 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") g.inIndent(func() { - g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)") + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") }) g.emit("}\n") }) diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index e2bca4e7c..a0936e013 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -64,6 +64,12 @@ func kindString(e ast.Expr) string { } } +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { @@ -73,6 +79,25 @@ type fieldDispatcher struct { unhandled func(n *ast.Ident) } +// Precondition: a must have a literal for the array length. Consts and +// expressions are not allowed as array lengths, and should be rejected by the +// caller. +func arrayLen(a *ast.ArrayType) int { + if a.Len == nil { + // Probably a slice? Must be handled by caller. + panic("Nil array length in array type") + } + lenLit, ok := a.Len.(*ast.BasicLit) + if !ok { + panic("Array has non-literal for length") + } + len, err := strconv.Atoi(lenLit.Value) + if err != nil { + panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) + } + return len +} + // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: - len := 0 - if v.Len != nil { - // Non-literal array length is handled by generatorInterfaces.validate(). - if lenLit, ok := v.Len.(*ast.BasicLit); ok { - var err error - len, err = strconv.Atoi(lenLit.Value) - if err != nil { - panic(err) - } - } - } switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, len) + fd.array(name, t, arrayLen(v)) default: - fd.array(name, nil, len) + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 93229dedb..c829db6da 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -104,6 +104,11 @@ type Stat struct { _ [3]int64 } +// InetAddr is an example marshallable newtype on an array. +// +// +marshal +type InetAddr [4]byte + // SignalSet is an example marshallable newtype on a primitive. // // +marshal diff --git a/tools/go_mod.sh b/tools/go_mod.sh new file mode 100755 index 000000000..84b779d6d --- /dev/null +++ b/tools/go_mod.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Build the :gopath target. +bazel build //:gopath +declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" + +# Copy go.mod and execute the command. +cp -a go.mod go.sum "${gopathdir}" +(cd "${gopathdir}" && go mod "$@") +cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . + +# Cleanup the WORKSPACE file. +bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh index 46dda6bb1..cd518d6ac 100755 --- a/tools/images/ubuntu1604/10_core.sh +++ b/tools/images/ubuntu1604/10_core.sh @@ -17,7 +17,20 @@ set -xeo pipefail # Install all essential build tools. -apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config +while true; do + if (apt-get update && apt-get install -y \ + make \ + git-core \ + build-essential \ + linux-headers-$(uname -r) \ + pkg-config); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install a recent go toolchain. if ! [[ -d /usr/local/go ]]; then diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh index b33e1656c..bb7afa676 100755 --- a/tools/images/ubuntu1604/20_bazel.sh +++ b/tools/images/ubuntu1604/20_bazel.sh @@ -19,7 +19,17 @@ set -xeo pipefail declare -r BAZEL_VERSION=2.0.0 # Install bazel dependencies. -apt-get update && apt-get install -y openjdk-8-jdk-headless unzip +while true; do + if (apt-get update && apt-get install -y \ + openjdk-8-jdk-headless \ + unzip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Use the release installer. curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh index 1d3defcd3..11eea2d72 100755 --- a/tools/images/ubuntu1604/25_docker.sh +++ b/tools/images/ubuntu1604/25_docker.sh @@ -15,12 +15,20 @@ # limitations under the License. # Add dependencies. -apt-get update && apt-get -y install \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common +while true; do + if (apt-get update && apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install the key. curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - @@ -32,4 +40,15 @@ add-apt-repository \ stable" # Install docker. -apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io +while true; do + if (apt-get update && apt-get install -y \ + docker-ce \ + docker-ce-cli \ + containerd.io); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh index a7472bd1c..fb3699c12 100755 --- a/tools/images/ubuntu1604/30_containerd.sh +++ b/tools/images/ubuntu1604/30_containerd.sh @@ -34,7 +34,17 @@ install_helper() { } # Install dependencies for the crictl tests. -apt-get install -y btrfs-tools libseccomp-dev +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install containerd & cri-tools. GOPATH=$(mktemp -d --tmpdir gopathXXXXX) diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh index 5f2dfc858..06a1e6c48 100755 --- a/tools/images/ubuntu1604/40_kokoro.sh +++ b/tools/images/ubuntu1604/40_kokoro.sh @@ -23,7 +23,22 @@ declare -r ssh_public_keys=( ) # Install dependencies. -apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip python3-pip zip +while true; do + if (apt-get update && apt-get install -y \ + rsync \ + coreutils \ + python-psutil \ + qemu-kvm \ + python-pip \ + python3-pip \ + zip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # junitparser is used to merge junit xml files. pip install junitparser diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 52f9734a6..2c6001c6c 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -19,17 +19,16 @@ set -e curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" + while true; do - if apt-get update; then - apt-get install -y runsc + if (apt-get update && apt-get install -y runsc); then break fi result=$? - # Check if apt update failed to aquire the file lock. if [[ $result -ne 100 ]]; then exit $result fi done + runsc install service docker restart - diff --git a/tools/tag_release.sh b/tools/tag_release.sh index f33b902d6..4dbfe420a 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -21,13 +21,19 @@ set -xeu # Check arguments. -if [ "$#" -ne 2 ]; then - echo "usage: $0 <commit|revid> <release.rc>" +if [ "$#" -ne 3 ]; then + echo "usage: $0 <commit|revid> <release.rc> <message-file>" exit 1 fi declare -r target_commit="$1" declare -r release="$2" +declare -r message_file="$3" + +if ! [[ -r "${message_file}" ]]; then + echo "error: message file '${message_file}' is not readable." + exit 1 +fi closest_commit() { while read line; do @@ -64,6 +70,6 @@ fi # Tag the given commit (annotated, to record the committer). declare -r tag="release-${release}" -(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \ +(git tag -F "${message_file}" -a "${tag}" "${commit}" && \ git push origin tag "${tag}") || \ (git tag -d "${tag}" && false) diff --git a/vdso/vdso.cc b/vdso/vdso.cc index 8bb80a7a4..c2585d592 100644 --- a/vdso/vdso.cc +++ b/vdso/vdso.cc @@ -126,6 +126,10 @@ extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) { case CLOCK_REALTIME: case CLOCK_MONOTONIC: case CLOCK_BOOTTIME: { + if (res == nullptr) { + return 0; + } + res->tv_sec = 0; res->tv_nsec = 1; break; |