From 14959250feb71df74dea13f3cb15dcbe8ce6b3f3 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 30 Jan 2020 17:37:17 -0800 Subject: Simplify testing link rules. PiperOrigin-RevId: 292458933 --- tools/build/defs.bzl | 1 + tools/defs.bzl | 3 ++- tools/images/BUILD | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) (limited to 'tools') diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl index d0556abd1..967c1f900 100644 --- a/tools/build/defs.bzl +++ b/tools/build/defs.bzl @@ -18,6 +18,7 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data +gtest = "@com_google_googletest//:gtest" loopback = "//tools/build:loopback" proto_library = native.proto_library pkg_deb = _pkg_deb diff --git a/tools/defs.bzl b/tools/defs.bzl index 819f12b0d..ce677cbbf 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") # Delegate directly. cc_binary = _cc_binary @@ -20,6 +20,7 @@ go_embed_data = _go_embed_data go_image = _go_image go_test = _go_test go_tool_library = _go_tool_library +gtest = _gtest pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = _py_library diff --git a/tools/images/BUILD b/tools/images/BUILD index f1699b184..fe11f08a3 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary", "gtest") load("//tools/images:defs.bzl", "vm_image", "vm_test") package( @@ -32,8 +32,8 @@ cc_binary( srcs = ["test.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", - "@com_google_googletest//:gtest", ], ) -- cgit v1.2.3 From 95ce8bb4c7ecb23e47e68c60b1de0b99ad8a856d Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Tue, 4 Feb 2020 14:36:43 -0800 Subject: Automatically propagate tags for stateify and marshal. Note that files will need to be appropriately segmented in order for the mechanism to work, in suffixes implying special tags. This only needs to happen for cases where marshal or state structures are defined, which should be rare and mostly architecture specific. PiperOrigin-RevId: 293231579 --- tools/build/defs.bzl | 2 + tools/build/tags.bzl | 36 ++++++++++ tools/defs.bzl | 105 ++++++++++++++++++--------- tools/go_marshal/gomarshal/BUILD | 1 + tools/go_marshal/gomarshal/generator.go | 11 +++ tools/go_stateify/BUILD | 1 + tools/go_stateify/defs.bzl | 10 +-- tools/go_stateify/main.go | 122 +++----------------------------- tools/tags/BUILD | 11 +++ tools/tags/tags.go | 89 +++++++++++++++++++++++ 10 files changed, 235 insertions(+), 153 deletions(-) create mode 100644 tools/build/tags.bzl create mode 100644 tools/tags/BUILD create mode 100644 tools/tags/tags.go (limited to 'tools') diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl index 967c1f900..1a1a0d825 100644 --- a/tools/build/defs.bzl +++ b/tools/build/defs.bzl @@ -8,6 +8,7 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") load("@pydeps//:requirements.bzl", _py_requirement = "requirement") +load("//tools/build:tags.bzl", _go_suffixes = "go_suffixes") container_image = _container_image cc_binary = _cc_binary @@ -18,6 +19,7 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data +go_suffixes = _go_suffixes gtest = "@com_google_googletest//:gtest" loopback = "//tools/build:loopback" proto_library = native.proto_library diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl new file mode 100644 index 000000000..e99c87f81 --- /dev/null +++ b/tools/build/tags.bzl @@ -0,0 +1,36 @@ +"""List of special Go suffixes.""" + +go_suffixes = [ + "_386", + "_386_unsafe", + "_amd64", + "_amd64_unsafe", + "_aarch64", + "_aarch64_unsafe", + "_arm", + "_arm_unsafe", + "_arm64", + "_arm64_unsafe", + "_mips", + "_mips_unsafe", + "_mipsle", + "_mipsle_unsafe", + "_mips64", + "_mips64_unsafe", + "_mips64le", + "_mips64le_unsafe", + "_ppc64", + "_ppc64_unsafe", + "_ppc64le", + "_ppc64le_unsafe", + "_riscv64", + "_riscv64_unsafe", + "_s390x", + "_s390x_unsafe", + "_sparc64", + "_sparc64_unsafe", + "_wasm", + "_wasm_unsafe", + "_linux", + "_linux_unsafe", +] diff --git a/tools/defs.bzl b/tools/defs.bzl index ce677cbbf..5d5fa134a 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/build:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") # Delegate directly. cc_binary = _cc_binary @@ -45,6 +45,34 @@ def go_binary(name, **kwargs): **kwargs ) +def calculate_sets(srcs): + """Calculates special Go sets for templates. + + Args: + srcs: the full set of Go sources. + + Returns: + A dictionary of the form: + + "": [src1.go, src2.go] + "suffix": [src3suffix.go, src4suffix.go] + + Note that suffix will typically start with '_'. + """ + result = dict() + for file in srcs: + if not file.endswith(".go"): + continue + target = "" + for suffix in go_suffixes: + if file.endswith(suffix + ".go"): + target = suffix + if not target in result: + result[target] = [file] + else: + result[target].append(file) + return result + def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. @@ -70,39 +98,49 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F marshal: whether marshal is enabled (default: false). **kwargs: standard go_library arguments. """ + all_srcs = srcs + all_deps = deps if stateify: # Only do stateification for non-state packages without manual autogen. - go_stateify( - name = name + "_state_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - imports = imports, - package = name, - arch = select_arch(), - out = name + "_state_autogen.go", - ) - all_srcs = srcs + [name + "_state_autogen.go"] - if "//pkg/state" not in deps: - all_deps = deps + ["//pkg/state"] - else: - all_deps = deps - else: - all_deps = deps - all_srcs = srcs + # First, we need to segregate the input files via the special suffixes, + # and calculate the final output set. + state_sets = calculate_sets(srcs) + for (suffix, srcs) in state_sets.items(): + go_stateify( + name = name + suffix + "_state_autogen", + srcs = srcs, + imports = imports, + package = name, + out = name + suffix + "_state_autogen.go", + ) + all_srcs = all_srcs + [ + name + suffix + "_state_autogen.go" + for suffix in state_sets.keys() + ] + if "//pkg/state" not in all_deps: + all_deps = all_deps + ["//pkg/state"] + if marshal: - go_marshal( - name = name + "_abi_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - debug = False, - imports = imports, - package = name, - ) + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, srcs) in marshal_sets.items(): + go_marshal( + name = name + suffix + "_abi_autogen", + srcs = srcs, + debug = False, + imports = imports, + package = name, + ) extra_deps = [ dep for dep in marshal_deps if not dep in all_deps ] all_deps = all_deps + extra_deps - all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + all_srcs = all_srcs + [ + name + suffix + "_abi_autogen_unsafe.go" + for suffix in marshal_sets.keys() + ] _go_library( name = name, @@ -115,13 +153,16 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F # Ignore importpath for go_test. kwargs.pop("importpath", None) - _go_test( - name = name + "_abi_autogen_test", - srcs = [name + "_abi_autogen_test.go"], - library = ":" + name, - deps = marshal_test_deps, - **kwargs - ) + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, srcs) in marshal_sets.items(): + _go_test( + name = name + suffix + "_abi_autogen_test", + srcs = [name + suffix + "_abi_autogen_test.go"], + library = ":" + name + suffix, + deps = marshal_test_deps, + **kwargs + ) def proto_library(name, srcs, **kwargs): """Wraps the standard proto_library. diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index c92b59dd6..b5d5a4487 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -14,4 +14,5 @@ go_library( visibility = [ "//:sandbox", ], + deps = ["//tools/tags"], ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index af90bdecb..0b3f600fe 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -23,6 +23,9 @@ import ( "go/token" "os" "sort" + "strings" + + "gvisor.dev/gvisor/tools/tags" ) const ( @@ -104,6 +107,14 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n") + } + + // Package header. b.emit("package %s\n\n", g.pkg) if err := b.write(g.output); err != nil { return err diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index a133d6f8b..6036faf7b 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -6,4 +6,5 @@ go_binary( name = "stateify", srcs = ["main.go"], visibility = ["//visibility:public"], + deps = ["//tools/tags"], ) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 0f261d89f..bdb966362 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -7,7 +7,6 @@ def _go_stateify_impl(ctx): # Run the stateify command. args = ["-output=%s" % output.path] args.append("-pkg=%s" % ctx.attr.package) - args.append("-arch=%s" % ctx.attr.arch) if ctx.attr._statepkg: args.append("-statepkg=%s" % ctx.attr._statepkg) if ctx.attr.imports: @@ -47,15 +46,8 @@ for statified types. doc = "The package name for the input sources.", mandatory = True, ), - "arch": attr.string( - doc = "Target platform.", - mandatory = True, - ), "out": attr.output( - doc = """ -The name of the generated file output. This must not conflict with any other -files and must be added to the srcs of the relevant go_library. -""", + doc = "Name of the generator output file.", mandatory = True, ), "_tool": attr.label( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 7d5d291e6..aa9d4543e 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -22,12 +22,12 @@ import ( "go/ast" "go/parser" "go/token" - "io/ioutil" "os" - "path/filepath" "reflect" "strings" "sync" + + "gvisor.dev/gvisor/tools/tags" ) var ( @@ -35,113 +35,8 @@ var ( imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") - arch = flag.String("arch", "", "specify the target platform") ) -// The known architectures. -var okgoarch = []string{ - "386", - "amd64", - "arm", - "arm64", - "mips", - "mipsle", - "mips64", - "mips64le", - "ppc64", - "ppc64le", - "riscv64", - "s390x", - "sparc64", - "wasm", -} - -// readfile returns the content of the named file. -func readfile(file string) string { - data, err := ioutil.ReadFile(file) - if err != nil { - panic(fmt.Sprintf("readfile err: %v", err)) - } - return string(data) -} - -// matchfield reports whether the field (x,y,z) matches this build. -// all the elements in the field must be satisfied. -func matchfield(f string, goarch string) bool { - for _, tag := range strings.Split(f, ",") { - if !matchtag(tag, goarch) { - return false - } - } - return true -} - -// matchtag reports whether the tag (x or !x) matches this build. -func matchtag(tag string, goarch string) bool { - if tag == "" { - return false - } - if tag[0] == '!' { - if len(tag) == 1 || tag[1] == '!' { - return false - } - return !matchtag(tag[1:], goarch) - } - return tag == goarch -} - -// canBuild reports whether we can build this file for target platform by -// checking file name and build tags. The code is derived from the Go source -// cmd.dist.build.shouldbuild. -func canBuild(file, goTargetArch string) bool { - name := filepath.Base(file) - excluded := func(list []string, ok string) bool { - for _, x := range list { - if x == ok || (ok == "android" && x == "linux") || (ok == "illumos" && x == "solaris") { - continue - } - i := strings.Index(name, x) - if i <= 0 || name[i-1] != '_' { - continue - } - i += len(x) - if i == len(name) || name[i] == '.' || name[i] == '_' { - return true - } - } - return false - } - if excluded(okgoarch, goTargetArch) { - return false - } - - // Check file contents for // +build lines. - for _, p := range strings.Split(readfile(file), "\n") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if !strings.HasPrefix(p, "//") { - break - } - if !strings.Contains(p, "+build") { - continue - } - fields := strings.Fields(p[2:]) - if len(fields) < 1 || fields[0] != "+build" { - continue - } - for _, p := range fields[1:] { - if matchfield(p, goTargetArch) { - goto fieldmatch - } - } - return false - fieldmatch: - } - return true -} - // resolveTypeName returns a qualified type name. func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { for done := false; !done; { @@ -329,8 +224,15 @@ func main() { fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) } - // Emit the package name. + // Automated warning. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") + + // Emit build tags. + if t := tags.Aggregate(flag.Args()); len(t) > 0 { + fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + } + + // Emit the package name. fmt.Fprintf(outputFile, "package %s\n\n", *pkg) // Emit the imports lazily. @@ -364,10 +266,6 @@ func main() { os.Exit(1) } - if !canBuild(filename, *arch) { - continue - } - files = append(files, f) } diff --git a/tools/tags/BUILD b/tools/tags/BUILD new file mode 100644 index 000000000..1c02e2c89 --- /dev/null +++ b/tools/tags/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "tags", + srcs = ["tags.go"], + marshal = False, + stateify = False, + visibility = ["//tools:__subpackages__"], +) diff --git a/tools/tags/tags.go b/tools/tags/tags.go new file mode 100644 index 000000000..f35904e0a --- /dev/null +++ b/tools/tags/tags.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tags is a utility for parsing build tags. +package tags + +import ( + "fmt" + "io/ioutil" + "strings" +) + +// OrSet is a set of tags on a single line. +// +// Note that tags may include ",", and we don't distinguish this case in the +// logic below. Ideally, this constraints can be split into separate top-level +// build tags in order to resolve any issues. +type OrSet []string + +// Line returns the line for this or. +func (or OrSet) Line() string { + return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) +} + +// AndSet is the set of all OrSets. +type AndSet []OrSet + +// Lines returns the lines to be printed. +func (and AndSet) Lines() (ls []string) { + for _, or := range and { + ls = append(ls, or.Line()) + } + return +} + +// Join joins this AndSet with another. +func (and AndSet) Join(other AndSet) AndSet { + return append(and, other...) +} + +// Tags returns the unique set of +build tags. +// +// Derived form the runtime's canBuild. +func Tags(file string) (tags AndSet) { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil + } + // Check file contents for // +build lines. + for _, p := range strings.Split(string(data), "\n") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if !strings.HasPrefix(p, "//") { + break + } + if !strings.Contains(p, "+build") { + continue + } + fields := strings.Fields(p[2:]) + if len(fields) < 1 || fields[0] != "+build" { + continue + } + tags = append(tags, OrSet(fields[1:])) + } + return tags +} + +// Aggregate aggregates all tags from a set of files. +// +// Note that these may be in conflict, in which case the build will fail. +func Aggregate(files []string) (tags AndSet) { + for _, file := range files { + tags = tags.Join(Tags(file)) + } + return tags +} -- cgit v1.2.3 From 1b6a12a768216a99a5e0428c42ea4faf79cf3b50 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 5 Feb 2020 22:45:44 -0800 Subject: Add notes to relevant tests. These were out-of-band notes that can help provide additional context and simplify automated imports. PiperOrigin-RevId: 293525915 --- pkg/metric/metric.go | 1 - pkg/sentry/arch/arch_x86.go | 4 ++ pkg/sentry/arch/signal_amd64.go | 2 +- pkg/sentry/fs/file_overlay_test.go | 1 + pkg/sentry/fs/proc/README.md | 4 ++ pkg/sentry/kernel/BUILD | 1 + pkg/sentry/kernel/kernel.go | 3 ++ pkg/sentry/kernel/kernel_opts.go | 20 +++++++ pkg/sentry/socket/hostinet/BUILD | 1 + pkg/sentry/socket/hostinet/socket.go | 5 +- pkg/sentry/socket/hostinet/sockopt_impl.go | 27 ++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 3 ++ runsc/boot/filter/BUILD | 1 + runsc/boot/filter/config.go | 13 ----- runsc/boot/filter/config_profile.go | 34 ++++++++++++ runsc/container/console_test.go | 5 +- runsc/dockerutil/dockerutil.go | 11 ++-- runsc/testutil/BUILD | 5 +- runsc/testutil/testutil.go | 54 ------------------- runsc/testutil/testutil_runfiles.go | 75 +++++++++++++++++++++++++++ test/image/image_test.go | 8 +-- test/syscalls/build_defs.bzl | 35 +++++++++++-- test/syscalls/linux/chroot.cc | 2 +- test/syscalls/linux/concurrency.cc | 3 +- test/syscalls/linux/exec_proc_exe_workload.cc | 6 +++ test/syscalls/linux/fork.cc | 5 +- test/syscalls/linux/mmap.cc | 8 +-- test/syscalls/linux/open_create.cc | 1 + test/syscalls/linux/preadv.cc | 1 + test/syscalls/linux/proc.cc | 46 +++++++++++++--- test/syscalls/linux/readv.cc | 4 +- test/syscalls/linux/rseq.cc | 2 +- test/syscalls/linux/select.cc | 2 +- test/syscalls/linux/shm.cc | 2 +- test/syscalls/linux/sigprocmask.cc | 2 +- test/syscalls/linux/socket_unix_non_stream.cc | 4 +- test/syscalls/linux/symlink.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 3 +- test/syscalls/linux/time.cc | 1 + test/syscalls/linux/tkill.cc | 2 +- test/util/temp_path.cc | 1 + tools/build/tags.bzl | 4 ++ tools/defs.bzl | 17 +++++- 43 files changed, 318 insertions(+), 113 deletions(-) create mode 100644 pkg/sentry/kernel/kernel_opts.go create mode 100644 pkg/sentry/socket/hostinet/sockopt_impl.go create mode 100644 runsc/boot/filter/config_profile.go create mode 100644 runsc/testutil/testutil_runfiles.go (limited to 'tools') diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 93d4f2b8c..006fcd9ab 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -46,7 +46,6 @@ var ( // // TODO(b/67298402): Support non-cumulative metrics. // TODO(b/67298427): Support metric fields. -// type Uint64Metric struct { // value is the actual value of the metric. It must be accessed // atomically. diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index a18093155..3db8bd34b 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -114,6 +114,10 @@ func newX86FPStateSlice() []byte { size, align := cpuid.HostFeatureSet().ExtendedStateSize() capacity := size // Always use at least 4096 bytes. + // + // For the KVM platform, this state is a fixed 4096 bytes, so make sure + // that the underlying array is at _least_ that size otherwise we will + // corrupt random memory. This is not a pleasant thing to debug. if capacity < 4096 { capacity = 4096 } diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 81b92bb43..6fb756f0e 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -55,7 +55,7 @@ type SignalContext64 struct { Trapno uint64 Oldmask linux.SignalSet Cr2 uint64 - // Pointer to a struct _fpstate. + // Pointer to a struct _fpstate. See b/33003106#comment8. Fpstate uint64 Reserved [8]uint64 } diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go index 02538bb4f..a76d87e3a 100644 --- a/pkg/sentry/fs/file_overlay_test.go +++ b/pkg/sentry/fs/file_overlay_test.go @@ -177,6 +177,7 @@ func TestReaddirRevalidation(t *testing.T) { // TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with // a frozen dirent tree does not make Readdir calls to the underlying files. +// This is a regression test for b/114808269. func TestReaddirOverlayFrozen(t *testing.T) { ctx := contexttest.Context(t) diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md index 5d4ec6c7b..6667a0916 100644 --- a/pkg/sentry/fs/proc/README.md +++ b/pkg/sentry/fs/proc/README.md @@ -11,6 +11,8 @@ inconsistency, please file a bug. The following files are implemented: + + | File /proc/ | Content | | :------------------------ | :---------------------------------------------------- | | [cpuinfo](#cpuinfo) | Info about the CPU | @@ -22,6 +24,8 @@ The following files are implemented: | [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus | | [version](#version) | Kernel version | + + ### cpuinfo ```bash diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a27628c0a..2231d6973 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -91,6 +91,7 @@ go_library( "fs_context.go", "ipc_namespace.go", "kernel.go", + "kernel_opts.go", "kernel_state.go", "pending_signals.go", "pending_signals_list.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index dcd6e91c4..3ee760ba2 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -235,6 +235,9 @@ type Kernel struct { // events. This is initialized lazily on the first unimplemented // syscall. unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"` + + // SpecialOpts contains special kernel options. + SpecialOpts } // InitKernelArgs holds arguments to Init. diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go new file mode 100644 index 000000000..2e66ec587 --- /dev/null +++ b/pkg/sentry/kernel/kernel_opts.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +// SpecialOpts contains non-standard options for the kernel. +// +// +stateify savable +type SpecialOpts struct{} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 5a07d5d0e..023bad156 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -10,6 +10,7 @@ go_library( "save_restore.go", "socket.go", "socket_unsafe.go", + "sockopt_impl.go", "stack.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 34f63986f..de76388ac 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -285,7 +285,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt } // Whitelist options and constrain option length. - var optlen int + optlen := getSockOptLen(t, level, name) switch level { case linux.SOL_IP: switch name { @@ -330,7 +330,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { // Whitelist options and constrain option length. - var optlen int + optlen := setSockOptLen(t, level, name) switch level { case linux.SOL_IP: switch name { @@ -353,6 +353,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ optlen = sizeofInt32 } } + if optlen == 0 { // Pretend to accept socket options we don't understand. This seems // dangerous, but it's what netstack does... diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go new file mode 100644 index 000000000..8a783712e --- /dev/null +++ b/pkg/sentry/socket/hostinet/sockopt_impl.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +func getSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} + +func setSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e4a6b1b8b..f2be0e651 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2166,6 +2166,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { e.isRegistered = true e.setEndpointState(StateListen) + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index ce30f6c53..ed18f0047 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -8,6 +8,7 @@ go_library( "config.go", "config_amd64.go", "config_arm64.go", + "config_profile.go", "extra_filters.go", "extra_filters_msan.go", "extra_filters_race.go", diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index f8d351c7b..c69f4c602 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -536,16 +536,3 @@ func controlServerFilters(fd int) seccomp.SyscallRules { }, } } - -// profileFilters returns extra syscalls made by runtime/pprof package. -func profileFilters() seccomp.SyscallRules { - return seccomp.SyscallRules{ - syscall.SYS_OPENAT: []seccomp.Rule{ - { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), - }, - }, - } -} diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go new file mode 100644 index 000000000..194952a7b --- /dev/null +++ b/runsc/boot/filter/config_profile.go @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +// profileFilters returns extra syscalls made by runtime/pprof package. +func profileFilters() seccomp.SyscallRules { + return seccomp.SyscallRules{ + syscall.SYS_OPENAT: []seccomp.Rule{ + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), + }, + }, + } +} diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 060b63bf3..c2518d52b 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -196,7 +196,10 @@ func TestJobControlSignalExec(t *testing.T) { defer ptyMaster.Close() defer ptySlave.Close() - // Exec bash and attach a terminal. + // Exec bash and attach a terminal. Note that occasionally /bin/sh + // may be a different shell or have a different configuration (such + // as disabling interactive mode and job control). Since we want to + // explicitly test interactive mode, use /bin/bash. See b/116981926. execArgs := &control.ExecArgs{ Filename: "/bin/bash", // Don't let bash execute from profile or rc files, otherwise diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go index 9b6346ca2..1ff5e8cc3 100644 --- a/runsc/dockerutil/dockerutil.go +++ b/runsc/dockerutil/dockerutil.go @@ -143,8 +143,11 @@ func PrepareFiles(names ...string) (string, error) { return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err) } for _, name := range names { - src := getLocalPath(name) - dst := path.Join(dir, name) + src, err := testutil.FindFile(name) + if err != nil { + return "", fmt.Errorf("testutil.Preparefiles(%q) failed: %v", name, err) + } + dst := path.Join(dir, path.Base(name)) if err := testutil.Copy(src, dst); err != nil { return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) } @@ -152,10 +155,6 @@ func PrepareFiles(names ...string) (string, error) { return dir, nil } -func getLocalPath(file string) string { - return path.Join(".", file) -} - // do executes docker command. func do(args ...string) (string, error) { log.Printf("Running: docker %s\n", args) diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index f845120b0..945405303 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "testutil", testonly = 1, - srcs = ["testutil.go"], + srcs = [ + "testutil.go", + "testutil_runfiles.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index edf2e809a..80c2c9680 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -79,60 +79,6 @@ func ConfigureExePath() error { return nil } -// FindFile searchs for a file inside the test run environment. It returns the -// full path to the file. It fails if none or more than one file is found. -func FindFile(path string) (string, error) { - wd, err := os.Getwd() - if err != nil { - return "", err - } - - // The test root is demarcated by a path element called "__main__". Search for - // it backwards from the working directory. - root := wd - for { - dir, name := filepath.Split(root) - if name == "__main__" { - break - } - if len(dir) == 0 { - return "", fmt.Errorf("directory __main__ not found in %q", wd) - } - // Remove ending slash to loop around. - root = dir[:len(dir)-1] - } - - // Annoyingly, bazel adds the build type to the directory path for go - // binaries, but not for c++ binaries. We use two different patterns to - // to find our file. - patterns := []string{ - // Try the obvious path first. - filepath.Join(root, path), - // If it was a go binary, use a wildcard to match the build - // type. The pattern is: /test-path/__main__/directories/*/file. - filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), - } - - for _, p := range patterns { - matches, err := filepath.Glob(p) - if err != nil { - // "The only possible returned error is ErrBadPattern, - // when pattern is malformed." -godoc - return "", fmt.Errorf("error globbing %q: %v", p, err) - } - switch len(matches) { - case 0: - // Try the next pattern. - case 1: - // We found it. - return matches[0], nil - default: - return "", fmt.Errorf("more than one match found for %q: %s", path, matches) - } - } - return "", fmt.Errorf("file %q not found", path) -} - // TestConfig returns the default configuration to use in tests. Note that // 'RootDir' must be set by caller if required. func TestConfig() *boot.Config { diff --git a/runsc/testutil/testutil_runfiles.go b/runsc/testutil/testutil_runfiles.go new file mode 100644 index 000000000..ece9ea9a1 --- /dev/null +++ b/runsc/testutil/testutil_runfiles.go @@ -0,0 +1,75 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "fmt" + "os" + "path/filepath" +) + +// FindFile searchs for a file inside the test run environment. It returns the +// full path to the file. It fails if none or more than one file is found. +func FindFile(path string) (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + + // The test root is demarcated by a path element called "__main__". Search for + // it backwards from the working directory. + root := wd + for { + dir, name := filepath.Split(root) + if name == "__main__" { + break + } + if len(dir) == 0 { + return "", fmt.Errorf("directory __main__ not found in %q", wd) + } + // Remove ending slash to loop around. + root = dir[:len(dir)-1] + } + + // Annoyingly, bazel adds the build type to the directory path for go + // binaries, but not for c++ binaries. We use two different patterns to + // to find our file. + patterns := []string{ + // Try the obvious path first. + filepath.Join(root, path), + // If it was a go binary, use a wildcard to match the build + // type. The pattern is: /test-path/__main__/directories/*/file. + filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), + } + + for _, p := range patterns { + matches, err := filepath.Glob(p) + if err != nil { + // "The only possible returned error is ErrBadPattern, + // when pattern is malformed." -godoc + return "", fmt.Errorf("error globbing %q: %v", p, err) + } + switch len(matches) { + case 0: + // Try the next pattern. + case 1: + // We found it. + return matches[0], nil + default: + return "", fmt.Errorf("more than one match found for %q: %s", path, matches) + } + } + return "", fmt.Errorf("file %q not found", path) +} diff --git a/test/image/image_test.go b/test/image/image_test.go index d0dcb1861..0a1e19d6f 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -107,7 +107,7 @@ func TestHttpd(t *testing.T) { } d := dockerutil.MakeDocker("http-test") - dir, err := dockerutil.PrepareFiles("latin10k.txt") + dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -139,7 +139,7 @@ func TestNginx(t *testing.T) { } d := dockerutil.MakeDocker("net-test") - dir, err := dockerutil.PrepareFiles("latin10k.txt") + dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -183,7 +183,7 @@ func TestMysql(t *testing.T) { } client := dockerutil.MakeDocker("mysql-client-test") - dir, err := dockerutil.PrepareFiles("mysql.sql") + dir, err := dockerutil.PrepareFiles("test/image/mysql.sql") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -283,7 +283,7 @@ func TestRuby(t *testing.T) { } d := dockerutil.MakeDocker("ruby-test") - dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh") + dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index 1df761dd0..cbab85ef7 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -2,8 +2,6 @@ load("//tools:defs.bzl", "loopback") -# syscall_test is a macro that will create targets to run the given test target -# on the host (native) and runsc. def syscall_test( test, shard_count = 5, @@ -13,6 +11,19 @@ def syscall_test( add_uds_tree = False, add_hostinet = False, tags = None): + """syscall_test is a macro that will create targets for all platforms. + + Args: + test: the test target. + shard_count: shards for defined tests. + size: the defined test size. + use_tmpfs: use tmpfs in the defined tests. + add_overlay: add an overlay test. + add_uds_tree: add a UDS test. + add_hostinet: add a hostinet test. + tags: starting test tags. + """ + _syscall_test( test = test, shard_count = shard_count, @@ -111,6 +122,19 @@ def _syscall_test( # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. tags += [full_platform, "file_" + file_access] + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + + # Disable off-host networking. + tags.append("requires-net:loopback") + # Add tag to prevent the tests from running in a Bazel sandbox. # TODO(b/120560048): Make the tests run without this tag. tags.append("no-sandbox") @@ -118,8 +142,11 @@ def _syscall_test( # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is # more stable. if platform == "kvm": - tags += ["manual"] - tags += ["requires-kvm"] + tags.append("manual") + tags.append("requires-kvm") + + # TODO(b/112165693): Remove when tests pass reliably. + tags.append("notap") args = [ # Arguments are passed directly to syscall_test_runner binary. diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 0a2d44a2c..85ec013d5 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -167,7 +167,7 @@ TEST(ChrootTest, DotDotFromOpenFD) { } // Test that link resolution in a chroot can escape the root by following an -// open proc fd. +// open proc fd. Regression test for b/32316719. TEST(ChrootTest, ProcFdLinkResolutionInChroot) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index f41f99900..7cd6a75bd 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -46,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) { } // Test that multiple threads in this process continue to execute in parallel, -// even if an unrelated second process is spawned. +// even if an unrelated second process is spawned. Regression test for +// b/32119508. TEST(ConcurrencyTest, MultiProcessMultithreaded) { // In PID 1, start TIDs 1 and 2, and put both to sleep. // diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc index b790fe5be..2989379b7 100644 --- a/test/syscalls/linux/exec_proc_exe_workload.cc +++ b/test/syscalls/linux/exec_proc_exe_workload.cc @@ -21,6 +21,12 @@ #include "test/util/posix_error.h" int main(int argc, char** argv, char** envp) { + // This is annoying. Because remote build systems may put these binaries + // in a content-addressable-store, you may wind up with /proc/self/exe + // pointing to some random path (but with a sensible argv[0]). + // + // Therefore, this test simply checks that the /proc/self/exe + // is absolute and *doesn't* match argv[1]. std::string exe = gvisor::testing::ProcessExePath(getpid()).ValueOrDie(); if (exe[0] != '/') { diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index 906f3358d..ff8bdfeb0 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -271,7 +271,7 @@ TEST_F(ForkTest, Alarm) { EXPECT_EQ(0, alarmed); } -// Child cannot affect parent private memory. +// Child cannot affect parent private memory. Regression test for b/24137240. TEST_F(ForkTest, PrivateMemory) { std::atomic local(0); @@ -298,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) { } // Kernel-accessed buffers should remain coherent across COW. +// +// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates +// differently. Regression test for b/33811887. TEST_F(ForkTest, COWSegment) { constexpr int kBufSize = 1024; char* read_buf = private_; diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 1c4d9f1c7..11fb1b457 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -1418,7 +1418,7 @@ TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) { // // On most platforms this is trivial, but when the file is mapped via the sentry // page cache (which does not yet support writing to shared mappings), a bug -// caused reads to fail unnecessarily on such mappings. +// caused reads to fail unnecessarily on such mappings. See b/28913513. TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { uintptr_t addr; size_t len = strlen(kFileContents); @@ -1435,7 +1435,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { // Tests that EFAULT is returned when invoking a syscall that requires the OS to // read past end of file (resulting in a fault in sentry context in the gVisor -// case). +// case). See b/28913513. TEST_F(MMapFileTest, InternalSigBus) { uintptr_t addr; ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, @@ -1578,7 +1578,7 @@ TEST_F(MMapFileTest, Bug38498194) { } // Tests that reading from a file to a memory mapping of the same file does not -// deadlock. +// deadlock. See b/34813270. TEST_F(MMapFileTest, SelfRead) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, @@ -1590,7 +1590,7 @@ TEST_F(MMapFileTest, SelfRead) { } // Tests that writing to a file from a memory mapping of the same file does not -// deadlock. +// deadlock. Regression test for b/34813270. TEST_F(MMapFileTest, SelfWrite) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 431733dbe..902d0a0dc 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -132,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { } // A file originally created RW, but opened RO can later be opened RW. +// Regression test for b/65385065. TEST(CreateTest, OpenCreateROThenRW) { TempPath file(NewTempAbsPath()); diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index f7ea44054..5b0743fe9 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -37,6 +37,7 @@ namespace testing { namespace { +// Stress copy-on-write. Attempts to reproduce b/38430174. TEST(PreadvTest, MMConcurrencyStress) { // Fill a one-page file with zeroes (the contents don't really matter). const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 169b723eb..a23fdb58d 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1352,13 +1352,19 @@ TEST(ProcPidSymlink, SubprocessZombied) { // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1 + // + // ~4.3: Syscall fails with EACCES. + // 4.17 & gVisor: Syscall succeeds and returns 1. + // // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1. + // + // ~4.3: Syscall fails with EACCES. + // 4.17 & gVisor: Syscall succeeds and returns 1. + // // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); } @@ -1431,8 +1437,12 @@ TEST(ProcPidFile, SubprocessRunning) { TEST(ProcPidFile, SubprocessZombie) { char buf[1]; - // 4.17: Succeeds and returns 1 - // gVisor: Succeeds and returns 0 + // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent + // behavior on different kernels. + // + // ~4.3: Succeds and returns 0. + // 4.17: Succeeds and returns 1. + // gVisor: Succeeds and returns 0. EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds()); EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)), @@ -1458,7 +1468,10 @@ TEST(ProcPidFile, SubprocessZombie) { // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. + // + // ~4.3: Fails and returns EACCES. // gVisor & 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); } @@ -1467,9 +1480,12 @@ TEST(ProcPidFile, SubprocessZombie) { TEST(ProcPidFile, SubprocessExited) { char buf[1]; - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels. + // + // ~4.3: Fails and returns ESRCH. // gVisor: Fails with ESRCH. // 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)), // SyscallFailsWithErrno(ESRCH)); @@ -1641,7 +1657,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", TaskFiles(initial, {child1.Tid()}))); - // Stat child1's task file. + // Stat child1's task file. Regression test for b/32097707. struct stat statbuf; const std::string child1_task_file = absl::StrCat("/proc/self/task/", child1.Tid()); @@ -1669,7 +1685,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(EventuallyDirContainsExactly( "/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()}))); - // Stat child1's task file again. This time it should fail. + // Stat child1's task file again. This time it should fail. See b/32097707. EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), SyscallFailsWithErrno(ENOENT)); @@ -1824,7 +1840,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) { } // Check that link for proc fd entries point the target node, not the -// symlink itself. +// symlink itself. Regression test for b/31155070. TEST(ProcTaskFd, FstatatFollowsSymlink) { const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = @@ -1883,6 +1899,20 @@ TEST(ProcMounts, IsSymlink) { EXPECT_EQ(link, "self/mounts"); } +TEST(ProcSelfMountinfo, RequiredFieldsArePresent) { + auto mountinfo = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo")); + EXPECT_THAT( + mountinfo, + AllOf( + // Root mount. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"), + // Proc mount - always rw. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)"))); +} + // Check that /proc/self/mounts looks something like a real mounts file. TEST(ProcSelfMounts, RequiredFieldsArePresent) { auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts")); diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index 4069cbc7e..baaf9f757 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { - // Ensure that we won't be interrupted by ITIMER_PROF. + // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly + // important in environments where automated profiling tools may start + // ITIMER_PROF automatically. struct itimerval itv = {}; auto const cleanup_itimer = ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv)); diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc index 106c045e3..4bfb1ff56 100644 --- a/test/syscalls/linux/rseq.cc +++ b/test/syscalls/linux/rseq.cc @@ -36,7 +36,7 @@ namespace { // We must be very careful about how these tests are written. Each thread may // only have one struct rseq registration, which may be done automatically at // thread start (as of 2019-11-13, glibc does *not* support rseq and thus does -// not do so). +// not do so, but other libraries do). // // Testing of rseq is thus done primarily in a child process with no // registration. This means exec'ing a nostdlib binary, as rseq registration can diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index 424e2a67f..be2364fb8 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -146,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) { // This test illustrates Linux's behavior of 'select' calls passing after // setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on -// this behavior. +// this behavior. See b/122318458. TEST_F(SelectTest, SetrlimitCallNOFILE) { fd_set read_set; FD_ZERO(&read_set); diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index 7ba752599..c7fdbb924 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -473,7 +473,7 @@ TEST(ShmTest, PartialUnmap) { } // Check that sentry does not panic when asked for a zero-length private shm -// segment. +// segment. Regression test for b/110694797. TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) { EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _)); } diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc index 654c6a47f..a603fc1d1 100644 --- a/test/syscalls/linux/sigprocmask.cc +++ b/test/syscalls/linux/sigprocmask.cc @@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) { } // Check that sigprocmask correctly handles aliasing of the set and oldset -// pointers. +// pointers. Regression test for b/30502311. TEST_F(SigProcMaskTest, AliasedSets) { sigset_t mask; diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index 276a94eb8..884319e1d 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -109,7 +109,7 @@ PosixErrorOr> CreateFragmentedRegion(const int size, } // A contiguous iov that is heavily fragmented in FileMem can still be sent -// successfully. +// successfully. See b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { } // A contiguous iov that is heavily fragmented in FileMem can still be received -// into successfully. +// into successfully. Regression test for b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index b249ff91f..03ee1250d 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -38,7 +38,7 @@ mode_t FilePermission(const std::string& path) { } // Test that name collisions are checked on the new link path, not the source -// path. +// path. Regression test for b/31782115. TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) { const std::string srcname = NewTempAbsPath(); const std::string newname = NewTempAbsPath(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 8a8b68e75..c4591a3b9 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -244,7 +244,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) { } // Test that a non-blocking write with a buffer that is larger than the send -// buffer size will not actually write the whole thing at once. +// buffer size will not actually write the whole thing at once. Regression test +// for b/64438887. TEST_P(TcpSocketTest, NonblockingLargeWrite) { // Set the FD to O_NONBLOCK. int opts; diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc index c7eead17e..1ccb95733 100644 --- a/test/syscalls/linux/time.cc +++ b/test/syscalls/linux/time.cc @@ -62,6 +62,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) { ::testing::KilledBySignal(SIGSEGV), ""); } +// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2. int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) { constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000; return reinterpret_cast( diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc index bae377c69..8d8ebbb24 100644 --- a/test/syscalls/linux/tkill.cc +++ b/test/syscalls/linux/tkill.cc @@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) { TEST_CHECK(info->si_code == SI_TKILL); } -// Test with a real signal. +// Test with a real signal. Regression test for b/24790092. TEST(TkillTest, ValidTIDAndRealSignal) { struct sigaction sa; sa.sa_sigaction = SigHandler; diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc index 35aacb172..9c10b6674 100644 --- a/test/util/temp_path.cc +++ b/test/util/temp_path.cc @@ -77,6 +77,7 @@ std::string NewTempAbsPath() { std::string NewTempRelPath() { return NextTempBasename(); } std::string GetAbsoluteTestTmpdir() { + // Note that TEST_TMPDIR is guaranteed to be set. char* env_tmpdir = getenv("TEST_TMPDIR"); std::string tmp_dir = env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp"; diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl index e99c87f81..a6db44e47 100644 --- a/tools/build/tags.bzl +++ b/tools/build/tags.bzl @@ -33,4 +33,8 @@ go_suffixes = [ "_wasm_unsafe", "_linux", "_linux_unsafe", + "_opts", + "_opts_unsafe", + "_impl", + "_impl_unsafe", ] diff --git a/tools/defs.bzl b/tools/defs.bzl index 5d5fa134a..c03b557ae 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -73,6 +73,16 @@ def calculate_sets(srcs): result[target].append(file) return result +def go_imports(name, src, out): + """Simplify a single Go source file by eliminating unused imports.""" + native.genrule( + name = name, + srcs = [src], + outs = [out], + tools = ["@org_golang_x_tools//cmd/goimports:goimports"], + cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), + ) + def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. @@ -107,10 +117,15 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F state_sets = calculate_sets(srcs) for (suffix, srcs) in state_sets.items(): go_stateify( - name = name + suffix + "_state_autogen", + name = name + suffix + "_state_autogen_with_imports", srcs = srcs, imports = imports, package = name, + out = name + suffix + "_state_autogen_with_imports.go", + ) + go_imports( + name = name + suffix + "_state_autogen", + src = name + suffix + "_state_autogen_with_imports.go", out = name + suffix + "_state_autogen.go", ) all_srcs = all_srcs + [ -- cgit v1.2.3 From 16561e461e82f8d846ef1f3ada990270ef39ccc6 Mon Sep 17 00:00:00 2001 From: Zach Koopmans Date: Thu, 6 Feb 2020 15:59:44 -0800 Subject: Add logic to run from baked images. Change adds the following: - logic to run from "baked images". See [GVISOR_DIR]/tools/images - installers which install modified files from a workspace. This allows users to run benchmarks while modifying runsc. - removes the --preemptible tag from built GCE instances. Preemptible instances are much more likely to be preempted on startup, which manifests for the user as a failed benchmark. I don't currently have a way to detect if a VM has been preempted that will work for this change. https://cloud.google.com/compute/docs/instances/preemptible#preemption_process https://cloud.google.com/compute/docs/instances/preemptible#preemption_selection PiperOrigin-RevId: 293697949 --- benchmarks/BUILD | 8 ++ benchmarks/README.md | 21 +++- benchmarks/harness/BUILD | 21 ++++ benchmarks/harness/__init__.py | 36 ++++++- benchmarks/harness/machine.py | 43 +++++++- benchmarks/harness/machine_producers/BUILD | 1 + .../harness/machine_producers/gcloud_producer.py | 114 ++++++++------------- benchmarks/harness/ssh_connection.py | 25 +++-- benchmarks/runner/__init__.py | 75 ++++---------- benchmarks/runner/commands.py | 70 ++++++------- tools/images/defs.bzl | 5 +- tools/installers/BUILD | 7 +- tools/installers/head.sh | 2 +- 13 files changed, 250 insertions(+), 178 deletions(-) (limited to 'tools') diff --git a/benchmarks/BUILD b/benchmarks/BUILD index 1455c6c5b..43614cf5d 100644 --- a/benchmarks/BUILD +++ b/benchmarks/BUILD @@ -3,8 +3,16 @@ package(licenses = ["notice"]) py_binary( name = "benchmarks", srcs = ["run.py"], + data = [ + "//tools/images:ubuntu1604", + "//tools/images:zone", + ], main = "run.py", python_version = "PY3", srcs_version = "PY3", + tags = [ + "local", + "manual", + ], deps = ["//benchmarks/runner"], ) diff --git a/benchmarks/README.md b/benchmarks/README.md index ff21614c5..975321c99 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -26,6 +26,8 @@ For configuring the environment manually, consult the ## Running benchmarks +### Locally + Run the following from the benchmarks directory: ```bash @@ -44,7 +46,7 @@ runtime, runc. Running on another installed runtime, like say runsc, is as simple as: ```bash -bazel run :benchmakrs -- run-local startup --runtime=runsc +bazel run :benchmarks -- run-local startup --runtime=runsc ``` There is help: ``bash bash bazel run :benchmarks -- --help bazel @@ -104,6 +106,23 @@ Or with different parameters: bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu ``` +### On Google Compute Engine (GCE) + +Benchmarks may be run on GCE in an automated way. The default project configured +for `gcloud` will be used. + +An additional parameter `installers` may be provided to ensure that the latest +runtime is installed from the workspace. See the files in `tools/installers` for +supported install targets. + +```bash +bazel run :benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu +``` + +When running on GCE, the scripts generate a per run SSH key, which is added to +your project. The key is set to expire in GCE after 60 minutes and is stored in +a temporary directory on the local machine running the scripts. + ## Writing benchmarks To write new benchmarks, you should familiarize yourself with the structure of diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 52d4e42f8..4d03e3a06 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -1,3 +1,4 @@ +load("//tools:defs.bzl", "pkg_tar") load("//tools:defs.bzl", "py_library", "py_requirement") package( @@ -5,9 +6,29 @@ package( licenses = ["notice"], ) +pkg_tar( + name = "installers", + srcs = [ + "//tools/installers:head", + "//tools/installers:master", + "//tools/installers:runsc", + ], + mode = "0755", +) + +filegroup( + name = "files", + srcs = [ + ":installers", + ], +) + py_library( name = "harness", srcs = ["__init__.py"], + data = [ + ":files", + ], ) py_library( diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py index 61fd25f73..15aa2a69a 100644 --- a/benchmarks/harness/__init__.py +++ b/benchmarks/harness/__init__.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,18 +15,48 @@ import getpass import os +import subprocess +import tempfile # LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a # format string that accepts a single string parameter. -LOCAL_WORKLOADS_PATH = os.path.join( - os.path.dirname(__file__), "../workloads/{}/tar.tar") +LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar" # REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the # remote host. This is a format string that accepts a single string parameter. REMOTE_WORKLOADS_PATH = "workloads/{}" +# INSTALLER_ROOT is the set of files that needs to be copied. +INSTALLER_ARCHIVE = os.readlink(os.path.join( + os.path.dirname(__file__), "installers.tar")) + +# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires +# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is +# abstracted away from the user. +SSH_KEY_DIR = tempfile.TemporaryDirectory() +SSH_PRIVATE_KEY = "key" + # DEFAULT_USER is the default user running this script. DEFAULT_USER = getpass.getuser() # DEFAULT_USER_HOME is the home directory of the user running the script. DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else "" + +# Default directory to remotely installer "installer" targets. +REMOTE_INSTALLERS_PATH = "installers" + + +def make_key(): + """Wraps a valid ssh key in a temporary directory.""" + path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY) + if not os.path.exists(path): + cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format( + key=path).split(" ") + cmd.append("") + subprocess.run(cmd, check=True) + return path + + +def delete_key(): + """Deletes temporary directory containing private key.""" + SSH_KEY_DIR.cleanup() diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py index 2df4c9e31..3d32d3dda 100644 --- a/benchmarks/harness/machine.py +++ b/benchmarks/harness/machine.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,10 +29,11 @@ to run contianers. """ import logging +import os import re import subprocess import time -from typing import Tuple +from typing import List, Tuple import docker @@ -201,6 +202,7 @@ class RemoteMachine(Machine): self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs) self._tunnel.connect() self._docker_client = self._tunnel.get_docker_client() + self._has_installers = False def run(self, cmd: str) -> Tuple[str, str]: return self._ssh_connection.run(cmd) @@ -210,14 +212,45 @@ class RemoteMachine(Machine): stdout, stderr = self._ssh_connection.run("cat '{}'".format(path)) return stdout + stderr + def install(self, + installer: str, + results: List[bool] = None, + index: int = -1): + """Method unique to RemoteMachine to handle installation of installers. + + Handles installers, which install things that may change between runs (e.g. + runsc). Usually called from gcloud_producer, which expects this method to + to store results. + + Args: + installer: the installer target to run. + results: Passed by the caller of where to store success. + index: Index for this method to store the result in the passed results + list. + """ + # This generates a tarball of the full installer root (which will generate + # be the full bazel root directory) and sends it over. + if not self._has_installers: + archive = self._ssh_connection.send_installers() + self.run("tar -xvf {archive} -C {dir}".format( + archive=archive, dir=harness.REMOTE_INSTALLERS_PATH)) + self._has_installers = True + + # Execute the remote installer. + self.run("sudo {dir}/{file}".format( + dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) + if results: + results[index] = True + def pull(self, workload: str) -> str: # Push to the remote machine and build. logging.info("Building %s@%s remotely...", workload, self._name) remote_path = self._ssh_connection.send_workload(workload) + remote_dir = os.path.dirname(remote_path) # Workloads are all tarballs. - self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format( - remote_path=remote_path)) - self.run("docker build --tag={} {}".format(workload, remote_path)) + self.run("tar -xvf {remote_path} -C {remote_dir}".format( + remote_path=remote_path, remote_dir=remote_dir)) + self.run("docker build --tag={} {}".format(workload, remote_dir)) return workload # Workload is the tag. def container(self, image: str, **kwargs) -> container.Container: diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD index 48ea0ef39..3711a397f 100644 --- a/benchmarks/harness/machine_producers/BUILD +++ b/benchmarks/harness/machine_producers/BUILD @@ -76,5 +76,6 @@ py_test( python_version = "PY3", tags = [ "local", + "manual", ], ) diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py index e0b77d52b..513d16e4f 100644 --- a/benchmarks/harness/machine_producers/gcloud_producer.py +++ b/benchmarks/harness/machine_producers/gcloud_producer.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,12 +46,11 @@ class GCloudProducer(machine_producer.MachineProducer): Produces Machine objects backed by GCP instances. Attributes: - project: The GCP project name under which to create the machines. - ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys. image: image name as a string. - image_project: image project as a string. - machine_type: type of GCP to create. e.g. n1-standard-4 zone: string to a valid GCP zone. + machine_type: type of GCP to create (e.g. n1-standard-4). + installers: list of installers post-boot. + ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys. ssh_user: string of user name for ssh_key ssh_password: string of password for ssh key mock: a mock printer which will print mock data if required. Mock data is @@ -60,21 +59,19 @@ class GCloudProducer(machine_producer.MachineProducer): """ def __init__(self, - project: str, - ssh_key_file: str, image: str, - image_project: str, - machine_type: str, zone: str, + machine_type: str, + installers: List[str], + ssh_key_file: str, ssh_user: str, ssh_password: str, mock: gcloud_mock_recorder.MockPrinter = None): - self.project = project - self.ssh_key_file = ssh_key_file self.image = image - self.image_project = image_project - self.machine_type = machine_type self.zone = zone + self.machine_type = machine_type + self.installers = installers + self.ssh_key_file = ssh_key_file self.ssh_user = ssh_user self.ssh_password = ssh_password self.mock = mock @@ -87,10 +84,34 @@ class GCloudProducer(machine_producer.MachineProducer): "Cannot ask for {num} machines!".format(num=num_machines)) with self.condition: names = self._get_unique_names(num_machines) - self._build_instances(names) - instances = self._start_command(names) + instances = self._build_instances(names) self._add_ssh_key_to_instances(names) - return self._machines_from_instances(instances) + machines = self._machines_from_instances(instances) + + # Install all bits in lock-step. + # + # This will perform paralell installations for however many machines we + # have, but it's easy to track errors because if installing (a, b, c), we + # won't install "c" until "b" is installed on all machines. + for installer in self.installers: + threads = [None] * len(machines) + results = [False] * len(machines) + for i in range(len(machines)): + threads[i] = threading.Thread( + target=machines[i].install, args=(installer, results, i)) + threads[i].start() + for thread in threads: + thread.join() + for result in results: + if not result: + raise NotImplementedError( + "Installers failed on at least one machine!") + + # Add this user to each machine's docker group. + for m in machines: + m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock") + + return machines def release_machines(self, machine_list: List[machine.Machine]): """Releases the requested number of machines, deleting the instances.""" @@ -123,15 +144,7 @@ class GCloudProducer(machine_producer.MachineProducer): def _get_unique_names(self, num_names) -> List[str]: """Returns num_names unique names based on data from the GCP project.""" - curr_machines = self._list_machines() - curr_names = set([machine["name"] for machine in curr_machines]) - ret = [] - while len(ret) < num_names: - new_name = "machine-" + str(uuid.uuid4()) - if new_name not in curr_names: - ret.append(new_name) - curr_names.update(new_name) - return ret + return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)] def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]: """Creates instances using gcloud command. @@ -151,34 +164,9 @@ class GCloudProducer(machine_producer.MachineProducer): "_build_instances cannot create instances without names.") cmd = "gcloud compute instances create".split(" ") cmd.extend(names) - cmd.extend( - "--preemptible --image={image} --zone={zone} --machine-type={machine_type}" - .format( - image=self.image, zone=self.zone, - machine_type=self.machine_type).split(" ")) - if self.image_project: - cmd.append("--image-project={project}".format(project=self.image_project)) - res = self._run_command(cmd) - return json.loads(res.stdout) - - def _start_command(self, names): - """Starts instances using gcloud command. - - Runs the command `gcloud compute instances start` on list of instances by - name and returns json data on started instances on success. - - Args: - names: list of names of instances to start. - - Returns: - List of json data describing started machines. - """ - if not names: - raise ValueError("_start_command cannot start empty instance list.") - cmd = "gcloud compute instances start".split(" ") - cmd.extend(names) - cmd.append("--zone={zone}".format(zone=self.zone)) - cmd.append("--project={project}".format(project=self.project)) + cmd.append("--image=" + self.image) + cmd.append("--zone=" + self.zone) + cmd.append("--machine-type=" + self.machine_type) res = self._run_command(cmd) return json.loads(res.stdout) @@ -186,7 +174,7 @@ class GCloudProducer(machine_producer.MachineProducer): """Adds ssh key to instances by calling gcloud ssh command. Runs the command `gcloud compute ssh instance_name` on list of images by - name. Tries to ssh into given instance + name. Tries to ssh into given instance. Args: names: list of machine names to which to add the ssh-key @@ -202,30 +190,18 @@ class GCloudProducer(machine_producer.MachineProducer): cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file)) cmd.append("--zone={zone}".format(zone=self.zone)) cmd.append("--command=uname") + cmd.append("--ssh-key-expire-after=60m") timeout = datetime.timedelta(seconds=5 * 60) start = datetime.datetime.now() while datetime.datetime.now() <= timeout + start: try: self._run_command(cmd) break - except subprocess.CalledProcessError as e: + except subprocess.CalledProcessError: if datetime.datetime.now() > timeout + start: raise TimeoutError( "Could not SSH into instance after 5 min: {name}".format( name=name)) - # 255 is the returncode for ssh connection refused. - elif e.returncode == 255: - - continue - else: - raise e - - def _list_machines(self) -> List[Dict[str, Any]]: - """Runs `list` gcloud command and returns list of Machine data.""" - cmd = "gcloud compute instances list --project {project}".format( - project=self.project).split(" ") - res = self._run_command(cmd) - return json.loads(res.stdout) def _run_command(self, cmd: List[str], @@ -261,7 +237,7 @@ class GCloudProducer(machine_producer.MachineProducer): self.mock.record(res) if res.returncode != 0: raise subprocess.CalledProcessError( - cmd=res.args, + cmd=" ".join(res.args), output=res.stdout, stderr=res.stderr, returncode=res.returncode) diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py index e0bf258f1..a50e34293 100644 --- a/benchmarks/harness/ssh_connection.py +++ b/benchmarks/harness/ssh_connection.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. """SSHConnection handles the details of SSH connections.""" + import os import warnings @@ -24,18 +25,24 @@ from benchmarks import harness warnings.filterwarnings(action="ignore", module=".*paramiko.*") -def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str): +def send_one_file(client: paramiko.SSHClient, path: str, + remote_dir: str) -> str: """Sends a single file via an SSH client. Args: client: The existing SSH client. path: The local path. remote_dir: The remote directory. + + Returns: + :return: The remote path as a string. """ filename = path.split("/").pop() - client.exec_command("mkdir -p " + remote_dir) + if remote_dir != ".": + client.exec_command("mkdir -p " + remote_dir) with client.open_sftp() as ftp_client: ftp_client.put(path, os.path.join(remote_dir, filename)) + return os.path.join(remote_dir, filename) class SSHConnection: @@ -103,6 +110,12 @@ class SSHConnection: The remote path. """ with self._client() as client: - send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name), - harness.REMOTE_WORKLOADS_PATH.format(name)) - return harness.REMOTE_WORKLOADS_PATH.format(name) + return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name), + harness.REMOTE_WORKLOADS_PATH.format(name)) + + def send_installers(self) -> str: + with self._client() as client: + return send_one_file( + client, + path=harness.INSTALLER_ARCHIVE, + remote_dir=harness.REMOTE_INSTALLERS_PATH) diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py index ba80d83d7..ba27dc69f 100644 --- a/benchmarks/runner/__init__.py +++ b/benchmarks/runner/__init__.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,13 +15,10 @@ import copy import csv -import json import logging -import os import pkgutil import pydoc import re -import subprocess import sys import types from typing import List @@ -123,57 +120,29 @@ def run_mock(ctx, **kwargs): @runner.command("run-gcp", commands.GCPCommand) @click.pass_context -def run_gcp(ctx, project: str, ssh_key_file: str, image: str, - image_project: str, machine_type: str, zone: str, ssh_user: str, - ssh_password: str, **kwargs): +def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str, + installers: List[str], **kwargs): """Runs all benchmarks on GCP instances.""" - if not ssh_user: - ssh_user = harness.DEFAULT_USER - - # Get the default project if one was not provided. - if not project: - sub = subprocess.run( - "gcloud config get-value project".split(" "), stdout=subprocess.PIPE) - if sub.returncode: - raise ValueError( - "Cannot get default project from gcloud. Is it configured>") - project = sub.stdout.decode("utf-8").strip("\n") - - if not image_project: - image_project = project - - # Check that the ssh-key exists and is readable. - if not os.access(ssh_key_file, os.R_OK): - raise ValueError( - "ssh key given `{ssh_key}` is does not exist or is not readable." - .format(ssh_key=ssh_key_file)) - - # Check that the image exists. - sub = subprocess.run( - "gcloud compute images describe {image} --project {image_project} --format=json" - .format(image=image, image_project=image_project).split(" "), - stdout=subprocess.PIPE) - if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]: - raise ValueError( - "given image was not found or is not ready: {image} {image_project}." - .format(image=image, image_project=image_project)) - - # Check and set zone to default. - if not zone: - sub = subprocess.run( - "gcloud config get-value compute/zone".split(" "), - stdout=subprocess.PIPE) - if sub.returncode: - raise ValueError( - "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag." - ) - zone = sub.stdout.decode("utf-8").strip("\n") - - producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image, - image_project, machine_type, zone, - ssh_user, ssh_password) - run(ctx, producer, **kwargs) + # Resolve all files. + image = open(image_file).read().rstrip() + zone = open(zone_file).read().rstrip() + + key_file = harness.make_key() + + producer = gcloud_producer.GCloudProducer( + image, + zone, + machine_type, + installers, + ssh_key_file=key_file, + ssh_user=harness.DEFAULT_USER, + ssh_password="") + + try: + run(ctx, producer, **kwargs) + finally: + harness.delete_key() def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int, diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py index 7ab12fac6..0fccb2fad 100644 --- a/benchmarks/runner/commands.py +++ b/benchmarks/runner/commands.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2019 The gVisor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ def run_mock(**kwargs): # mock implementation """ -import click +import os -from benchmarks import harness +import click class RunCommand(click.core.Command): @@ -90,46 +90,40 @@ class GCPCommand(RunCommand): """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method. Attributes: - project: GCP project - ssh_key_path: path to the ssh-key to use for the run - image: name of the image to build machines from - image_project: GCP project under which to find image - zone: a GCP zone (e.g. us-west1-b) - ssh_user: username to use for the ssh-key - ssh_password: password to use for the ssh-key + image_file: name of the image to build machines from + zone_file: a GCP zone (e.g. us-west1-b) + installers: named installers for post-create + machine_type: type of machine to create (e.g. n1-standard-4) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - project = click.core.Option( - ("--project",), - help="Project to run on if not default value given by 'gcloud config get-value project'." + image_file = click.core.Option( + ("--image_file",), + help="The file containing the image for VMs.", + default=os.path.join( + os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"), + ) + zone_file = click.core.Option( + ("--zone_file",), + help="The file containing the GCP zone.", + default=os.path.join( + os.path.dirname(__file__), "../../tools/images/zone.txt"), + ) + installers = click.core.Option( + ("--installers",), + help="The set of installers to use.", + multiple=True, + ) + machine_type = click.core.Option( + ("--machine_type",), + help="Type to make all machines.", + default="n1-standard-4", ) - ssh_key_path = click.core.Option( - ("--ssh-key-file",), - help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.", - default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools") - image = click.core.Option(("--image",), - help="The image on which to build VMs.", - default="bm-tools-testing") - image_project = click.core.Option( - ("--image_project",), - help="The project under which the image to be used is listed.", - default="") - machine_type = click.core.Option(("--machine_type",), - help="Type to make all machines.", - default="n1-standard-4") - zone = click.core.Option(("--zone",), - help="The GCP zone to run on.", - default="") - ssh_user = click.core.Option(("--ssh-user",), - help="User for the ssh key.", - default=harness.DEFAULT_USER) - ssh_password = click.core.Option(("--ssh-password",), - help="Password for the ssh key.", - default="") self.params.extend([ - project, ssh_key_path, image, image_project, machine_type, zone, - ssh_user, ssh_password + image_file, + zone_file, + machine_type, + installers, ]) diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index 32235813a..de365d153 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -57,7 +57,10 @@ def _vm_image_impl(ctx): command = argv, input_manifests = runfiles_manifests, ) - return [DefaultInfo(files = depset([ctx.outputs.out]))] + return [DefaultInfo( + files = depset([ctx.outputs.out]), + runfiles = ctx.runfiles(files = [ctx.outputs.out]), + )] _vm_image = rule( attrs = { diff --git a/tools/installers/BUILD b/tools/installers/BUILD index 01bc4de8c..d78a265ca 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -5,10 +5,15 @@ package( licenses = ["notice"], ) +filegroup( + name = "runsc", + srcs = ["//runsc"], +) + sh_binary( name = "head", srcs = ["head.sh"], - data = ["//runsc"], + data = [":runsc"], ) sh_binary( diff --git a/tools/installers/head.sh b/tools/installers/head.sh index 4435cb27a..9de8f138c 100755 --- a/tools/installers/head.sh +++ b/tools/installers/head.sh @@ -15,7 +15,7 @@ # limitations under the License. # Install our runtime. -third_party/gvisor/runsc/runsc install +$(dirname $0)/runsc install # Restart docker. service docker restart || true -- cgit v1.2.3 From 0efa8168c7c04ec0a4bd62e2d2eb8718b5d72ea7 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 10 Feb 2020 11:28:57 -0800 Subject: Update visibility. PiperOrigin-RevId: 294265019 --- pkg/seccomp/BUILD | 2 +- test/root/testdata/BUILD | 2 +- tools/build/BUILD | 2 +- tools/checkunsafe/BUILD | 2 +- tools/go_generics/BUILD | 2 +- tools/go_generics/go_merge/BUILD | 2 +- tools/go_stateify/BUILD | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) (limited to 'tools') diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index 742c8b79b..c5fca2ba3 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -26,7 +26,7 @@ go_library( "seccomp_rules.go", "seccomp_unsafe.go", ], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", "//pkg/bpf", diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD index bca5f9cab..6859541ad 100644 --- a/test/root/testdata/BUILD +++ b/test/root/testdata/BUILD @@ -13,6 +13,6 @@ go_library( "simple.go", ], visibility = [ - "//visibility:public", + "//:sandbox", ], ) diff --git a/tools/build/BUILD b/tools/build/BUILD index 0c0ce3f4d..00a467473 100644 --- a/tools/build/BUILD +++ b/tools/build/BUILD @@ -6,5 +6,5 @@ genrule( name = "loopback", outs = ["loopback.txt"], cmd = "touch $@", - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], ) diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index 92ba8ab06..4f1a31a6d 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -5,7 +5,7 @@ package(licenses = ["notice"]) go_tool_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", ], diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 069df3856..32a949c93 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -9,7 +9,7 @@ go_binary( "imports.go", "remove.go", ], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = ["//tools/go_generics/globals"], ) diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index b7d35e272..2fd5a200d 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -5,5 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "go_merge", srcs = ["main.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], ) diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index 6036faf7b..503cdf2e5 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -5,6 +5,6 @@ package(licenses = ["notice"]) go_binary( name = "stateify", srcs = ["main.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = ["//tools/tags"], ) -- cgit v1.2.3 From 20840bfec087d45853e81d1ac34940f3b2fb920a Mon Sep 17 00:00:00 2001 From: Brad Burlage Date: Mon, 10 Feb 2020 11:57:31 -0800 Subject: Move x86 state definition to its own file. PiperOrigin-RevId: 294271541 --- pkg/sentry/arch/BUILD | 1 + pkg/sentry/arch/arch_state_x86.go | 4 ++-- pkg/sentry/arch/arch_x86.go | 15 -------------- pkg/sentry/arch/arch_x86_impl.go | 43 +++++++++++++++++++++++++++++++++++++++ tools/build/tags.bzl | 24 +++++++++++----------- 5 files changed, 58 insertions(+), 29 deletions(-) create mode 100644 pkg/sentry/arch/arch_x86_impl.go (limited to 'tools') diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 34c0a867d..e27f21e5e 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -14,6 +14,7 @@ go_library( "arch_state_aarch64.go", "arch_state_x86.go", "arch_x86.go", + "arch_x86_impl.go", "auxv.go", "signal.go", "signal_act.go", diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index d388ee9cf..e35c9214a 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -43,8 +43,8 @@ func (e ErrFloatingPoint) Error() string { // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE -// afterLoad is invoked by stateify. -func (s *State) afterLoad() { +// afterLoadFPState is invoked by afterLoad. +func (s *State) afterLoadFPState() { old := s.x86FPState // Recreate the slice. This is done to ensure that it is aligned diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 3db8bd34b..88b40a9d1 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -155,21 +155,6 @@ func NewFloatingPointData() *FloatingPointData { return (*FloatingPointData)(&(newX86FPState()[0])) } -// State contains the common architecture bits for X86 (the build tag of this -// file ensures it's only built on x86). -// -// +stateify savable -type State struct { - // The system registers. - Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"` - - // Our floating point state. - x86FPState `state:"wait"` - - // FeatureSet is a pointer to the currently active feature set. - FeatureSet *cpuid.FeatureSet -} - // Proto returns a protobuf representation of the system registers in State. func (s State) Proto() *rpb.Registers { regs := &rpb.AMD64Registers{ diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go new file mode 100644 index 000000000..04ac283c6 --- /dev/null +++ b/pkg/sentry/arch/arch_x86_impl.go @@ -0,0 +1,43 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 i386 + +package arch + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/cpuid" +) + +// State contains the common architecture bits for X86 (the build tag of this +// file ensures it's only built on x86). +// +// +stateify savable +type State struct { + // The system registers. + Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"` + + // Our floating point state. + x86FPState `state:"wait"` + + // FeatureSet is a pointer to the currently active feature set. + FeatureSet *cpuid.FeatureSet +} + +// afterLoad is invoked by stateify. +func (s *State) afterLoad() { + s.afterLoadFPState() +} diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl index a6db44e47..558fb53ae 100644 --- a/tools/build/tags.bzl +++ b/tools/build/tags.bzl @@ -3,22 +3,28 @@ go_suffixes = [ "_386", "_386_unsafe", - "_amd64", - "_amd64_unsafe", "_aarch64", "_aarch64_unsafe", + "_amd64", + "_amd64_unsafe", "_arm", - "_arm_unsafe", "_arm64", "_arm64_unsafe", + "_arm_unsafe", + "_impl", + "_impl_unsafe", + "_linux", + "_linux_unsafe", "_mips", - "_mips_unsafe", - "_mipsle", - "_mipsle_unsafe", "_mips64", "_mips64_unsafe", "_mips64le", "_mips64le_unsafe", + "_mips_unsafe", + "_mipsle", + "_mipsle_unsafe", + "_opts", + "_opts_unsafe", "_ppc64", "_ppc64_unsafe", "_ppc64le", @@ -31,10 +37,4 @@ go_suffixes = [ "_sparc64_unsafe", "_wasm", "_wasm_unsafe", - "_linux", - "_linux_unsafe", - "_opts", - "_opts_unsafe", - "_impl", - "_impl_unsafe", ] -- cgit v1.2.3 From dc5a8e52d7004e3796feaadb0a0b0960f7289884 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 10 Feb 2020 15:43:36 -0800 Subject: Rename build to builddefs and minor build clean-up. The name 'bazel' also doesn't work because bazel will treat it specially. Fixes #1807 PiperOrigin-RevId: 294321221 --- tools/bazeldefs/BUILD | 10 ++++++ tools/bazeldefs/defs.bzl | 94 ++++++++++++++++++++++++++++++++++++++++++++++++ tools/bazeldefs/tags.bzl | 40 +++++++++++++++++++++ tools/build/BUILD | 10 ------ tools/build/defs.bzl | 94 ------------------------------------------------ tools/build/tags.bzl | 40 --------------------- tools/defs.bzl | 2 +- 7 files changed, 145 insertions(+), 145 deletions(-) create mode 100644 tools/bazeldefs/BUILD create mode 100644 tools/bazeldefs/defs.bzl create mode 100644 tools/bazeldefs/tags.bzl delete mode 100644 tools/build/BUILD delete mode 100644 tools/build/defs.bzl delete mode 100644 tools/build/tags.bzl (limited to 'tools') diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD new file mode 100644 index 000000000..00a467473 --- /dev/null +++ b/tools/bazeldefs/BUILD @@ -0,0 +1,10 @@ +package(licenses = ["notice"]) + +# In bazel, no special support is required for loopback networking. This is +# just a dummy data target that does not change the test environment. +genrule( + name = "loopback", + outs = ["loopback.txt"], + cmd = "touch $@", + visibility = ["//:sandbox"], +) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl new file mode 100644 index 000000000..08c29ff1c --- /dev/null +++ b/tools/bazeldefs/defs.bzl @@ -0,0 +1,94 @@ +"""Bazel implementations of standard rules.""" + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") +load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") +load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") +load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") +load("@pydeps//:requirements.bzl", _py_requirement = "requirement") +load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes") + +container_image = _container_image +cc_binary = _cc_binary +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +go_image = _go_image +go_embed_data = _go_embed_data +go_suffixes = _go_suffixes +gtest = "@com_google_googletest//:gtest" +loopback = "//tools/bazeldefs:loopback" +proto_library = native.proto_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = native.py_library +py_binary = native.py_binary +py_test = native.py_test + +def go_binary(name, static = False, pure = False, **kwargs): + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_tool_library(name, **kwargs): + _go_tool_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_proto_library(name, proto, **kwargs): + deps = kwargs.pop("deps", []) + _go_proto_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, + proto = proto, + deps = [dep.replace("_proto", "_go_proto") for dep in deps], + **kwargs + ) + +def go_test(name, **kwargs): + library = kwargs.pop("library", None) + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def py_requirement(name, direct = False): + return _py_requirement(name) + +def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): + values = { + "@bazel_tools//src/conditions:linux_x86_64": amd64, + "@bazel_tools//src/conditions:linux_aarch64": arm64, + } + if default: + values["//conditions:default"] = default + return select(values, **kwargs) + +def select_system(linux = ["__linux__"], **kwargs): + return linux # Only Linux supported. + +def default_installer(): + return None + +def default_net_util(): + return [] # Nothing needed. diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl new file mode 100644 index 000000000..558fb53ae --- /dev/null +++ b/tools/bazeldefs/tags.bzl @@ -0,0 +1,40 @@ +"""List of special Go suffixes.""" + +go_suffixes = [ + "_386", + "_386_unsafe", + "_aarch64", + "_aarch64_unsafe", + "_amd64", + "_amd64_unsafe", + "_arm", + "_arm64", + "_arm64_unsafe", + "_arm_unsafe", + "_impl", + "_impl_unsafe", + "_linux", + "_linux_unsafe", + "_mips", + "_mips64", + "_mips64_unsafe", + "_mips64le", + "_mips64le_unsafe", + "_mips_unsafe", + "_mipsle", + "_mipsle_unsafe", + "_opts", + "_opts_unsafe", + "_ppc64", + "_ppc64_unsafe", + "_ppc64le", + "_ppc64le_unsafe", + "_riscv64", + "_riscv64_unsafe", + "_s390x", + "_s390x_unsafe", + "_sparc64", + "_sparc64_unsafe", + "_wasm", + "_wasm_unsafe", +] diff --git a/tools/build/BUILD b/tools/build/BUILD deleted file mode 100644 index 00a467473..000000000 --- a/tools/build/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -package(licenses = ["notice"]) - -# In bazel, no special support is required for loopback networking. This is -# just a dummy data target that does not change the test environment. -genrule( - name = "loopback", - outs = ["loopback.txt"], - cmd = "touch $@", - visibility = ["//:sandbox"], -) diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl deleted file mode 100644 index 1a1a0d825..000000000 --- a/tools/build/defs.bzl +++ /dev/null @@ -1,94 +0,0 @@ -"""Bazel implementations of standard rules.""" - -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") -load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") -load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") -load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") -load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") -load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") -load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") -load("@pydeps//:requirements.bzl", _py_requirement = "requirement") -load("//tools/build:tags.bzl", _go_suffixes = "go_suffixes") - -container_image = _container_image -cc_binary = _cc_binary -cc_library = _cc_library -cc_flags_supplier = _cc_flags_supplier -cc_proto_library = _cc_proto_library -cc_test = _cc_test -cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" -go_image = _go_image -go_embed_data = _go_embed_data -go_suffixes = _go_suffixes -gtest = "@com_google_googletest//:gtest" -loopback = "//tools/build:loopback" -proto_library = native.proto_library -pkg_deb = _pkg_deb -pkg_tar = _pkg_tar -py_library = native.py_library -py_binary = native.py_binary -py_test = native.py_test - -def go_binary(name, static = False, pure = False, **kwargs): - if static: - kwargs["static"] = "on" - if pure: - kwargs["pure"] = "on" - _go_binary( - name = name, - **kwargs - ) - -def go_library(name, **kwargs): - _go_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_tool_library(name, **kwargs): - _go_tool_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_proto_library(name, proto, **kwargs): - deps = kwargs.pop("deps", []) - _go_proto_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, - proto = proto, - deps = [dep.replace("_proto", "_go_proto") for dep in deps], - **kwargs - ) - -def go_test(name, **kwargs): - library = kwargs.pop("library", None) - if library: - kwargs["embed"] = [library] - _go_test( - name = name, - **kwargs - ) - -def py_requirement(name, direct = False): - return _py_requirement(name) - -def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): - values = { - "@bazel_tools//src/conditions:linux_x86_64": amd64, - "@bazel_tools//src/conditions:linux_aarch64": arm64, - } - if default: - values["//conditions:default"] = default - return select(values, **kwargs) - -def select_system(linux = ["__linux__"], **kwargs): - return linux # Only Linux supported. - -def default_installer(): - return None - -def default_net_util(): - return [] # Nothing needed. diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl deleted file mode 100644 index 558fb53ae..000000000 --- a/tools/build/tags.bzl +++ /dev/null @@ -1,40 +0,0 @@ -"""List of special Go suffixes.""" - -go_suffixes = [ - "_386", - "_386_unsafe", - "_aarch64", - "_aarch64_unsafe", - "_amd64", - "_amd64_unsafe", - "_arm", - "_arm64", - "_arm64_unsafe", - "_arm_unsafe", - "_impl", - "_impl_unsafe", - "_linux", - "_linux_unsafe", - "_mips", - "_mips64", - "_mips64_unsafe", - "_mips64le", - "_mips64le_unsafe", - "_mips_unsafe", - "_mipsle", - "_mipsle_unsafe", - "_opts", - "_opts_unsafe", - "_ppc64", - "_ppc64_unsafe", - "_ppc64le", - "_ppc64le_unsafe", - "_riscv64", - "_riscv64_unsafe", - "_s390x", - "_s390x_unsafe", - "_sparc64", - "_sparc64_unsafe", - "_wasm", - "_wasm_unsafe", -] diff --git a/tools/defs.bzl b/tools/defs.bzl index c03b557ae..d4690cc1a 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/build:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") # Delegate directly. cc_binary = _cc_binary -- cgit v1.2.3 From 9be46e55c2aadcf40c9abd4b515c3fe899d9fa08 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 11 Feb 2020 11:40:51 -0800 Subject: Stateify: register types with full package names This is to avoid conflicts with types that share the same [short] package and type names, e.g. proc.smapsData exist in pkg/sentry/fs/proc and pkg/sentry/fsimpl/proc. Updates #1663 PiperOrigin-RevId: 294485146 --- tools/defs.bzl | 4 +++- tools/go_stateify/defs.bzl | 4 ++-- tools/go_stateify/main.go | 10 ++++++---- 3 files changed, 11 insertions(+), 7 deletions(-) (limited to 'tools') diff --git a/tools/defs.bzl b/tools/defs.bzl index d4690cc1a..46249f9c4 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -110,6 +110,8 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F """ all_srcs = srcs all_deps = deps + dirname, _, _ = native.package_name().rpartition("/") + full_pkg = dirname + "/" + name if stateify: # Only do stateification for non-state packages without manual autogen. # First, we need to segregate the input files via the special suffixes, @@ -120,7 +122,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F name = name + suffix + "_state_autogen_with_imports", srcs = srcs, imports = imports, - package = name, + package = full_pkg, out = name + suffix + "_state_autogen_with_imports.go", ) go_imports( diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index bdb966362..6a5e666f0 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -6,7 +6,7 @@ def _go_stateify_impl(ctx): # Run the stateify command. args = ["-output=%s" % output.path] - args.append("-pkg=%s" % ctx.attr.package) + args.append("-fullpkg=%s" % ctx.attr.package) if ctx.attr._statepkg: args.append("-statepkg=%s" % ctx.attr._statepkg) if ctx.attr.imports: @@ -43,7 +43,7 @@ for statified types. mandatory = False, ), "package": attr.string( - doc = "The package name for the input sources.", + doc = "The fully qualified package name for the input sources.", mandatory = True, ), "out": attr.output( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index aa9d4543e..3437aa476 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -23,6 +23,7 @@ import ( "go/parser" "go/token" "os" + "path/filepath" "reflect" "strings" "sync" @@ -31,7 +32,7 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") + fullPkg = flag.String("fullpkg", "", "fully qualified output package") imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") @@ -170,7 +171,7 @@ func main() { flag.Usage() os.Exit(1) } - if *pkg == "" { + if *fullPkg == "" { fmt.Fprintf(os.Stderr, "Error: package required.") os.Exit(1) } @@ -202,7 +203,7 @@ func main() { // Declare our emission closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) } emitZeroCheck := func(name string) { fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) @@ -233,7 +234,8 @@ func main() { } // Emit the package name. - fmt.Fprintf(outputFile, "package %s\n\n", *pkg) + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) // Emit the imports lazily. var once sync.Once -- cgit v1.2.3 From 3ad6d3056371b031fb0c16c4e365d5c7e60bdaf0 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 13 Feb 2020 09:20:30 -0800 Subject: Call py_requirement with named argument for optional kwarg. PiperOrigin-RevId: 294930818 --- benchmarks/defs.bzl | 14 +++ benchmarks/harness/BUILD | 165 ++++++++++++++++++++++------- benchmarks/harness/machine_producers/BUILD | 5 +- benchmarks/runner/BUILD | 17 +-- benchmarks/workloads/ab/BUILD | 13 +-- benchmarks/workloads/absl/BUILD | 13 +-- benchmarks/workloads/fio/BUILD | 13 +-- benchmarks/workloads/iperf/BUILD | 13 +-- benchmarks/workloads/redisbenchmark/BUILD | 13 +-- benchmarks/workloads/sysbench/BUILD | 13 +-- benchmarks/workloads/syscall/BUILD | 13 +-- tools/bazeldefs/defs.bzl | 2 +- 12 files changed, 172 insertions(+), 122 deletions(-) create mode 100644 benchmarks/defs.bzl (limited to 'tools') diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl new file mode 100644 index 000000000..56d28223e --- /dev/null +++ b/benchmarks/defs.bzl @@ -0,0 +1,14 @@ +"""Provides attributes common to many workload tests.""" + +load("//tools:defs.bzl", "py_requirement") + +test_deps = [ + py_requirement("attrs", direct = False), + py_requirement("atomicwrites", direct = False), + py_requirement("more-itertools", direct = False), + py_requirement("pathlib2", direct = False), + py_requirement("pluggy", direct = False), + py_requirement("py", direct = False), + py_requirement("pytest"), + py_requirement("six", direct = False), +] diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 4d03e3a06..48c548d59 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -1,5 +1,4 @@ -load("//tools:defs.bzl", "pkg_tar") -load("//tools:defs.bzl", "py_library", "py_requirement") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -46,16 +45,43 @@ py_library( srcs = ["container.py"], deps = [ "//benchmarks/workloads", - py_requirement("asn1crypto", False), - py_requirement("chardet", False), - py_requirement("certifi", False), - py_requirement("docker", True), - py_requirement("docker-pycreds", False), - py_requirement("idna", False), - py_requirement("ptyprocess", False), - py_requirement("requests", False), - py_requirement("urllib3", False), - py_requirement("websocket-client", False), + py_requirement( + "asn1crypto", + direct = False, + ), + py_requirement( + "chardet", + direct = False, + ), + py_requirement( + "certifi", + direct = False, + ), + py_requirement("docker"), + py_requirement( + "docker-pycreds", + direct = False, + ), + py_requirement( + "idna", + direct = False, + ), + py_requirement( + "ptyprocess", + direct = False, + ), + py_requirement( + "requests", + direct = False, + ), + py_requirement( + "urllib3", + direct = False, + ), + py_requirement( + "websocket-client", + direct = False, + ), ], ) @@ -68,17 +94,47 @@ py_library( "//benchmarks/harness:ssh_connection", "//benchmarks/harness:tunnel_dispatcher", "//benchmarks/harness/machine_mocks", - py_requirement("asn1crypto", False), - py_requirement("chardet", False), - py_requirement("certifi", False), - py_requirement("docker", True), - py_requirement("docker-pycreds", False), - py_requirement("idna", False), - py_requirement("ptyprocess", False), - py_requirement("requests", False), - py_requirement("six", False), - py_requirement("urllib3", False), - py_requirement("websocket-client", False), + py_requirement( + "asn1crypto", + direct = False, + ), + py_requirement( + "chardet", + direct = False, + ), + py_requirement( + "certifi", + direct = False, + ), + py_requirement("docker"), + py_requirement( + "docker-pycreds", + direct = False, + ), + py_requirement( + "idna", + direct = False, + ), + py_requirement( + "ptyprocess", + direct = False, + ), + py_requirement( + "requests", + direct = False, + ), + py_requirement( + "six", + direct = False, + ), + py_requirement( + "urllib3", + direct = False, + ), + py_requirement( + "websocket-client", + direct = False, + ), ], ) @@ -87,10 +143,16 @@ py_library( srcs = ["ssh_connection.py"], deps = [ "//benchmarks/harness", - py_requirement("bcrypt", False), - py_requirement("cffi", True), - py_requirement("paramiko", True), - py_requirement("cryptography", False), + py_requirement( + "bcrypt", + direct = False, + ), + py_requirement("cffi"), + py_requirement("paramiko"), + py_requirement( + "cryptography", + direct = False, + ), ], ) @@ -98,16 +160,43 @@ py_library( name = "tunnel_dispatcher", srcs = ["tunnel_dispatcher.py"], deps = [ - py_requirement("asn1crypto", False), - py_requirement("chardet", False), - py_requirement("certifi", False), - py_requirement("docker", True), - py_requirement("docker-pycreds", False), - py_requirement("idna", False), - py_requirement("pexpect", True), - py_requirement("ptyprocess", False), - py_requirement("requests", False), - py_requirement("urllib3", False), - py_requirement("websocket-client", False), + py_requirement( + "asn1crypto", + direct = False, + ), + py_requirement( + "chardet", + direct = False, + ), + py_requirement( + "certifi", + direct = False, + ), + py_requirement("docker"), + py_requirement( + "docker-pycreds", + direct = False, + ), + py_requirement( + "idna", + direct = False, + ), + py_requirement("pexpect"), + py_requirement( + "ptyprocess", + direct = False, + ), + py_requirement( + "requests", + direct = False, + ), + py_requirement( + "urllib3", + direct = False, + ), + py_requirement( + "websocket-client", + direct = False, + ), ], ) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD index 3711a397f..81f19bd08 100644 --- a/benchmarks/harness/machine_producers/BUILD +++ b/benchmarks/harness/machine_producers/BUILD @@ -31,7 +31,10 @@ py_library( deps = [ "//benchmarks/harness:machine", "//benchmarks/harness/machine_producers:machine_producer", - py_requirement("PyYAML", False), + py_requirement( + "PyYAML", + direct = False, + ), ], ) diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD index fae0ca800..471debfdf 100644 --- a/benchmarks/runner/BUILD +++ b/benchmarks/runner/BUILD @@ -1,4 +1,5 @@ load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package(licenses = ["notice"]) @@ -28,7 +29,7 @@ py_library( "//benchmarks/suites:startup", "//benchmarks/suites:sysbench", "//benchmarks/suites:syscall", - py_requirement("click", True), + py_requirement("click"), ], ) @@ -36,7 +37,7 @@ py_library( name = "commands", srcs = ["commands.py"], deps = [ - py_requirement("click", True), + py_requirement("click"), ], ) @@ -48,16 +49,8 @@ py_test( "local", "manual", ], - deps = [ + deps = test_deps + [ ":runner", - py_requirement("click", True), - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), + py_requirement("click"), ], ) diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD index 4dd91ceb3..945ac7026 100644 --- a/benchmarks/workloads/ab/BUILD +++ b/benchmarks/workloads/ab/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "ab_test", srcs = ["ab_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":ab", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD index 55dae3baa..bb1a308bf 100644 --- a/benchmarks/workloads/absl/BUILD +++ b/benchmarks/workloads/absl/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "absl_test", srcs = ["absl_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":absl", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD index 7b78e8e75..24d909c53 100644 --- a/benchmarks/workloads/fio/BUILD +++ b/benchmarks/workloads/fio/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "fio_test", srcs = ["fio_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":fio", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD index 570f40148..91b953718 100644 --- a/benchmarks/workloads/iperf/BUILD +++ b/benchmarks/workloads/iperf/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "iperf_test", srcs = ["iperf_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":iperf", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD index f472a4443..147cfedd2 100644 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "redisbenchmark_test", srcs = ["redisbenchmark_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":redisbenchmark", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD index 3834af7ed..ab2556064 100644 --- a/benchmarks/workloads/sysbench/BUILD +++ b/benchmarks/workloads/sysbench/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "sysbench_test", srcs = ["sysbench_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":sysbench", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD index dba4bb1e7..f8c43bca1 100644 --- a/benchmarks/workloads/syscall/BUILD +++ b/benchmarks/workloads/syscall/BUILD @@ -1,4 +1,5 @@ -load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test") +load("//benchmarks:defs.bzl", "test_deps") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -14,16 +15,8 @@ py_test( name = "syscall_test", srcs = ["syscall_test.py"], python_version = "PY3", - deps = [ + deps = test_deps + [ ":syscall", - py_requirement("attrs", False), - py_requirement("atomicwrites", False), - py_requirement("more-itertools", False), - py_requirement("pathlib2", False), - py_requirement("pluggy", False), - py_requirement("py", False), - py_requirement("pytest", True), - py_requirement("six", False), ], ) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 08c29ff1c..6798362dc 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -72,7 +72,7 @@ def go_test(name, **kwargs): **kwargs ) -def py_requirement(name, direct = False): +def py_requirement(name, direct = True): return _py_requirement(name) def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): -- cgit v1.2.3 From 336f758d59a8a0411c745d744a1e5c3294eaf78a Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 13 Feb 2020 16:31:33 -0800 Subject: Ensure the marshalled object doesn't escape. Add new Marshallable interface methods CopyIn/CopyOut, which can be directly called on the marshalled object, avoiding an interface indirection. Such indirections are problematic because they always cause the marshalled object to escape. PiperOrigin-RevId: 295028010 --- tools/go_marshal/gomarshal/generator.go | 39 +++++---- tools/go_marshal/gomarshal/generator_interfaces.go | 98 ++++++++++++++++++++++ tools/go_marshal/gomarshal/generator_tests.go | 1 + tools/go_marshal/gomarshal/util.go | 5 ++ tools/go_marshal/marshal/BUILD | 3 + tools/go_marshal/marshal/marshal.go | 42 +++++++++- 6 files changed, 169 insertions(+), 19 deletions(-) (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0b3f600fe..01be7c477 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -34,9 +34,9 @@ const ( usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) -// List of identifiers we use in generated code, that may conflict a -// similarly-named source identifier. Avoid problems by refusing the generate -// code when we see these. +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. @@ -44,10 +44,20 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "src", "srcs", "dst", "dsts", "blk", "buf", "err", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + // Generator drives code generation for a single invocation of the go_marshal // utility. // @@ -88,16 +98,18 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as - // used, so they're always added to the generated code. + // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } g.imports.add(marshalImport).markUsed() - // The follow imports may or may not be used by the generated - // code, depending what's required for the target types. Don't - // mark these imports as used by default. - g.imports.add(usermemImport) + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("reflect") + g.imports.add("runtime") g.imports.add(safecopyImport) g.imports.add("unsafe") + g.imports.add(usermemImport) return &g, nil } @@ -229,11 +241,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - badImportNames := make(map[string]bool) - for _, i := range badIdents { - badImportNames[i] = true - } - is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -250,7 +257,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp if len(i.name) == 1 { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } - if badImportNames[i.name] { + if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } @@ -371,6 +378,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) @@ -380,6 +388,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Write test functions. for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index a712c14dc..f25331ac5 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -504,4 +504,102 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return len, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("n, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return n, err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return n, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return len, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index bcda17c3b..cc760b6d0 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -25,6 +25,7 @@ var standardImports = []string{ "fmt", "reflect", "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", } diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 967537abf..3d86935b4 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -219,6 +219,11 @@ type sourceBuffer struct { b bytes.Buffer } +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + func (b *sourceBuffer) incIndent() { b.indent++ } diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index ad508c72f..bacfaa5a4 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -10,4 +10,7 @@ go_library( visibility = [ "//:sandbox", ], + deps = [ + "//pkg/usermem", + ], ) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index a313a27ed..10614ec4d 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -20,6 +20,26 @@ // tools/go_marshal. See the go_marshal README for details. package marshal +import ( + "gvisor.dev/gvisor/pkg/usermem" +) + +// Task provides a subset of kernel.Task, used in marshalling. We don't import +// the kernel package directly to avoid circular dependency. +type Task interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + // Marshallable represents a type that can be marshalled to and from memory. type Marshallable interface { // SizeBytes is the size of the memory representation of a type in @@ -48,13 +68,27 @@ type Marshallable interface { // MarshalBytes. MarshalUnsafe(dst []byte) - // UnmarshalUnsafe deserializes a type directly to the underlying memory - // allocated for the object by the runtime. + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. // // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes (usually by calling - // UnmarshalBytes directly). + // mechanism implemented in UnmarshalBytes. UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + CopyIn(task Task, addr usermem.Addr) (int, error) + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + CopyOut(task Task, addr usermem.Addr) (int, error) } -- cgit v1.2.3 From b2e86906ea4f7bc43b8d2d3a4735a87eca779b33 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 14 Feb 2020 03:26:42 -0800 Subject: Fix various issues related to enabling go-marshal. - Add missing build tags to files in the abi package. - Add the marshal package as a sentry dependency, allowed by deps_test. - Fix an issue with our top-level go_library BUILD rule, which incorrectly shadows the variable containing the input set of source files. This caused the expansion for the go_marshal clause to silently omit input files. - Fix formatting when copying build tags to gomarshal-generated files. - Fix a bug with import statement collision detection in go-marshal. PiperOrigin-RevId: 295112284 --- pkg/abi/linux/file_amd64.go | 2 ++ pkg/abi/linux/file_arm64.go | 2 ++ tools/defs.bzl | 12 ++++++------ tools/go_marshal/gomarshal/generator.go | 2 +- tools/go_marshal/gomarshal/util.go | 25 ++++++++++++++++++------- 5 files changed, 29 insertions(+), 14 deletions(-) (limited to 'tools') diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go index 9d307e840..8693d49c8 100644 --- a/pkg/abi/linux/file_amd64.go +++ b/pkg/abi/linux/file_amd64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build amd64 + package linux // Constants for open(2). diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go index 26a54f416..ea3adc5f5 100644 --- a/pkg/abi/linux/file_arm64.go +++ b/pkg/abi/linux/file_arm64.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// +build arm64 + package linux // Constants for open(2). diff --git a/tools/defs.bzl b/tools/defs.bzl index 46249f9c4..39f035f12 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -117,10 +117,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F # First, we need to segregate the input files via the special suffixes, # and calculate the final output set. state_sets = calculate_sets(srcs) - for (suffix, srcs) in state_sets.items(): + for (suffix, src_subset) in state_sets.items(): go_stateify( name = name + suffix + "_state_autogen_with_imports", - srcs = srcs, + srcs = src_subset, imports = imports, package = full_pkg, out = name + suffix + "_state_autogen_with_imports.go", @@ -140,10 +140,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F if marshal: # See above. marshal_sets = calculate_sets(srcs) - for (suffix, srcs) in marshal_sets.items(): + for (suffix, src_subset) in marshal_sets.items(): go_marshal( name = name + suffix + "_abi_autogen", - srcs = srcs, + srcs = src_subset, debug = False, imports = imports, package = name, @@ -172,11 +172,11 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F # See above. marshal_sets = calculate_sets(srcs) - for (suffix, srcs) in marshal_sets.items(): + for (suffix, _) in marshal_sets.items(): _go_test( name = name + suffix + "_abi_autogen_test", srcs = [name + suffix + "_abi_autogen_test.go"], - library = ":" + name + suffix, + library = ":" + name, deps = marshal_test_deps, **kwargs ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 01be7c477..fbec7bb9a 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -123,7 +123,7 @@ func (g *Generator) writeHeader() error { // Emit build tags. if t := tags.Aggregate(g.inputs); len(t) > 0 { b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n") + b.emit("\n\n") } // Package header. diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 3d86935b4..e2bca4e7c 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -310,7 +310,7 @@ func (i *importStmt) markUsed() { } func (i *importStmt) equivalent(other *importStmt) bool { - return i == other + return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. @@ -329,7 +329,7 @@ func newImportTable() *importTable { // result in a panic. func (i *importTable) merge(other *importTable) { for name, im := range other.is { - if dup, ok := i.is[name]; ok && dup.equivalent(im) { + if dup, ok := i.is[name]; ok && !dup.equivalent(im) { panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) } @@ -337,16 +337,27 @@ func (i *importTable) merge(other *importTable) { } } +func (i *importTable) addStmt(s *importStmt) *importStmt { + if old, ok := i.is[s.name]; ok && !old.equivalent(s) { + // A collision should always be between an import inserted by the + // go-marshal tool and an import from the original source file (assuming + // the original source file was valid). We could theoretically handle + // the collision by assigning a local name to our import. However, this + // would need to be plumbed throughout the generator. Given that + // collisions should be rare, simply panic on collision. + panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) + } + i.is[s.name] = s + return s +} + func (i *importTable) add(s string) *importStmt { n := newImport(s) - i.is[n.name] = n - return n + return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - n := newImportFromSpec(spec, f) - i.is[n.name] = n - return n + return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns -- cgit v1.2.3 From 48d9aa7ab371691d28a44533f67e495173554098 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 14 Feb 2020 15:19:53 -0800 Subject: Add a minimal binary target for escape analysis on go-marshal. Note that this is not an automated test. PiperOrigin-RevId: 295238672 --- tools/go_marshal/test/BUILD | 14 +++- tools/go_marshal/test/benchmark_test.go | 2 +- tools/go_marshal/test/escape.go | 114 ++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 tools/go_marshal/test/escape.go (limited to 'tools') diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index e345e3a8e..f27c5ce52 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") licenses(["notice"]) @@ -27,3 +27,15 @@ go_library( marshal = True, deps = ["//tools/go_marshal/test/external"], ) + +go_binary( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + gc_goopts = ["-m"], + deps = [ + ":test", + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go index e12403741..c79defe9e 100644 --- a/tools/go_marshal/test/benchmark_test.go +++ b/tools/go_marshal/test/benchmark_test.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" - test "gvisor.dev/gvisor/tools/go_marshal/test" + "gvisor.dev/gvisor/tools/go_marshal/test" ) // Marshalling using the standard encoding/binary package. diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go new file mode 100644 index 000000000..184f05ea3 --- /dev/null +++ b/tools/go_marshal/test/escape.go @@ -0,0 +1,114 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This binary provides a convienient target for analyzing how the go-marshal +// API causes its various arguments to escape to the heap. To use, build and +// observe the output from the go compiler's escape analysis: +// +// $ bazel build :escape +// ... +// escape.go:67:2: moved to heap: task +// escape.go:77:31: make([]byte, size) escapes to heap +// escape.go:87:31: make([]byte, size) escapes to heap +// escape.go:96:6: moved to heap: stat +// ... +// +// This is not an automated test, but simply a minimal binary for easy analysis. +package main + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + task.CopyOutBytes(addr, buf) +} + +func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + task.CopyOutBytes(addr, buf) +} + +// Expected escapes: +// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface. +func doCopyOut() { + task := dummyTask{} + var stat test.Stat + stat.CopyOut(&task, usermem.Addr(0xf000ba12)) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalBytesDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalUnsafeDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface. +func doMarshalBytesViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface. +func doMarshalUnsafeViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} + +func main() { + doCopyOut() + doMarshalBytesDirect() + doMarshalUnsafeDirect() + doMarshalBytesViaMarshallable() + doMarshalUnsafeViaMarshallable() +} -- cgit v1.2.3 From 3d32ad1367b4e84a0822808f44bd7b9f9351db71 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 14 Feb 2020 18:31:55 -0800 Subject: Generate implementation of io.WriterTo via go-marshal. PiperOrigin-RevId: 295269654 --- tools/go_marshal/gomarshal/generator.go | 6 ++- tools/go_marshal/gomarshal/generator_interfaces.go | 46 ++++++++++++++++++++++ tools/go_marshal/gomarshal/generator_tests.go | 34 ++++++++++++++-- tools/go_marshal/marshal/marshal.go | 4 ++ 4 files changed, 84 insertions(+), 6 deletions(-) (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index fbec7bb9a..0294ba5ba 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -101,14 +101,16 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } - g.imports.add(marshalImport).markUsed() + // The following imports may or may not be used by the generated code, // depending on what's required for the target types. Don't mark these as // used by default. + g.imports.add("io") g.imports.add("reflect") g.imports.add("runtime") - g.imports.add(safecopyImport) g.imports.add("unsafe") + g.imports.add(marshalImport) + g.imports.add(safecopyImport) g.imports.add(usermemImport) return &g, nil diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index f25331ac5..22aae0f6b 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -602,4 +602,50 @@ func (g *interfaceGenerator) emitMarshallable() { } }) g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index cc760b6d0..5ad97af14 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -22,6 +22,7 @@ import ( ) var standardImports = []string{ + "bytes", "fmt", "reflect", "testing", @@ -117,26 +118,50 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n") }) @@ -146,6 +171,7 @@ func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index 10614ec4d..e521b50bd 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -21,6 +21,8 @@ package marshal import ( + "io" + "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,6 +44,8 @@ type Task interface { // Marshallable represents a type that can be marshalled to and from memory. type Marshallable interface { + io.WriterTo + // SizeBytes is the size of the memory representation of a type in // marshalled form. SizeBytes() int -- cgit v1.2.3 From 5cc0bbbafb2dc7d248bc3141b4cfa022d420abd1 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sat, 15 Feb 2020 00:00:04 -0800 Subject: Ensure Marshallable.SizeBytes() always works on a typed nil pointer. This lets go-marshal replace various calls to binary.Size() throughout the sentry without requiring concrete objects. PiperOrigin-RevId: 295299965 --- tools/go_marshal/gomarshal/generator_interfaces.go | 2 +- tools/go_marshal/gomarshal/generator_tests.go | 15 +++++++++++++++ tools/go_marshal/marshal/marshal.go | 4 ++++ 3 files changed, 20 insertions(+), 1 deletion(-) (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 22aae0f6b..3aa299ccd 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -301,7 +301,7 @@ func (g *interfaceGenerator) emitMarshallable() { primitiveSize += size } else { g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) } }, selector: func(n, tX, tSel *ast.Ident) { diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 5ad97af14..8c28b00d0 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -167,11 +167,26 @@ func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { }) } +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)") + }) + g.emit("}\n") + }) +} + func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index e521b50bd..20353850d 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -48,6 +48,10 @@ type Marshallable interface { // SizeBytes is the size of the memory representation of a type in // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). SizeBytes() int // MarshalBytes serializes a copy of a type to dst. dst must be at least -- cgit v1.2.3 From 737a3d072ef6e3edf5099505e41deed49f9e5b5c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 18 Feb 2020 15:08:11 -0800 Subject: go-marshal: Stop complaining about files with no +marshal types. Since we tag entire packages as marshallable, due to conditional compiling for different architectures we can end up with sets of source files that don't contain any marshallable types. It's safe to silently ignore this scenario. PiperOrigin-RevId: 295831871 --- tools/go_marshal/gomarshal/generator.go | 11 ----------- 1 file changed, 11 deletions(-) (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0294ba5ba..d3c2f72f5 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -338,17 +338,6 @@ func (g *Generator) Run() error { } } - // Tool was invoked with input files with no data structures marked for code - // generation. This is probably not what the user intended. - if len(impls) == 0 { - var buf bytes.Buffer - fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") - for _, i := range g.inputs { - fmt.Fprintf(&buf, " %s\n", i) - } - abort(buf.String()) - } - // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { -- cgit v1.2.3 From 660cfdff3f2ac771c6f0f18834921cfc043b2f3a Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 19 Feb 2020 15:41:22 -0800 Subject: Handle situations where go-marshal generates an empty test file. This can happen due to conditional compilation, where a subset of the source files contain no marshallable types. go-marshal is still required to write an output file in these cases, since bazel defines the output package before calling go-marshal. PiperOrigin-RevId: 296074321 --- tools/go_marshal/gomarshal/generator.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index d3c2f72f5..0fa868415 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -380,6 +380,26 @@ func (g *Generator) writeTests(ts []*testGenerator) error { } // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func ExampleEmptyTestSuite() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err -- cgit v1.2.3 From 30794512d3977ebb2b185e5e9cfb969d558a07a4 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 19 Feb 2020 18:20:52 -0800 Subject: Add basic microbenchmarks. PiperOrigin-RevId: 296104390 --- WORKSPACE | 10 + test/perf/BUILD | 114 +++++++ test/perf/linux/BUILD | 356 +++++++++++++++++++++ test/perf/linux/clock_getres_benchmark.cc | 39 +++ test/perf/linux/clock_gettime_benchmark.cc | 60 ++++ test/perf/linux/death_benchmark.cc | 36 +++ test/perf/linux/epoll_benchmark.cc | 99 ++++++ test/perf/linux/fork_benchmark.cc | 350 +++++++++++++++++++++ test/perf/linux/futex_benchmark.cc | 248 +++++++++++++++ test/perf/linux/getdents_benchmark.cc | 149 +++++++++ test/perf/linux/getpid_benchmark.cc | 37 +++ test/perf/linux/gettid_benchmark.cc | 38 +++ test/perf/linux/mapping_benchmark.cc | 163 ++++++++++ test/perf/linux/open_benchmark.cc | 56 ++++ test/perf/linux/pipe_benchmark.cc | 66 ++++ test/perf/linux/randread_benchmark.cc | 100 ++++++ test/perf/linux/read_benchmark.cc | 53 ++++ test/perf/linux/sched_yield_benchmark.cc | 37 +++ test/perf/linux/send_recv_benchmark.cc | 372 ++++++++++++++++++++++ test/perf/linux/seqwrite_benchmark.cc | 66 ++++ test/perf/linux/signal_benchmark.cc | 59 ++++ test/perf/linux/sleep_benchmark.cc | 60 ++++ test/perf/linux/stat_benchmark.cc | 62 ++++ test/perf/linux/unlink_benchmark.cc | 66 ++++ test/perf/linux/write_benchmark.cc | 52 ++++ test/runner/BUILD | 22 ++ test/runner/defs.bzl | 218 +++++++++++++ test/runner/gtest/BUILD | 9 + test/runner/gtest/gtest.go | 154 +++++++++ test/runner/runner.go | 477 ++++++++++++++++++++++++++++ test/syscalls/BUILD | 21 +- test/syscalls/build_defs.bzl | 180 ----------- test/syscalls/gtest/BUILD | 9 - test/syscalls/gtest/gtest.go | 93 ------ test/syscalls/linux/alarm.cc | 3 +- test/syscalls/linux/exec.cc | 3 +- test/syscalls/linux/fcntl.cc | 2 +- test/syscalls/linux/itimer.cc | 3 +- test/syscalls/linux/prctl.cc | 2 +- test/syscalls/linux/prctl_setuid.cc | 2 +- test/syscalls/linux/proc.cc | 2 +- test/syscalls/linux/ptrace.cc | 2 +- test/syscalls/linux/rtsignal.cc | 3 +- test/syscalls/linux/seccomp.cc | 2 +- test/syscalls/linux/sigiret.cc | 3 +- test/syscalls/linux/signalfd.cc | 2 +- test/syscalls/linux/sigstop.cc | 2 +- test/syscalls/linux/sigtimedwait.cc | 3 +- test/syscalls/linux/timers.cc | 2 +- test/syscalls/linux/vfork.cc | 2 +- test/syscalls/syscall_test_runner.go | 482 ----------------------------- test/syscalls/syscall_test_runner.sh | 34 -- test/util/BUILD | 3 +- test/util/test_main.cc | 2 +- test/util/test_util.h | 1 + test/util/test_util_impl.cc | 14 + tools/bazeldefs/defs.bzl | 1 + tools/defs.bzl | 3 +- 58 files changed, 3666 insertions(+), 843 deletions(-) create mode 100644 test/perf/BUILD create mode 100644 test/perf/linux/BUILD create mode 100644 test/perf/linux/clock_getres_benchmark.cc create mode 100644 test/perf/linux/clock_gettime_benchmark.cc create mode 100644 test/perf/linux/death_benchmark.cc create mode 100644 test/perf/linux/epoll_benchmark.cc create mode 100644 test/perf/linux/fork_benchmark.cc create mode 100644 test/perf/linux/futex_benchmark.cc create mode 100644 test/perf/linux/getdents_benchmark.cc create mode 100644 test/perf/linux/getpid_benchmark.cc create mode 100644 test/perf/linux/gettid_benchmark.cc create mode 100644 test/perf/linux/mapping_benchmark.cc create mode 100644 test/perf/linux/open_benchmark.cc create mode 100644 test/perf/linux/pipe_benchmark.cc create mode 100644 test/perf/linux/randread_benchmark.cc create mode 100644 test/perf/linux/read_benchmark.cc create mode 100644 test/perf/linux/sched_yield_benchmark.cc create mode 100644 test/perf/linux/send_recv_benchmark.cc create mode 100644 test/perf/linux/seqwrite_benchmark.cc create mode 100644 test/perf/linux/signal_benchmark.cc create mode 100644 test/perf/linux/sleep_benchmark.cc create mode 100644 test/perf/linux/stat_benchmark.cc create mode 100644 test/perf/linux/unlink_benchmark.cc create mode 100644 test/perf/linux/write_benchmark.cc create mode 100644 test/runner/BUILD create mode 100644 test/runner/defs.bzl create mode 100644 test/runner/gtest/BUILD create mode 100644 test/runner/gtest/gtest.go create mode 100644 test/runner/runner.go delete mode 100644 test/syscalls/build_defs.bzl delete mode 100644 test/syscalls/gtest/BUILD delete mode 100644 test/syscalls/gtest/gtest.go delete mode 100644 test/syscalls/syscall_test_runner.go delete mode 100755 test/syscalls/syscall_test_runner.sh (limited to 'tools') diff --git a/WORKSPACE b/WORKSPACE index 2827c3a26..ff0196dc6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -330,3 +330,13 @@ http_archive( "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", ], ) + +http_archive( + name = "com_google_benchmark", + sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", + strip_prefix = "benchmark-1.5.0", + urls = [ + "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", + "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", + ], +) diff --git a/test/perf/BUILD b/test/perf/BUILD new file mode 100644 index 000000000..7a2bf10ed --- /dev/null +++ b/test/perf/BUILD @@ -0,0 +1,114 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + test = "//test/perf/linux:clock_getres_benchmark", +) + +syscall_test( + test = "//test/perf/linux:clock_gettime_benchmark", +) + +syscall_test( + test = "//test/perf/linux:death_benchmark", +) + +syscall_test( + test = "//test/perf/linux:epoll_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:fork_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:futex_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:getdents_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:getpid_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:gettid_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:mapping_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:open_benchmark", +) + +syscall_test( + test = "//test/perf/linux:pipe_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:randread_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:read_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:sched_yield_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:send_recv_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:seqwrite_benchmark", +) + +syscall_test( + size = "large", + test = "//test/perf/linux:signal_benchmark", +) + +syscall_test( + test = "//test/perf/linux:sleep_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:stat_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:unlink_benchmark", +) + +syscall_test( + size = "large", + add_overlay = True, + test = "//test/perf/linux:write_benchmark", +) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD new file mode 100644 index 000000000..b4e907826 --- /dev/null +++ b/test/perf/linux/BUILD @@ -0,0 +1,356 @@ +load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "getpid_benchmark", + testonly = 1, + srcs = [ + "getpid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "send_recv_benchmark", + testonly = 1, + srcs = [ + "send_recv_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "gettid_benchmark", + testonly = 1, + srcs = [ + "gettid_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "sched_yield_benchmark", + testonly = 1, + srcs = [ + "sched_yield_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "clock_getres_benchmark", + testonly = 1, + srcs = [ + "clock_getres_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + ], +) + +cc_binary( + name = "clock_gettime_benchmark", + testonly = 1, + srcs = [ + "clock_gettime_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:test_main", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "open_benchmark", + testonly = 1, + srcs = [ + "open_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) + +cc_binary( + name = "read_benchmark", + testonly = 1, + srcs = [ + "read_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "randread_benchmark", + testonly = 1, + srcs = [ + "randread_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "write_benchmark", + testonly = 1, + srcs = [ + "write_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "seqwrite_benchmark", + testonly = 1, + srcs = [ + "seqwrite_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/random", + ], +) + +cc_binary( + name = "pipe_benchmark", + testonly = 1, + srcs = [ + "pipe_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + ], +) + +cc_binary( + name = "fork_benchmark", + testonly = 1, + srcs = [ + "fork_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + +cc_binary( + name = "futex_benchmark", + testonly = 1, + srcs = [ + "futex_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "epoll_benchmark", + testonly = 1, + srcs = [ + "epoll_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:epoll_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "death_benchmark", + testonly = 1, + srcs = [ + "death_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "mapping_benchmark", + testonly = 1, + srcs = [ + "mapping_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:memory_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "signal_benchmark", + testonly = 1, + srcs = [ + "signal_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "getdents_benchmark", + testonly = 1, + srcs = [ + "getdents_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "sleep_benchmark", + testonly = 1, + srcs = [ + "sleep_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:logging", + "//test/util:test_main", + ], +) + +cc_binary( + name = "stat_benchmark", + testonly = 1, + srcs = [ + "stat_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "unlink_benchmark", + testonly = 1, + srcs = [ + "unlink_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc new file mode 100644 index 000000000..b051293ad --- /dev/null +++ b/test/perf/linux/clock_getres_benchmark.cc @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +// clock_getres(1) is very nearly a no-op syscall, but it does require copying +// out to a userspace struct. It thus provides a nice small copy-out benchmark. +void BM_ClockGetRes(benchmark::State& state) { + struct timespec ts; + for (auto _ : state) { + clock_getres(CLOCK_MONOTONIC, &ts); + } +} + +BENCHMARK(BM_ClockGetRes); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc new file mode 100644 index 000000000..6691bebd9 --- /dev/null +++ b/test/perf/linux/clock_gettime_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_ClockGettimeThreadCPUTime(benchmark::State& state) { + clockid_t clockid; + ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid)); + struct timespec tp; + + for (auto _ : state) { + clock_gettime(clockid, &tp); + } +} + +BENCHMARK(BM_ClockGettimeThreadCPUTime); + +void BM_VDSOClockGettime(benchmark::State& state) { + const clockid_t clock = state.range(0); + struct timespec tp; + absl::Time start = absl::Now(); + + // Don't benchmark the calibration phase. + while (absl::Now() < start + absl::Milliseconds(2100)) { + clock_gettime(clock, &tp); + } + + for (auto _ : state) { + clock_gettime(clock, &tp); + } +} + +BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc new file mode 100644 index 000000000..cb2b6fd07 --- /dev/null +++ b/test/perf/linux/death_benchmark.cc @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing +// the ability of gVisor (on whatever platform) to execute all the related +// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH. +TEST(DeathTest, ZeroEqualsOne) { + EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), ""); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc new file mode 100644 index 000000000..0b121338a --- /dev/null +++ b/test/perf/linux/epoll_benchmark.cc @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/epoll_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Returns a new eventfd. +PosixErrorOr NewEventFD() { + int fd = eventfd(0, /* flags = */ 0); + MaybeSave(); + if (fd < 0) { + return PosixError(errno, "eventfd"); + } + return FileDescriptor(fd); +} + +// Also stolen from epoll.cc unit tests. +void BM_EpollTimeout(benchmark::State& state) { + constexpr int kFDsPerEpoll = 3; + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + + std::vector eventfds; + for (int i = 0; i < kFDsPerEpoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + } + + struct epoll_event result[kFDsPerEpoll]; + int timeout_ms = state.range(0); + + for (auto _ : state) { + EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms)); + } +} + +BENCHMARK(BM_EpollTimeout)->Range(0, 8); + +// Also stolen from epoll.cc unit tests. +void BM_EpollAllEvents(benchmark::State& state) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + const int fds_per_epoll = state.range(0); + constexpr uint64_t kEventVal = 5; + + std::vector eventfds; + for (int i = 0; i < fds_per_epoll; i++) { + eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD())); + ASSERT_NO_ERRNO( + RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0)); + + ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)), + SyscallSucceedsWithValue(sizeof(kEventVal))); + } + + std::vector result(fds_per_epoll); + + for (auto _ : state) { + EXPECT_EQ(fds_per_epoll, + epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0)); + } +} + +BENCHMARK(BM_EpollAllEvents)->Range(2, 1024); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc new file mode 100644 index 000000000..84fdbc8a0 --- /dev/null +++ b/test/perf/linux/fork_benchmark.cc @@ -0,0 +1,350 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "absl/synchronization/barrier.h" +#include "benchmark/benchmark.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBusyMax = 250; + +// Do some CPU-bound busy-work. +int busy(int max) { + // Prevent the compiler from optimizing this work away, + volatile int count = 0; + + for (int i = 1; i < max; i++) { + for (int j = 2; j < i / 2; j++) { + if (i % j == 0) { + count++; + } + } + } + + return count; +} + +void BM_CPUBoundUniprocess(benchmark::State& state) { + for (auto _ : state) { + busy(kBusyMax); + } +} + +BENCHMARK(BM_CPUBoundUniprocess); + +void BM_CPUBoundAsymmetric(benchmark::State& state) { + const size_t max = state.max_iterations; + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < max; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + ASSERT_TRUE(state.KeepRunningBatch(max)); + + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + ASSERT_FALSE(state.KeepRunning()); +} + +BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime(); + +void BM_CPUBoundSymmetric(benchmark::State& state) { + std::vector children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + ASSERT_FALSE(state.KeepRunning()); + }); + + const int processes = state.range(0); + for (int i = 0; i < processes; i++) { + size_t cur = (state.max_iterations + (processes - 1)) / processes; + if ((state.iterations() + cur) >= state.max_iterations) { + cur = state.max_iterations - state.iterations(); + } + pid_t child = fork(); + if (child == 0) { + for (int i = 0; i < cur; i++) { + busy(kBusyMax); + } + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + if (cur > 0) { + // We can have a zero cur here, depending. + ASSERT_TRUE(state.KeepRunningBatch(cur)); + } + children.push_back(child); + } +} + +BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime(); + +// Child routine for ProcessSwitch/ThreadSwitch. +// Reads from readfd and writes the result to writefd. +void SwitchChild(int readfd, int writefd) { + while (1) { + char buf; + int ret = ReadFd(readfd, &buf, 1); + if (ret == 0) { + break; + } + TEST_CHECK_MSG(ret == 1, "read failed"); + + ret = WriteFd(writefd, &buf, 1); + if (ret == -1) { + TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure"); + break; + } + TEST_CHECK_MSG(ret == 1, "write failed"); + } +} + +// Send bytes in a loop through a series of pipes, each passing through a +// different process. +// +// Proc 0 Proc 1 +// * ----------> * +// ^ Pipe 1 | +// | | +// | Pipe 0 | Pipe 2 +// | | +// | | +// | Pipe 3 v +// * <---------- * +// Proc 3 Proc 2 +// +// This exercises context switching through multiple processes. +void BM_ProcessSwitch(benchmark::State& state) { + // Code below assumes there are at least two processes. + const int num_processes = state.range(0); + ASSERT_GE(num_processes, 2); + + std::vector children; + auto child_cleanup = Cleanup([&] { + for (const pid_t child : children) { + int status; + EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status)); + EXPECT_EQ(0, WEXITSTATUS(status)); + } + }); + + // Must come after children, as the FDs must be closed before the children + // will exit. + std::vector read_fds; + std::vector write_fds; + + for (int i = 0; i < num_processes; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This process is one of the processes in the loop. It will be considered + // index 0. + for (int i = 1; i < num_processes; i++) { + // Read from current pipe index, write to next. + const int read_index = i; + const int read_fd = read_fds[read_index].get(); + + const int write_index = (i + 1) % num_processes; + const int write_fd = write_fds[write_index].get(); + + // std::vector isn't safe to use from the fork child. + FileDescriptor* read_array = read_fds.data(); + FileDescriptor* write_array = write_fds.data(); + + pid_t child = fork(); + if (!child) { + // Close all other FDs. + for (int j = 0; j < num_processes; j++) { + if (j != read_index) { + read_array[j].reset(); + } + if (j != write_index) { + write_array[j].reset(); + } + } + + SwitchChild(read_fd, write_fd); + _exit(0); + } + ASSERT_THAT(child, SyscallSucceeds()); + children.push_back(child); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } +} + +BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime(); + +// Equivalent to BM_ThreadSwitch using threads instead of processes. +void BM_ThreadSwitch(benchmark::State& state) { + // Code below assumes there are at least two threads. + const int num_threads = state.range(0); + ASSERT_GE(num_threads, 2); + + // Must come after threads, as the FDs must be closed before the children + // will exit. + std::vector> threads; + std::vector read_fds; + std::vector write_fds; + + for (int i = 0; i < num_threads; i++) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + read_fds.emplace_back(fds[0]); + write_fds.emplace_back(fds[1]); + } + + // This thread is one of the threads in the loop. It will be considered + // index 0. + for (int i = 1; i < num_threads; i++) { + // Read from current pipe index, write to next. + // + // Transfer ownership of the FDs to the thread. + const int read_index = i; + const int read_fd = read_fds[read_index].release(); + + const int write_index = (i + 1) % num_threads; + const int write_fd = write_fds[write_index].release(); + + threads.emplace_back(std::make_unique([read_fd, write_fd] { + FileDescriptor read(read_fd); + FileDescriptor write(write_fd); + SwitchChild(read.get(), write.get()); + })); + } + + // Read from current pipe index (0), write to next (1). + const int read_index = 0; + const int read_fd = read_fds[read_index].get(); + + const int write_index = 1; + const int write_fd = write_fds[write_index].get(); + + // Kick start the loop. + char buf = 'a'; + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + + for (auto _ : state) { + ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1)); + } + + // The two FDs still owned by this thread are closed, causing the next thread + // to exit its loop and close its FDs, and so on until all threads exit. +} + +BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime(); + +void BM_ThreadStart(benchmark::State& state) { + const int num_threads = state.range(0); + + for (auto _ : state) { + state.PauseTiming(); + + auto barrier = new absl::Barrier(num_threads + 1); + std::vector> threads; + + state.ResumeTiming(); + + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back(std::make_unique([barrier] { + if (barrier->Block()) { + delete barrier; + } + })); + } + + if (barrier->Block()) { + delete barrier; + } + + state.PauseTiming(); + + for (const auto& thread : threads) { + thread->Join(); + } + + state.ResumeTiming(); + } +} + +BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime(); + +// Benchmark the complete fork + exit + wait. +void BM_ProcessLifecycle(benchmark::State& state) { + const int num_procs = state.range(0); + + std::vector pids(num_procs); + for (auto _ : state) { + for (size_t i = 0; i < num_procs; ++i) { + int pid = fork(); + if (pid == 0) { + _exit(0); + } + ASSERT_THAT(pid, SyscallSucceeds()); + pids[i] = pid; + } + + for (const int pid : pids) { + ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0), + SyscallSucceedsWithValue(pid)); + } + } +} + +BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc new file mode 100644 index 000000000..b349d50bf --- /dev/null +++ b/test/perf/linux/futex_benchmark.cc @@ -0,0 +1,248 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +inline int FutexWait(std::atomic* v, int32_t val) { + return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr); +} + +inline int FutexWaitRelativeTimeout(std::atomic* v, int32_t val, + const struct timespec* reltime) { + return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime); +} + +inline int FutexWaitAbsoluteTimeout(std::atomic* v, int32_t val, + const struct timespec* abstime) { + return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime); +} + +inline int FutexWaitBitsetAbsoluteTimeout(std::atomic* v, int32_t val, + int32_t bits, + const struct timespec* abstime) { + return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME, + val, abstime, nullptr, bits); +} + +inline int FutexWake(std::atomic* v, int32_t count) { + return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count); +} + +// This just uses FUTEX_WAKE on an address with nothing waiting, very simple. +void BM_FutexWakeNop(benchmark::State& state) { + std::atomic v(0); + + for (auto _ : state) { + EXPECT_EQ(0, FutexWake(&v, 1)); + } +} + +BENCHMARK(BM_FutexWakeNop); + +// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the +// syscall won't wait. +void BM_FutexWaitNop(benchmark::State& state) { + std::atomic v(0); + + for (auto _ : state) { + EXPECT_EQ(-EAGAIN, FutexWait(&v, 1)); + } +} + +BENCHMARK(BM_FutexWaitNop); + +// This uses FUTEX_WAIT with a timeout on an address whose value never +// changes, such that it always times out. Timeout overhead can be estimated by +// timer overruns for short timeouts. +void BM_FutexWaitTimeout(benchmark::State& state) { + const int timeout_ns = state.range(0); + std::atomic v(0); + auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + + for (auto _ : state) { + EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts)); + } +} + +BENCHMARK(BM_FutexWaitTimeout) + ->Arg(1) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000); + +// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME. +void BM_FutexWaitBitset(benchmark::State& state) { + std::atomic v(0); + int timeout_ns = state.range(0); + auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + for (auto _ : state) { + EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts)); + } +} + +BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000); + +int64_t GetCurrentMonotonicTimeNanos() { + struct timespec ts; + TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1); + return ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +void SpinNanos(int64_t delay_ns) { + if (delay_ns <= 0) { + return; + } + const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns; + while (GetCurrentMonotonicTimeNanos() < end) { + // spin + } +} + +// Each iteration of FutexRoundtripDelayed involves a thread sending a futex +// wakeup to another thread, which spins for delay_us and then sends a futex +// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs + +// futex/scheduling overhead). +void BM_FutexRoundtripDelayed(benchmark::State& state) { + const int delay_us = state.range(0); + + const int64_t delay_ns = delay_us * 1000; + // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce + // the probability that the wakeup comes before the wait, preventing the wait + // from ever taking effect and causing the benchmark to underestimate the + // actual wakeup time. + constexpr int64_t kBeforeWakeDelayNs = 500; + std::atomic v(0); + ScopedThread t([&] { + for (int i = 0; i < state.max_iterations; i++) { + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 0) { + FutexWait(&v, 0); + } + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(0, std::memory_order_release); + FutexWake(&v, 1); + } + }); + for (auto _ : state) { + SpinNanos(kBeforeWakeDelayNs + delay_ns); + v.store(1, std::memory_order_release); + FutexWake(&v, 1); + SpinNanos(delay_ns); + while (v.load(std::memory_order_acquire) == 1) { + FutexWait(&v, 1); + } + } +} + +BENCHMARK(BM_FutexRoundtripDelayed) + ->Arg(0) + ->Arg(10) + ->Arg(20) + ->Arg(50) + ->Arg(100); + +// FutexLock is a simple, dumb futex based lock implementation. +// It will try to acquire the lock by atomically incrementing the +// lock word. If it did not increment the lock from 0 to 1, someone +// else has the lock, so it will FUTEX_WAIT until it is woken in +// the unlock path. +class FutexLock { + public: + FutexLock() : lock_word_(0) {} + + void lock(struct timespec* deadline) { + int32_t val; + while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) != + 1) { + // If we didn't get the lock by incrementing from 0 to 1, + // do a FUTEX_WAIT with the desired current value set to + // val. If val is no longer what the atomic increment returned, + // someone might have set it to 0 so we can try to acquire + // again. + int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline); + if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) { + continue; + } else { + FAIL() << "unexpected FUTEX_WAIT return: " << ret; + } + } + } + + void unlock() { + // Store 0 into the lock word and wake one waiter. We intentionally + // ignore the return value of the FUTEX_WAKE here, since there may be + // no waiters to wake anyway. + lock_word_.store(0, std::memory_order_release); + (void)FutexWake(&lock_word_, 1); + } + + private: + std::atomic lock_word_; +}; + +FutexLock* test_lock; // Used below. + +void FutexContend(benchmark::State& state, int thread_index, + struct timespec* deadline) { + int counter = 0; + if (thread_index == 0) { + test_lock = new FutexLock(); + } + for (auto _ : state) { + test_lock->lock(deadline); + counter++; + test_lock->unlock(); + } + if (thread_index == 0) { + delete test_lock; + } + state.SetItemsProcessed(state.iterations()); +} + +void BM_FutexContend(benchmark::State& state) { + FutexContend(state, state.thread_index, nullptr); +} + +BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime(); + +void BM_FutexDeadlineContend(benchmark::State& state) { + auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10)); + FutexContend(state, state.thread_index, &deadline); +} + +BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc new file mode 100644 index 000000000..0e03975b4 --- /dev/null +++ b/test/perf/linux/getdents_benchmark.cc @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +#ifndef SYS_getdents64 +#if defined(__x86_64__) +#define SYS_getdents64 217 +#elif defined(__aarch64__) +#define SYS_getdents64 217 +#else +#error "Unknown architecture" +#endif +#endif // SYS_getdents64 + +namespace gvisor { +namespace testing { + +namespace { + +constexpr int kBufferSize = 16384; + +PosixErrorOr CreateDirectory(int count, + std::vector* files) { + ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir()); + + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (int i = 0; i < count; i++) { + auto file = NewTempRelPath(); + auto res = MknodAt(dfd, file, S_IFREG | 0644, 0); + RETURN_IF_ERRNO(res); + files->push_back(file); + } + + return std::move(dir); +} + +PosixError CleanupDirectory(const TempPath& dir, + std::vector* files) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd, + Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + for (auto it = files->begin(); it != files->end(); ++it) { + auto res = UnlinkAt(dfd, *it, 0); + RETURN_IF_ERRNO(res); + } + return NoError(); +} + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a single FD. +void BM_GetdentsSameFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds()); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); + +// Creates a directory containing `files` files, and reads all the directory +// entries from the directory using a new FD each time. +void BM_GetdentsNewFD(benchmark::State& state) { + // Create directory with given files. + const int count = state.range(0); + + // Keep a vector of all of the file TempPaths that is destroyed before dir. + // + // Normally, we'd simply allow dir to recursively clean up the contained + // files, but that recursive cleanup uses getdents, which may be very slow in + // extreme benchmarks. + TempPath dir; + std::vector files; + dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files)); + char buffer[kBufferSize]; + + // We read all directory entries on each iteration, but report this as a + // "batch" iteration so that reported times are per file. + while (state.KeepRunningBatch(count)) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY)); + + int ret; + do { + ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize), + SyscallSucceeds()); + } while (ret > 0); + } + + ASSERT_NO_ERRNO(CleanupDirectory(dir, &files)); + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 16)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc new file mode 100644 index 000000000..db74cb264 --- /dev/null +++ b/test/perf/linux/getpid_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Getpid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_getpid); + } +} + +BENCHMARK(BM_Getpid); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc new file mode 100644 index 000000000..8f4961f5e --- /dev/null +++ b/test/perf/linux/gettid_benchmark.cc @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Gettid(benchmark::State& state) { + for (auto _ : state) { + syscall(SYS_gettid); + } +} + +BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc new file mode 100644 index 000000000..39c30fe69 --- /dev/null +++ b/test/perf/linux/mapping_benchmark.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/memory_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Conservative value for /proc/sys/vm/max_map_count, which limits the number of +// VMAs, minus a safety margin for VMAs that already exist for the test binary. +// The default value for max_map_count is +// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530. +constexpr size_t kMaxVMAs = 64001; + +// Map then unmap pages without touching them. +void BM_MapUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map, touch, then unmap pages. +void BM_MapTouchUnmap(benchmark::State& state) { + // Number of pages to map. + const int pages = state.range(0); + + while (state.KeepRunning()) { + void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast(addr); + char* end = c + pages * kPageSize; + while (c < end) { + *c = 42; + c += kPageSize; + } + + int ret = munmap(addr, pages * kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } +} + +BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime(); + +// Map and touch many pages, unmapping all at once. +// +// NOTE(b/111429208): This is a regression test to ensure performant mapping and +// allocation even with tons of mappings. +void BM_MapTouchMany(benchmark::State& state) { + // Number of pages to map. + const int page_count = state.range(0); + + while (state.KeepRunning()) { + std::vector pages; + + for (int i = 0; i < page_count; i++) { + void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed"); + + char* c = reinterpret_cast(addr); + *c = 42; + + pages.push_back(addr); + } + + for (void* addr : pages) { + int ret = munmap(addr, kPageSize); + TEST_CHECK_MSG(ret == 0, "munmap failed"); + } + } + + state.SetBytesProcessed(kPageSize * page_count * state.iterations()); +} + +BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime(); + +void BM_PageFault(benchmark::State& state) { + // Map the region in which we will take page faults. To ensure that each page + // fault maps only a single page, each page we touch must correspond to a + // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However, + // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that + // if there are background threads running, they can't inadvertently creating + // mappings in our gaps that are unmapped when the test ends. + size_t test_pages = kMaxVMAs; + // Ensure that test_pages is odd, since we want the test region to both + // begin and end with a mapped page. + if (test_pages % 2 == 0) { + test_pages--; + } + const size_t test_region_bytes = test_pages * kPageSize; + // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on + // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE + // so that Linux pre-allocates the shmem file used to back the mapping. + Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE)); + for (size_t i = 0; i < test_pages / 2; i++) { + ASSERT_THAT( + mprotect(reinterpret_cast(m.addr() + ((2 * i + 1) * kPageSize)), + kPageSize, PROT_NONE), + SyscallSucceeds()); + } + + const size_t mapped_pages = test_pages / 2 + 1; + // "Start" at the end of the mapped region to force the mapped region to be + // reset, since we mapped it with MAP_POPULATE. + size_t cur_page = mapped_pages; + for (auto _ : state) { + if (cur_page >= mapped_pages) { + // We've reached the end of our mapped region and have to reset it to + // incur page faults again. + state.PauseTiming(); + ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED), + SyscallSucceeds()); + cur_page = 0; + state.ResumeTiming(); + } + const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize); + const char c = *reinterpret_cast(addr); + benchmark::DoNotOptimize(c); + cur_page++; + } +} + +BENCHMARK(BM_PageFault)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc new file mode 100644 index 000000000..68008f6d5 --- /dev/null +++ b/test/perf/linux/open_benchmark.cc @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Open(benchmark::State& state) { + const int size = state.range(0); + std::vector cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + cache.emplace_back(std::move(path)); + } + + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + close(fd); + } +} + +BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc new file mode 100644 index 000000000..8f5f6a2a3 --- /dev/null +++ b/test/perf/linux/pipe_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Pipe(benchmark::State& state) { + int fds[2]; + TEST_CHECK(pipe(fds) == 0); + + const int size = state.range(0); + std::vector wbuf(size); + std::vector rbuf(size); + RandomizeBuffer(wbuf.data(), size); + + ScopedThread t([&] { + auto const fd = fds[1]; + for (int i = 0; i < state.max_iterations; i++) { + TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size); + } + }); + + for (auto _ : state) { + TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size); + } + + t.Join(); + + close(fds[0]); + close(fds[1]); + + state.SetBytesProcessed(static_cast(size) * + static_cast(state.iterations())); +} + +BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc new file mode 100644 index 000000000..b0eb8c24e --- /dev/null +++ b/test/perf/linux/randread_benchmark.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Create a 1GB file that will be read from at random positions. This should +// invalid any performance gains from caching. +const uint64_t kFileSize = 1ULL << 30; + +// How many bytes to write at once to initialize the file used to read from. +const uint32_t kWriteSize = 65536; + +// Largest benchmarked read unit. +const uint32_t kMaxRead = 1UL << 26; + +TempPath CreateFile(uint64_t file_size) { + auto path = TempPath::CreateFile().ValueOrDie(); + FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie(); + + // Try to minimize syscalls by using maximum size writev() requests. + std::vector buffer(kWriteSize); + RandomizeBuffer(buffer.data(), buffer.size()); + const std::vector> iovecs_list = + GenerateIovecs(file_size, buffer.data(), buffer.size()); + for (const auto& iovecs : iovecs_list) { + TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); + } + + return path; +} + +// Global test state, initialized once per process lifetime. +struct GlobalState { + const TempPath tmpfile; + explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {} +}; + +GlobalState& GetGlobalState() { + // This gets created only once throughout the lifetime of the process. + // Use a dynamically allocated object (that is never deleted) to avoid order + // of destruction of static storage variables issues. + static GlobalState* const state = + // The actual file size is the maximum random seek range (kFileSize) + the + // maximum read size so we can read that number of bytes at the end of the + // file. + new GlobalState(CreateFile(kFileSize + kMaxRead)); + return *state; +} + +void BM_RandRead(benchmark::State& state) { + const int size = state.range(0); + + GlobalState& global_state = GetGlobalState(); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY)); + std::vector buf(size); + + unsigned int seed = 1; + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), + rand_r(&seed) % kFileSize) == size); + } + + state.SetBytesProcessed(static_cast(size) * + static_cast(state.iterations())); +} + +BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc new file mode 100644 index 000000000..62445867d --- /dev/null +++ b/test/perf/linux/read_benchmark.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Read(benchmark::State& state) { + const int size = state.range(0); + const std::string contents(size, 0); + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY)); + + std::vector buf(size); + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); + } + + state.SetBytesProcessed(static_cast(size) * + static_cast(state.iterations())); +} + +BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc new file mode 100644 index 000000000..6756b5575 --- /dev/null +++ b/test/perf/linux/sched_yield_benchmark.cc @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Sched_yield(benchmark::State& state) { + for (auto ignored : state) { + TEST_CHECK(sched_yield() == 0); + } +} + +BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc new file mode 100644 index 000000000..d73e49523 --- /dev/null +++ b/test/perf/linux/send_recv_benchmark.cc @@ -0,0 +1,372 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include + +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "benchmark/benchmark.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/logging.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr ssize_t kMessageSize = 1024; + +class Message { + public: + explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {} + + explicit Message(int byte, int sz) : Message(byte, sz, 0) {} + + explicit Message(int byte, int sz, int cmsg_sz) + : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) { + iov_.iov_base = buffer_.data(); + iov_.iov_len = sz; + hdr_.msg_iov = &iov_; + hdr_.msg_iovlen = 1; + hdr_.msg_control = cmsg_buffer_.data(); + hdr_.msg_controllen = cmsg_sz; + } + + struct msghdr* header() { + return &hdr_; + } + + private: + std::vector buffer_; + std::vector cmsg_buffer_; + struct iovec iov_ = {}; + struct msghdr hdr_ = {}; +}; + +void BM_Recvmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvmsg)->UseRealTime(); + +void BM_Sendmsg(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + Message send_msg('a'), recv_msg; + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + recvmsg(recv_socket.get(), recv_msg.header(), 0); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendmsg(send_socket.get(), send_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendmsg)->UseRealTime(); + +void BM_Recvfrom(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&send_socket, &send_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + } + }); + + int bytes_received = 0; + for (auto ignored : state) { + int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_Recvfrom)->UseRealTime(); + +void BM_Sendto(benchmark::State& state) { + int sockets[2]; + TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0); + FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]); + absl::Notification notification; + char send_buffer[kMessageSize], recv_buffer[kMessageSize]; + + ScopedThread t([&recv_socket, &recv_buffer, ¬ification] { + while (!notification.HasBeenNotified()) { + recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr, + nullptr); + } + }); + + int64_t bytes_sent = 0; + for (auto ignored : state) { + int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0); + TEST_CHECK(n > 0); + bytes_sent += n; + } + + notification.Notify(); + send_socket.reset(); + + state.SetBytesProcessed(bytes_sent); +} + +BENCHMARK(BM_Sendto)->UseRealTime(); + +PosixErrorOr InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate +// space for control messages. Note that we do not expect to receive any. +void BM_RecvmsgWithControlBuf(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + absl::Notification notification; + Message send_msg('a'); + // Create a msghdr with a buffer allocated for control messages. + Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24); + + ScopedThread t([&send_msg, &send_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + sendmsg(send_socket.get(), send_msg.header(), 0); + } + }); + + int64_t bytes_received = 0; + for (auto ignored : state) { + int n = recvmsg(recv_socket.get(), recv_msg.header(), 0); + TEST_CHECK(n > 0); + bytes_received += n; + } + + notification.Notify(); + recv_socket.reset(); + + state.SetBytesProcessed(bytes_received); +} + +BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime(); + +// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes. +// +// state.Args[0] indicates whether the underlying socket should be blocking or +// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking. +// state.Args[1] is the size of the payload to be used per sendmsg call. +void BM_SendmsgTCP(benchmark::State& state) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // Check if we want to run the test w/ a blocking send socket + // or non-blocking. + const int blocking = state.range(0); + if (!blocking) { + // Set the send FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds()); + } + + absl::Notification notification; + + // Get the buffer size we should use for this iteration of the test. + const int buf_size = state.range(1); + Message send_msg('a', buf_size), recv_msg(0, buf_size); + + ScopedThread t([&recv_msg, &recv_socket, ¬ification] { + while (!notification.HasBeenNotified()) { + TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0); + } + }); + + int64_t bytes_sent = 0; + int ncalls = 0; + for (auto ignored : state) { + int sent = 0; + while (true) { + struct msghdr hdr = {}; + struct iovec iov = {}; + struct msghdr* snd_header = send_msg.header(); + iov.iov_base = static_cast(snd_header->msg_iov->iov_base) + sent; + iov.iov_len = snd_header->msg_iov->iov_len - sent; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0); + ncalls++; + if (n > 0) { + sent += n; + if (sent == buf_size) { + break; + } + // n can be > 0 but less than requested size. In which case we don't + // poll. + continue; + } + // Poll the fd for it to become writable. + struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), + SyscallSucceedsWithValue(0)); + } + bytes_sent += static_cast(sent); + } + + notification.Notify(); + send_socket.reset(); + state.SetBytesProcessed(bytes_sent); +} + +void Args(benchmark::internal::Benchmark* benchmark) { + for (int blocking = 0; blocking < 2; blocking++) { + for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) { + benchmark->Args({blocking, buf_size}); + } + } +} + +BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc new file mode 100644 index 000000000..af49e4477 --- /dev/null +++ b/test/perf/linux/seqwrite_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// The maximum file size of the test file, when writes get beyond this point +// they wrap around. This should be large enough to blow away caches. +const uint64_t kMaxFile = 1 << 30; + +// Perform writes of various sizes sequentially to one file. Wraps around if it +// goes above a certain maximum file size. +void BM_SeqWrite(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector buf(size); + RandomizeBuffer(buf.data(), buf.size()); + + // Start writes at offset 0. + uint64_t offset = 0; + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) == + buf.size()); + offset += buf.size(); + // Wrap around if going above the maximum file size. + if (offset >= kMaxFile) { + offset = 0; + } + } + + state.SetBytesProcessed(static_cast(size) * + static_cast(state.iterations())); +} + +BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc new file mode 100644 index 000000000..a6928df58 --- /dev/null +++ b/test/perf/linux/signal_benchmark.cc @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void FixupHandler(int sig, siginfo_t* si, void* void_ctx) { + static unsigned int dataval = 0; + + // Skip the offending instruction. + ucontext_t* ctx = reinterpret_cast(void_ctx); + ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast(&dataval); +} + +void BM_FaultSignalFixup(benchmark::State& state) { + // Set up the signal handler. + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_sigaction = FixupHandler; + sa.sa_flags = SA_SIGINFO; + TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0); + + // Fault, fault, fault. + for (auto _ : state) { + register volatile unsigned int* ptr asm("rax"); + + // Trigger the segfault. + ptr = nullptr; + *ptr = 0; + } +} + +BENCHMARK(BM_FaultSignalFixup)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc new file mode 100644 index 000000000..99ef05117 --- /dev/null +++ b/test/perf/linux/sleep_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Sleep for 'param' nanoseconds. +void BM_Sleep(benchmark::State& state) { + const int nanoseconds = state.range(0); + + for (auto _ : state) { + struct timespec ts; + ts.tv_sec = 0; + ts.tv_nsec = nanoseconds; + + int ret; + do { + ret = syscall(SYS_nanosleep, &ts, &ts); + if (ret < 0) { + TEST_CHECK(errno == EINTR); + } + } while (ret < 0); + } +} + +BENCHMARK(BM_Sleep) + ->Arg(0) + ->Arg(1) + ->Arg(1000) // 1us + ->Arg(1000 * 1000) // 1ms + ->Arg(10 * 1000 * 1000) // 10ms + ->Arg(50 * 1000 * 1000) // 50ms + ->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc new file mode 100644 index 000000000..f15424482 --- /dev/null +++ b/test/perf/linux/stat_benchmark.cc @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a file in a nested directory hierarchy at least `depth` directories +// deep, and stats that file multiple times. +void BM_Stat(benchmark::State& state) { + // Create nested directories with given depth. + int depth = state.range(0); + const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string dir_path = top_dir.path(); + + while (depth-- > 0) { + // Don't use TempPath because it will make paths too long to use. + // + // The top_dir destructor will clean up this whole tree. + dir_path = JoinPath(dir_path, absl::StrCat(depth)); + ASSERT_NO_ERRNO(Mkdir(dir_path, 0755)); + } + + // Create the file that will be stat'd. + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path)); + + struct stat st; + for (auto _ : state) { + ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds()); + } +} + +BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc new file mode 100644 index 000000000..92243a042 --- /dev/null +++ b/test/perf/linux/unlink_benchmark.cc @@ -0,0 +1,66 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Creates a directory containing `files` files, and unlinks all the files. +void BM_Unlink(benchmark::State& state) { + // Create directory with given files. + const int file_count = state.range(0); + + // We unlink all files on each iteration, but report this as a "batch" + // iteration so that reported times are per file. + TempPath dir; + while (state.KeepRunningBatch(file_count)) { + state.PauseTiming(); + // N.B. dir is declared outside the loop so that destruction of the previous + // iteration's directory occurs here, inside of PauseTiming. + dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + std::vector files; + for (int i = 0; i < file_count; i++) { + TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + files.push_back(std::move(file)); + } + state.ResumeTiming(); + + while (!files.empty()) { + // Destructor unlinks. + files.pop_back(); + } + } + + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc new file mode 100644 index 000000000..7b060c70e --- /dev/null +++ b/test/perf/linux/write_benchmark.cc @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_Write(benchmark::State& state) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY)); + + const int size = state.range(0); + std::vector buf(size); + RandomizeBuffer(buf.data(), size); + + for (auto _ : state) { + TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size); + } + + state.SetBytesProcessed(static_cast(size) * + static_cast(state.iterations())); +} + +BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/runner/BUILD b/test/runner/BUILD new file mode 100644 index 000000000..9959ef9b0 --- /dev/null +++ b/test/runner/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "runner", + testonly = 1, + srcs = ["runner.go"], + data = [ + "//runsc", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//runsc/specutils", + "//runsc/testutil", + "//test/runner/gtest", + "//test/uds", + "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl new file mode 100644 index 000000000..5e97c1867 --- /dev/null +++ b/test/runner/defs.bzl @@ -0,0 +1,218 @@ +"""Defines a rule for syscall test targets.""" + +load("//tools:defs.bzl", "loopback") + +def _runner_test_impl(ctx): + # Generate a runner binary. + runner = ctx.actions.declare_file("%s-runner" % ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "set -euf -x -o pipefail", + "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then", + " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"", + "fi", + "exec %s %s %s\n" % ( + ctx.files.runner[0].short_path, + " ".join(ctx.attr.runner_args), + ctx.files.test[0].short_path, + ), + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + + # Return with all transitive files. + runfiles = ctx.runfiles( + transitive_files = depset(transitive = [ + depset(target.data_runfiles.files) + for target in (ctx.attr.runner, ctx.attr.test) + if hasattr(target, "data_runfiles") + ]), + files = ctx.files.runner + ctx.files.test, + collect_default = True, + collect_data = True, + ) + return [DefaultInfo(executable = runner, runfiles = runfiles)] + +_runner_test = rule( + attrs = { + "runner": attr.label( + default = "//test/runner:runner", + ), + "test": attr.label( + mandatory = True, + ), + "runner_args": attr.string_list(), + "data": attr.label_list( + allow_files = True, + ), + }, + test = True, + implementation = _runner_test_impl, +) + +def _syscall_test( + test, + shard_count, + size, + platform, + use_tmpfs, + tags, + network = "none", + file_access = "exclusive", + overlay = False, + add_uds_tree = False): + # Prepend "runsc" to non-native platform names. + full_platform = platform if platform == "native" else "runsc_" + platform + + # Name the test appropriately. + name = test.split(":")[1] + "_" + full_platform + if file_access == "shared": + name += "_shared" + if overlay: + name += "_overlay" + if network != "none": + name += "_" + network + "net" + + # Apply all tags. + if tags == None: + tags = [] + + # Add the full_platform and file access in a tag to make it easier to run + # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. + tags += [full_platform, "file_" + file_access] + + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + + # Disable off-host networking. + tags.append("requires-net:loopback") + + # Add tag to prevent the tests from running in a Bazel sandbox. + # TODO(b/120560048): Make the tests run without this tag. + tags.append("no-sandbox") + + # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is + # more stable. + if platform == "kvm": + tags.append("manual") + tags.append("requires-kvm") + + # TODO(b/112165693): Remove when tests pass reliably. + tags.append("notap") + + runner_args = [ + # Arguments are passed directly to runner binary. + "--platform=" + platform, + "--network=" + network, + "--use-tmpfs=" + str(use_tmpfs), + "--file-access=" + file_access, + "--overlay=" + str(overlay), + "--add-uds-tree=" + str(add_uds_tree), + ] + + # Call the rule above. + _runner_test( + name = name, + test = test, + runner_args = runner_args, + data = [loopback], + size = size, + tags = tags, + shard_count = shard_count, + ) + +def syscall_test( + test, + shard_count = 5, + size = "small", + use_tmpfs = False, + add_overlay = False, + add_uds_tree = False, + add_hostinet = False, + tags = None): + """syscall_test is a macro that will create targets for all platforms. + + Args: + test: the test target. + shard_count: shards for defined tests. + size: the defined test size. + use_tmpfs: use tmpfs in the defined tests. + add_overlay: add an overlay test. + add_uds_tree: add a UDS test. + add_hostinet: add a hostinet test. + tags: starting test tags. + """ + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "native", + use_tmpfs = False, + add_uds_tree = add_uds_tree, + tags = tags, + ) + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "kvm", + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = tags, + ) + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = tags, + ) + + if add_overlay: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. + add_uds_tree = add_uds_tree, + tags = tags, + overlay = True, + ) + + if not use_tmpfs: + # Also test shared gofer access. + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = tags, + file_access = "shared", + ) + + if add_hostinet: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "ptrace", + use_tmpfs = use_tmpfs, + network = "host", + add_uds_tree = add_uds_tree, + tags = tags, + ) diff --git a/test/runner/gtest/BUILD b/test/runner/gtest/BUILD new file mode 100644 index 000000000..de4b2727c --- /dev/null +++ b/test/runner/gtest/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gtest", + srcs = ["gtest.go"], + visibility = ["//:sandbox"], +) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go new file mode 100644 index 000000000..23bf7b5f6 --- /dev/null +++ b/test/runner/gtest/gtest.go @@ -0,0 +1,154 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gtest contains helpers for running google-test tests from Go. +package gtest + +import ( + "fmt" + "os/exec" + "strings" +) + +var ( + // listTestFlag is the flag that will list tests in gtest binaries. + listTestFlag = "--gtest_list_tests" + + // filterTestFlag is the flag that will filter tests in gtest binaries. + filterTestFlag = "--gtest_filter" + + // listBechmarkFlag is the flag that will list benchmarks in gtest binaries. + listBenchmarkFlag = "--benchmark_list_tests" + + // filterBenchmarkFlag is the flag that will run specified benchmarks. + filterBenchmarkFlag = "--benchmark_filter" +) + +// TestCase is a single gtest test case. +type TestCase struct { + // Suite is the suite for this test. + Suite string + + // Name is the name of this individual test. + Name string + + // benchmark indicates that this is a benchmark. In this case, the + // suite will be empty, and we will use the appropriate test and + // benchmark flags. + benchmark bool +} + +// FullName returns the name of the test including the suite. It is suitable to +// pass to "-gtest_filter". +func (tc TestCase) FullName() string { + return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) +} + +// Args returns arguments to be passed when invoking the test. +func (tc TestCase) Args() []string { + if tc.benchmark { + return []string{ + fmt.Sprintf("%s=^$", filterTestFlag), + fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), + } + } + return []string{ + fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()), + fmt.Sprintf("%s=^$", filterBenchmarkFlag), + } +} + +// ParseTestCases calls a gtest test binary to list its test and returns a +// slice with the name and suite of each test. +// +// If benchmarks is true, then benchmarks will be included in the list of test +// cases provided. Note that this requires the binary to support the +// benchmarks_list_tests flag. +func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) { + // Run to extract test cases. + args := append([]string{listTestFlag}, extraArgs...) + cmd := exec.Command(testBin, args...) + out, err := cmd.Output() + if err != nil { + exitErr, ok := err.(*exec.ExitError) + if !ok { + return nil, fmt.Errorf("could not enumerate gtest tests: %v", err) + } + return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr) + } + + // Parse test output. + var t []TestCase + var suite string + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // New suite? + if !strings.HasPrefix(line, " ") { + suite = strings.TrimSuffix(strings.TrimSpace(line), ".") + continue + } + + // Individual test. + name := strings.TrimSpace(line) + + // Do we have a suite yet? + if suite == "" { + return nil, fmt.Errorf("test without a suite: %v", name) + } + + // Add this individual test. + t = append(t, TestCase{ + Suite: suite, + Name: name, + }) + + } + + // Finished? + if !benchmarks { + return t, nil + } + + // Run again to extract benchmarks. + args = append([]string{listBenchmarkFlag}, extraArgs...) + cmd = exec.Command(testBin, args...) + out, err = cmd.Output() + if err != nil { + exitErr, ok := err.(*exec.ExitError) + if !ok { + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err) + } + return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) + } + + // Parse benchmark output. + for _, line := range strings.Split(string(out), "\n") { + // Strip comments. + line = strings.Split(line, "#")[0] + + // Single benchmark. + name := strings.TrimSpace(line) + + // Add the single benchmark. + t = append(t, TestCase{ + Suite: "Benchmarks", + Name: name, + benchmark: true, + }) + } + + return t, nil +} diff --git a/test/runner/runner.go b/test/runner/runner.go new file mode 100644 index 000000000..a78ef38e0 --- /dev/null +++ b/test/runner/runner.go @@ -0,0 +1,477 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary syscall_test_runner runs the syscall test suites in gVisor +// containers and on the host platform. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "syscall" + "testing" + "time" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/runsc/specutils" + "gvisor.dev/gvisor/runsc/testutil" + "gvisor.dev/gvisor/test/runner/gtest" + "gvisor.dev/gvisor/test/uds" +) + +var ( + debug = flag.Bool("debug", false, "enable debug logs") + strace = flag.Bool("strace", false, "enable strace logs") + platform = flag.String("platform", "ptrace", "platform to run on") + network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") + useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") + fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") + overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") + parallel = flag.Bool("parallel", false, "run tests in parallel") + runscPath = flag.String("runsc", "", "path to runsc binary") + + addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") +) + +// runTestCaseNative runs the test case directly on the host machine. +func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { + // These tests might be running in parallel, so make sure they have a + // unique test temp dir. + tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") + if err != nil { + t.Fatalf("could not create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Replace TEST_TMPDIR in the current environment with something + // unique. + env := os.Environ() + newEnvVar := "TEST_TMPDIR=" + tmpDir + var found bool + for i, kv := range env { + if strings.HasPrefix(kv, "TEST_TMPDIR=") { + env[i] = newEnvVar + found = true + break + } + } + if !found { + env = append(env, newEnvVar) + } + // Remove env variables that cause the gunit binary to write output + // files, since they will stomp on eachother, and on the output files + // from this go test. + env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) + + // Remove shard env variables so that the gunit binary does not try to + // intepret them. + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + + if *addUDSTree { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + t.Fatalf("failed to create socket tree: %v", err) + } + defer cleanup() + + env = append(env, "TEST_UDS_TREE="+socketDir) + // On Linux, the concept of "attach" location doesn't exist. + // Just pass the same path to make these test identical. + env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) + } + + cmd := exec.Command(testBin, tc.Args()...) + cmd.Env = env + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) + t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) + } +} + +// runRunsc runs spec in runsc in a standard test configuration. +// +// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. +// +// Returns an error if the sandboxed application exits non-zero. +func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { + bundleDir, err := testutil.SetupBundleDir(spec) + if err != nil { + return fmt.Errorf("SetupBundleDir failed: %v", err) + } + defer os.RemoveAll(bundleDir) + + rootDir, err := testutil.SetupRootDir() + if err != nil { + return fmt.Errorf("SetupRootDir failed: %v", err) + } + defer os.RemoveAll(rootDir) + + name := tc.FullName() + id := testutil.UniqueContainerID() + log.Infof("Running test %q in container %q", name, id) + specutils.LogSpec(spec) + + args := []string{ + "-root", rootDir, + "-network", *network, + "-log-format=text", + "-TESTONLY-unsafe-nonroot=true", + "-net-raw=true", + fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), + "-watchdog-action=panic", + "-platform", *platform, + "-file-access", *fileAccess, + } + if *overlay { + args = append(args, "-overlay") + } + if *debug { + args = append(args, "-debug", "-log-packets=true") + } + if *strace { + args = append(args, "-strace") + } + if *addUDSTree { + args = append(args, "-fsgofer-host-uds") + } + + if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(tdir, 0755); err != nil { + return fmt.Errorf("could not create test dir: %v", err) + } + debugLogDir, err := ioutil.TempDir(tdir, "runsc") + if err != nil { + return fmt.Errorf("could not create temp dir: %v", err) + } + debugLogDir += "/" + log.Infof("runsc logs: %s", debugLogDir) + args = append(args, "-debug-log", debugLogDir) + + // Default -log sends messages to stderr which makes reading the test log + // difficult. Instead, drop them when debug log is enabled given it's a + // better place for these messages. + args = append(args, "-log=/dev/null") + } + + // Current process doesn't have CAP_SYS_ADMIN, create user namespace and run + // as root inside that namespace to get it. + rArgs := append(args, "run", "--bundle", bundleDir, id) + cmd := exec.Command(*runscPath, rArgs...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, + // Set current user/group as root inside the namespace. + UidMappings: []syscall.SysProcIDMap{ + {ContainerID: 0, HostID: os.Getuid(), Size: 1}, + }, + GidMappings: []syscall.SysProcIDMap{ + {ContainerID: 0, HostID: os.Getgid(), Size: 1}, + }, + GidMappingsEnableSetgroups: false, + Credential: &syscall.Credential{ + Uid: 0, + Gid: 0, + }, + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGTERM) + go func() { + s, ok := <-sig + if !ok { + return + } + log.Warningf("%s: Got signal: %v", name, s) + done := make(chan bool) + dArgs := append([]string{}, args...) + dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) + go func(dArgs []string) { + cmd := exec.Command(*runscPath, dArgs...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Run() + done <- true + }(dArgs) + + timeout := time.After(3 * time.Second) + select { + case <-timeout: + log.Infof("runsc debug --stacks is timeouted") + case <-done: + } + + log.Warningf("Send SIGTERM to the sandbox process") + dArgs = append(args, "debug", + fmt.Sprintf("--signal=%d", syscall.SIGTERM), + id) + cmd := exec.Command(*runscPath, dArgs...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Run() + }() + + err = cmd.Run() + + signal.Stop(sig) + close(sig) + + return err +} + +// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing. +func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { + socketDir, cleanup, err := uds.CreateSocketTree("/tmp") + if err != nil { + return nil, fmt.Errorf("failed to create socket tree: %v", err) + } + + // Standard access to entire tree. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets", + Source: socketDir, + Type: "bind", + }) + + // Individial attach points for each socket to test mounts that attach + // directly to the sockets. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/echo", + Source: filepath.Join(socketDir, "stream/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/stream/nonlistening", + Source: filepath.Join(socketDir, "stream/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/echo", + Source: filepath.Join(socketDir, "seqpacket/echo"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/seqpacket/nonlistening", + Source: filepath.Join(socketDir, "seqpacket/nonlistening"), + Type: "bind", + }) + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp/sockets-attach/dgram/null", + Source: filepath.Join(socketDir, "dgram/null"), + Type: "bind", + }) + + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets") + spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach") + + return cleanup, nil +} + +// runsTestCaseRunsc runs the test case in runsc. +func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { + // Run a new container with the test executable and filter for the + // given test suite and name. + spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) + + // Mark the root as writeable, as some tests attempt to + // write to the rootfs, and expect EACCES, not EROFS. + spec.Root.Readonly = false + + // Test spec comes with pre-defined mounts that we don't want. Reset it. + spec.Mounts = nil + if *useTmpfs { + // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on + // features only available in gVisor's internal tmpfs may fail. + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: "/tmp", + Type: "tmpfs", + }) + } else { + // Use a gofer-backed directory as '/tmp'. + // + // Tests might be running in parallel, so make sure each has a + // unique test temp dir. + // + // Some tests (e.g., sticky) access this mount from other + // users, so make sure it is world-accessible. + tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") + if err != nil { + t.Fatalf("could not create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + if err := os.Chmod(tmpDir, 0777); err != nil { + t.Fatalf("could not chmod temp dir: %v", err) + } + + spec.Mounts = append(spec.Mounts, specs.Mount{ + Type: "bind", + Destination: "/tmp", + Source: tmpDir, + }) + } + + // Set environment variables that indicate we are + // running in gVisor with the given platform and network. + platformVar := "TEST_ON_GVISOR" + networkVar := "GVISOR_NETWORK" + env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) + + // Remove env variables that cause the gunit binary to write output + // files, since they will stomp on eachother, and on the output files + // from this go test. + env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) + + // Remove shard env variables so that the gunit binary does not try to + // intepret them. + env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + + // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to + // be backed by tmpfs. + for i, kv := range env { + if strings.HasPrefix(kv, "TEST_TMPDIR=") { + env[i] = "TEST_TMPDIR=/tmp" + break + } + } + + spec.Process.Env = env + + if *addUDSTree { + cleanup, err := setupUDSTree(spec) + if err != nil { + t.Fatalf("error creating UDS tree: %v", err) + } + defer cleanup() + } + + if err := runRunsc(tc, spec); err != nil { + t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) + } +} + +// filterEnv returns an environment with the blacklisted variables removed. +func filterEnv(env, blacklist []string) []string { + var out []string + for _, kv := range env { + ok := true + for _, k := range blacklist { + if strings.HasPrefix(kv, k+"=") { + ok = false + break + } + } + if ok { + out = append(out, kv) + } + } + return out +} + +func fatalf(s string, args ...interface{}) { + fmt.Fprintf(os.Stderr, s+"\n", args...) + os.Exit(1) +} + +func matchString(a, b string) (bool, error) { + return a == b, nil +} + +func main() { + flag.Parse() + if flag.NArg() != 1 { + fatalf("test must be provided") + } + testBin := flag.Args()[0] // Only argument. + + log.SetLevel(log.Info) + if *debug { + log.SetLevel(log.Debug) + } + + if *platform != "native" && *runscPath == "" { + if err := testutil.ConfigureExePath(); err != nil { + panic(err.Error()) + } + *runscPath = specutils.ExePath + } + + // Make sure stdout and stderr are opened with O_APPEND, otherwise logs + // from outside the sandbox can (and will) stomp on logs from inside + // the sandbox. + for _, f := range []*os.File{os.Stdout, os.Stderr} { + flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) + if err != nil { + fatalf("error getting file flags for %v: %v", f, err) + } + if flags&unix.O_APPEND == 0 { + flags |= unix.O_APPEND + if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { + fatalf("error setting file flags for %v: %v", f, err) + } + } + } + + // Get all test cases in each binary. + testCases, err := gtest.ParseTestCases(testBin, true) + if err != nil { + fatalf("ParseTestCases(%q) failed: %v", testBin, err) + } + + // Get subset of tests corresponding to shard. + indices, err := testutil.TestIndicesForShard(len(testCases)) + if err != nil { + fatalf("TestsForShard() failed: %v", err) + } + + // Resolve the absolute path for the binary. + testBin, err = filepath.Abs(testBin) + if err != nil { + fatalf("Abs() failed: %v", err) + } + + // Run the tests. + var tests []testing.InternalTest + for _, tci := range indices { + // Capture tc. + tc := testCases[tci] + tests = append(tests, testing.InternalTest{ + Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), + F: func(t *testing.T) { + if *parallel { + t.Parallel() + } + if *platform == "native" { + // Run the test case on host. + runTestCaseNative(testBin, tc, t) + } else { + // Run the test case in runsc. + runTestCaseRunsc(testBin, tc, t) + } + }, + }) + } + + testing.Main(matchString, tests, nil, nil) +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 31d239c0e..d69ac8356 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,5 +1,4 @@ -load("//tools:defs.bzl", "go_binary") -load("//test/syscalls:build_defs.bzl", "syscall_test") +load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -726,21 +725,3 @@ syscall_test(test = "//test/syscalls/linux:proc_net_unix_test") syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test") syscall_test(test = "//test/syscalls/linux:proc_net_udp_test") - -go_binary( - name = "syscall_test_runner", - testonly = 1, - srcs = ["syscall_test_runner.go"], - data = [ - "//runsc", - ], - deps = [ - "//pkg/log", - "//runsc/specutils", - "//runsc/testutil", - "//test/syscalls/gtest", - "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl deleted file mode 100644 index cbab85ef7..000000000 --- a/test/syscalls/build_defs.bzl +++ /dev/null @@ -1,180 +0,0 @@ -"""Defines a rule for syscall test targets.""" - -load("//tools:defs.bzl", "loopback") - -def syscall_test( - test, - shard_count = 5, - size = "small", - use_tmpfs = False, - add_overlay = False, - add_uds_tree = False, - add_hostinet = False, - tags = None): - """syscall_test is a macro that will create targets for all platforms. - - Args: - test: the test target. - shard_count: shards for defined tests. - size: the defined test size. - use_tmpfs: use tmpfs in the defined tests. - add_overlay: add an overlay test. - add_uds_tree: add a UDS test. - add_hostinet: add a hostinet test. - tags: starting test tags. - """ - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "kvm", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - if add_overlay: - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. - add_uds_tree = add_uds_tree, - tags = tags, - overlay = True, - ) - - if not use_tmpfs: - # Also test shared gofer access. - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - file_access = "shared", - ) - - if add_hostinet: - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - network = "host", - add_uds_tree = add_uds_tree, - tags = tags, - ) - -def _syscall_test( - test, - shard_count, - size, - platform, - use_tmpfs, - tags, - network = "none", - file_access = "exclusive", - overlay = False, - add_uds_tree = False): - test_name = test.split(":")[1] - - # Prepend "runsc" to non-native platform names. - full_platform = platform if platform == "native" else "runsc_" + platform - - name = test_name + "_" + full_platform - if file_access == "shared": - name += "_shared" - if overlay: - name += "_overlay" - if network != "none": - name += "_" + network + "net" - - if tags == None: - tags = [] - - # Add the full_platform and file access in a tag to make it easier to run - # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. - tags += [full_platform, "file_" + file_access] - - # Hash this target into one of 15 buckets. This can be used to - # randomly split targets between different workflows. - hash15 = hash(native.package_name() + name) % 15 - tags.append("hash15:" + str(hash15)) - - # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until - # we figure out how to request ipv4 sockets on Guitar machines. - if network == "host": - tags.append("noguitar") - - # Disable off-host networking. - tags.append("requires-net:loopback") - - # Add tag to prevent the tests from running in a Bazel sandbox. - # TODO(b/120560048): Make the tests run without this tag. - tags.append("no-sandbox") - - # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is - # more stable. - if platform == "kvm": - tags.append("manual") - tags.append("requires-kvm") - - # TODO(b/112165693): Remove when tests pass reliably. - tags.append("notap") - - args = [ - # Arguments are passed directly to syscall_test_runner binary. - "--test-name=" + test_name, - "--platform=" + platform, - "--network=" + network, - "--use-tmpfs=" + str(use_tmpfs), - "--file-access=" + file_access, - "--overlay=" + str(overlay), - "--add-uds-tree=" + str(add_uds_tree), - ] - - sh_test( - srcs = ["syscall_test_runner.sh"], - name = name, - data = [ - ":syscall_test_runner", - loopback, - test, - ], - args = args, - size = size, - tags = tags, - shard_count = shard_count, - ) - -def sh_test(**kwargs): - """Wraps the standard sh_test.""" - native.sh_test( - **kwargs - ) diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD deleted file mode 100644 index de4b2727c..000000000 --- a/test/syscalls/gtest/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "gtest", - srcs = ["gtest.go"], - visibility = ["//:sandbox"], -) diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go deleted file mode 100644 index bdec8eb07..000000000 --- a/test/syscalls/gtest/gtest.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gtest contains helpers for running google-test tests from Go. -package gtest - -import ( - "fmt" - "os/exec" - "strings" -) - -var ( - // ListTestFlag is the flag that will list tests in gtest binaries. - ListTestFlag = "--gtest_list_tests" - - // FilterTestFlag is the flag that will filter tests in gtest binaries. - FilterTestFlag = "--gtest_filter" -) - -// TestCase is a single gtest test case. -type TestCase struct { - // Suite is the suite for this test. - Suite string - - // Name is the name of this individual test. - Name string -} - -// FullName returns the name of the test including the suite. It is suitable to -// pass to "-gtest_filter". -func (tc TestCase) FullName() string { - return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) -} - -// ParseTestCases calls a gtest test binary to list its test and returns a -// slice with the name and suite of each test. -func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) { - args := append([]string{ListTestFlag}, extraArgs...) - cmd := exec.Command(testBin, args...) - out, err := cmd.Output() - if err != nil { - exitErr, ok := err.(*exec.ExitError) - if !ok { - return nil, fmt.Errorf("could not enumerate gtest tests: %v", err) - } - return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr) - } - - var t []TestCase - var suite string - for _, line := range strings.Split(string(out), "\n") { - // Strip comments. - line = strings.Split(line, "#")[0] - - // New suite? - if !strings.HasPrefix(line, " ") { - suite = strings.TrimSuffix(strings.TrimSpace(line), ".") - continue - } - - // Individual test. - name := strings.TrimSpace(line) - - // Do we have a suite yet? - if suite == "" { - return nil, fmt.Errorf("test without a suite: %v", name) - } - - // Add this individual test. - t = append(t, TestCase{ - Suite: suite, - Name: name, - }) - - } - - if len(t) == 0 { - return nil, fmt.Errorf("no tests parsed from %v", testBin) - } - return t, nil -} diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc index d89269985..940c97285 100644 --- a/test/syscalls/linux/alarm.cc +++ b/test/syscalls/linux/alarm.cc @@ -188,6 +188,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index b5e0a512b..07bd527e6 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -868,6 +868,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 421c15b87..c7cc5816e 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -1128,5 +1128,5 @@ int main(int argc, char** argv) { exit(err); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc index b77e4cbd1..8b48f0804 100644 --- a/test/syscalls/linux/itimer.cc +++ b/test/syscalls/linux/itimer.cc @@ -349,6 +349,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index d07571a5f..04c5161f5 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -226,5 +226,5 @@ int main(int argc, char** argv) { prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc index 30f0d75b3..c4e9cf528 100644 --- a/test/syscalls/linux/prctl_setuid.cc +++ b/test/syscalls/linux/prctl_setuid.cc @@ -264,5 +264,5 @@ int main(int argc, char** argv) { prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index a23fdb58d..f91187e75 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -2076,5 +2076,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 4dd5cf27b..bfe3e2603 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -1208,5 +1208,5 @@ int main(int argc, char** argv) { gvisor::testing::RunExecveChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc index 81d193ffd..ed27e2566 100644 --- a/test/syscalls/linux/rtsignal.cc +++ b/test/syscalls/linux/rtsignal.cc @@ -167,6 +167,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc index 2c947feb7..cf6499f8b 100644 --- a/test/syscalls/linux/seccomp.cc +++ b/test/syscalls/linux/seccomp.cc @@ -411,5 +411,5 @@ int main(int argc, char** argv) { } gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc index 4deb1ae95..6227774a4 100644 --- a/test/syscalls/linux/sigiret.cc +++ b/test/syscalls/linux/sigiret.cc @@ -132,6 +132,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc index 95be4b66c..389e5fca2 100644 --- a/test/syscalls/linux/signalfd.cc +++ b/test/syscalls/linux/signalfd.cc @@ -369,5 +369,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc index 7db57d968..b2fcedd62 100644 --- a/test/syscalls/linux/sigstop.cc +++ b/test/syscalls/linux/sigstop.cc @@ -147,5 +147,5 @@ int main(int argc, char** argv) { return 1; } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc index 1e5bf5942..4f8afff15 100644 --- a/test/syscalls/linux/sigtimedwait.cc +++ b/test/syscalls/linux/sigtimedwait.cc @@ -319,6 +319,5 @@ int main(int argc, char** argv) { TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0); gvisor::testing::TestInit(&argc, &argv); - - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 2f92c27da..4b3c44527 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -658,5 +658,5 @@ int main(int argc, char** argv) { } } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc index 0aaba482d..19d05998e 100644 --- a/test/syscalls/linux/vfork.cc +++ b/test/syscalls/linux/vfork.cc @@ -191,5 +191,5 @@ int main(int argc, char** argv) { return gvisor::testing::RunChild(); } - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go deleted file mode 100644 index ae342b68c..000000000 --- a/test/syscalls/syscall_test_runner.go +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Binary syscall_test_runner runs the syscall test suites in gVisor -// containers and on the host platform. -package main - -import ( - "flag" - "fmt" - "io/ioutil" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strings" - "syscall" - "testing" - "time" - - specs "github.com/opencontainers/runtime-spec/specs-go" - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/specutils" - "gvisor.dev/gvisor/runsc/testutil" - "gvisor.dev/gvisor/test/syscalls/gtest" - "gvisor.dev/gvisor/test/uds" -) - -// Location of syscall tests, relative to the repo root. -const testDir = "test/syscalls/linux" - -var ( - testName = flag.String("test-name", "", "name of test binary to run") - debug = flag.Bool("debug", false, "enable debug logs") - strace = flag.Bool("strace", false, "enable strace logs") - platform = flag.String("platform", "ptrace", "platform to run on") - network = flag.String("network", "none", "network stack to run on (sandbox, host, none)") - useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp") - fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") - parallel = flag.Bool("parallel", false, "run tests in parallel") - runscPath = flag.String("runsc", "", "path to runsc binary") - - addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") -) - -// runTestCaseNative runs the test case directly on the host machine. -func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { - // These tests might be running in parallel, so make sure they have a - // unique test temp dir. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - // Replace TEST_TMPDIR in the current environment with something - // unique. - env := os.Environ() - newEnvVar := "TEST_TMPDIR=" + tmpDir - var found bool - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = newEnvVar - found = true - break - } - } - if !found { - env = append(env, newEnvVar) - } - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - if *addUDSTree { - socketDir, cleanup, err := uds.CreateSocketTree("/tmp") - if err != nil { - t.Fatalf("failed to create socket tree: %v", err) - } - defer cleanup() - - env = append(env, "TEST_UDS_TREE="+socketDir) - // On Linux, the concept of "attach" location doesn't exist. - // Just pass the same path to make these test identical. - env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) - } - - cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName()) - cmd.Env = env - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) - t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) - } -} - -// runRunsc runs spec in runsc in a standard test configuration. -// -// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. -// -// Returns an error if the sandboxed application exits non-zero. -func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { - bundleDir, err := testutil.SetupBundleDir(spec) - if err != nil { - return fmt.Errorf("SetupBundleDir failed: %v", err) - } - defer os.RemoveAll(bundleDir) - - rootDir, err := testutil.SetupRootDir() - if err != nil { - return fmt.Errorf("SetupRootDir failed: %v", err) - } - defer os.RemoveAll(rootDir) - - name := tc.FullName() - id := testutil.UniqueContainerID() - log.Infof("Running test %q in container %q", name, id) - specutils.LogSpec(spec) - - args := []string{ - "-root", rootDir, - "-network", *network, - "-log-format=text", - "-TESTONLY-unsafe-nonroot=true", - "-net-raw=true", - fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), - "-watchdog-action=panic", - "-platform", *platform, - "-file-access", *fileAccess, - } - if *overlay { - args = append(args, "-overlay") - } - if *debug { - args = append(args, "-debug", "-log-packets=true") - } - if *strace { - args = append(args, "-strace") - } - if *addUDSTree { - args = append(args, "-fsgofer-host-uds") - } - - if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { - return fmt.Errorf("could not create test dir: %v", err) - } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") - if err != nil { - return fmt.Errorf("could not create temp dir: %v", err) - } - debugLogDir += "/" - log.Infof("runsc logs: %s", debugLogDir) - args = append(args, "-debug-log", debugLogDir) - - // Default -log sends messages to stderr which makes reading the test log - // difficult. Instead, drop them when debug log is enabled given it's a - // better place for these messages. - args = append(args, "-log=/dev/null") - } - - // Current process doesn't have CAP_SYS_ADMIN, create user namespace and run - // as root inside that namespace to get it. - rArgs := append(args, "run", "--bundle", bundleDir, id) - cmd := exec.Command(*runscPath, rArgs...) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, - // Set current user/group as root inside the namespace. - UidMappings: []syscall.SysProcIDMap{ - {ContainerID: 0, HostID: os.Getuid(), Size: 1}, - }, - GidMappings: []syscall.SysProcIDMap{ - {ContainerID: 0, HostID: os.Getgid(), Size: 1}, - }, - GidMappingsEnableSetgroups: false, - Credential: &syscall.Credential{ - Uid: 0, - Gid: 0, - }, - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM) - go func() { - s, ok := <-sig - if !ok { - return - } - log.Warningf("%s: Got signal: %v", name, s) - done := make(chan bool) - dArgs := append([]string{}, args...) - dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) - go func(dArgs []string) { - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() - done <- true - }(dArgs) - - timeout := time.After(3 * time.Second) - select { - case <-timeout: - log.Infof("runsc debug --stacks is timeouted") - case <-done: - } - - log.Warningf("Send SIGTERM to the sandbox process") - dArgs = append(args, "debug", - fmt.Sprintf("--signal=%d", syscall.SIGTERM), - id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() - }() - - err = cmd.Run() - - signal.Stop(sig) - close(sig) - - return err -} - -// setupUDSTree updates the spec to expose a UDS tree for gofer socket testing. -func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { - socketDir, cleanup, err := uds.CreateSocketTree("/tmp") - if err != nil { - return nil, fmt.Errorf("failed to create socket tree: %v", err) - } - - // Standard access to entire tree. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets", - Source: socketDir, - Type: "bind", - }) - - // Individial attach points for each socket to test mounts that attach - // directly to the sockets. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/stream/echo", - Source: filepath.Join(socketDir, "stream/echo"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/stream/nonlistening", - Source: filepath.Join(socketDir, "stream/nonlistening"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/seqpacket/echo", - Source: filepath.Join(socketDir, "seqpacket/echo"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/seqpacket/nonlistening", - Source: filepath.Join(socketDir, "seqpacket/nonlistening"), - Type: "bind", - }) - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp/sockets-attach/dgram/null", - Source: filepath.Join(socketDir, "dgram/null"), - Type: "bind", - }) - - spec.Process.Env = append(spec.Process.Env, "TEST_UDS_TREE=/tmp/sockets") - spec.Process.Env = append(spec.Process.Env, "TEST_UDS_ATTACH_TREE=/tmp/sockets-attach") - - return cleanup, nil -} - -// runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { - // Run a new container with the test executable and filter for the - // given test suite and name. - spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName()) - - // Mark the root as writeable, as some tests attempt to - // write to the rootfs, and expect EACCES, not EROFS. - spec.Root.Readonly = false - - // Test spec comes with pre-defined mounts that we don't want. Reset it. - spec.Mounts = nil - if *useTmpfs { - // Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on - // features only available in gVisor's internal tmpfs may fail. - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: "/tmp", - Type: "tmpfs", - }) - } else { - // Use a gofer-backed directory as '/tmp'. - // - // Tests might be running in parallel, so make sure each has a - // unique test temp dir. - // - // Some tests (e.g., sticky) access this mount from other - // users, so make sure it is world-accessible. - tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") - if err != nil { - t.Fatalf("could not create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - if err := os.Chmod(tmpDir, 0777); err != nil { - t.Fatalf("could not chmod temp dir: %v", err) - } - - spec.Mounts = append(spec.Mounts, specs.Mount{ - Type: "bind", - Destination: "/tmp", - Source: tmpDir, - }) - } - - // Set environment variables that indicate we are - // running in gVisor with the given platform and network. - platformVar := "TEST_ON_GVISOR" - networkVar := "GVISOR_NETWORK" - env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) - - // Remove env variables that cause the gunit binary to write output - // files, since they will stomp on eachother, and on the output files - // from this go test. - env = filterEnv(env, []string{"GUNIT_OUTPUT", "TEST_PREMATURE_EXIT_FILE", "XML_OUTPUT_FILE"}) - - // Remove shard env variables so that the gunit binary does not try to - // intepret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) - - // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to - // be backed by tmpfs. - for i, kv := range env { - if strings.HasPrefix(kv, "TEST_TMPDIR=") { - env[i] = "TEST_TMPDIR=/tmp" - break - } - } - - spec.Process.Env = env - - if *addUDSTree { - cleanup, err := setupUDSTree(spec) - if err != nil { - t.Fatalf("error creating UDS tree: %v", err) - } - defer cleanup() - } - - if err := runRunsc(tc, spec); err != nil { - t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) - } -} - -// filterEnv returns an environment with the blacklisted variables removed. -func filterEnv(env, blacklist []string) []string { - var out []string - for _, kv := range env { - ok := true - for _, k := range blacklist { - if strings.HasPrefix(kv, k+"=") { - ok = false - break - } - } - if ok { - out = append(out, kv) - } - } - return out -} - -func fatalf(s string, args ...interface{}) { - fmt.Fprintf(os.Stderr, s+"\n", args...) - os.Exit(1) -} - -func matchString(a, b string) (bool, error) { - return a == b, nil -} - -func main() { - flag.Parse() - if *testName == "" { - fatalf("test-name flag must be provided") - } - - log.SetLevel(log.Info) - if *debug { - log.SetLevel(log.Debug) - } - - if *platform != "native" && *runscPath == "" { - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - *runscPath = specutils.ExePath - } - - // Make sure stdout and stderr are opened with O_APPEND, otherwise logs - // from outside the sandbox can (and will) stomp on logs from inside - // the sandbox. - for _, f := range []*os.File{os.Stdout, os.Stderr} { - flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) - if err != nil { - fatalf("error getting file flags for %v: %v", f, err) - } - if flags&unix.O_APPEND == 0 { - flags |= unix.O_APPEND - if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { - fatalf("error setting file flags for %v: %v", f, err) - } - } - } - - // Get path to test binary. - fullTestName := filepath.Join(testDir, *testName) - testBin, err := testutil.FindFile(fullTestName) - if err != nil { - fatalf("FindFile(%q) failed: %v", fullTestName, err) - } - - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin) - if err != nil { - fatalf("ParseTestCases(%q) failed: %v", testBin, err) - } - - // Get subset of tests corresponding to shard. - indices, err := testutil.TestIndicesForShard(len(testCases)) - if err != nil { - fatalf("TestsForShard() failed: %v", err) - } - - // Run the tests. - var tests []testing.InternalTest - for _, tci := range indices { - // Capture tc. - tc := testCases[tci] - testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name) - tests = append(tests, testing.InternalTest{ - Name: testName, - F: func(t *testing.T) { - if *parallel { - t.Parallel() - } - if *platform == "native" { - // Run the test case on host. - runTestCaseNative(testBin, tc, t) - } else { - // Run the test case in runsc. - runTestCaseRunsc(testBin, tc, t) - } - }, - }) - } - - testing.Main(matchString, tests, nil, nil) -} diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh deleted file mode 100755 index 864bb2de4..000000000 --- a/test/syscalls/syscall_test_runner.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# syscall_test_runner.sh is a simple wrapper around the go syscall test runner. -# It exists so that we can build the syscall test runner once, and use it for -# all syscall tests, rather than build it for each test run. - -set -euf -x -o pipefail - -echo -- "$@" - -if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then - mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}" - chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}" -fi - -# Get location of syscall_test_runner binary. -readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner) - -# Pass the arguments of this script directly to the runner. -exec "${runner}" "$@" diff --git a/test/util/BUILD b/test/util/BUILD index 1f22ebe29..8b5a0f25c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_library", "cc_test", "gtest", "select_system") +load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system") package( default_visibility = ["//:sandbox"], @@ -260,6 +260,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", gtest, + gbenchmark, ], ) diff --git a/test/util/test_main.cc b/test/util/test_main.cc index 5c7ee0064..1f389e58f 100644 --- a/test/util/test_main.cc +++ b/test/util/test_main.cc @@ -16,5 +16,5 @@ int main(int argc, char** argv) { gvisor::testing::TestInit(&argc, &argv); - return RUN_ALL_TESTS(); + return gvisor::testing::RunAllTests(); } diff --git a/test/util/test_util.h b/test/util/test_util.h index 2d22b0eb8..c5cb9d6d6 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -771,6 +771,7 @@ std::string RunfilePath(std::string path); #endif void TestInit(int* argc, char*** argv); +int RunAllTests(void); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc index ba7c0a85b..7e1ad9e66 100644 --- a/test/util/test_util_impl.cc +++ b/test/util/test_util_impl.cc @@ -17,8 +17,12 @@ #include "gtest/gtest.h" #include "absl/flags/flag.h" #include "absl/flags/parse.h" +#include "benchmark/benchmark.h" #include "test/util/logging.h" +extern bool FLAGS_benchmark_list_tests; +extern std::string FLAGS_benchmark_filter; + namespace gvisor { namespace testing { @@ -26,6 +30,7 @@ void SetupGvisorDeathTest() {} void TestInit(int* argc, char*** argv) { ::testing::InitGoogleTest(argc, *argv); + benchmark::Initialize(argc, *argv); ::absl::ParseCommandLine(*argc, *argv); // Always mask SIGPIPE as it's common and tests aren't expected to handle it. @@ -34,5 +39,14 @@ void TestInit(int* argc, char*** argv) { TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0); } +int RunAllTests() { + if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") { + benchmark::RunSpecifiedBenchmarks(); + return 0; + } else { + return RUN_ALL_TESTS(); + } +} + } // namespace testing } // namespace gvisor diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 6798362dc..6f091d759 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -21,6 +21,7 @@ go_image = _go_image go_embed_data = _go_embed_data go_suffixes = _go_suffixes gtest = "@com_google_googletest//:gtest" +gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" proto_library = native.proto_library pkg_deb = _pkg_deb diff --git a/tools/defs.bzl b/tools/defs.bzl index 39f035f12..4eece2d83 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") # Delegate directly. cc_binary = _cc_binary @@ -21,6 +21,7 @@ go_image = _go_image go_test = _go_test go_tool_library = _go_tool_library gtest = _gtest +gbenchmark = _gbenchmark pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = _py_library -- cgit v1.2.3 From 72187fa7a9e1f3ee9d021681f4465777f91c13fe Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 20 Feb 2020 12:32:31 -0800 Subject: Import tags.bzl directly from tools/defs.bzl. This simplifies the script slightly. PiperOrigin-RevId: 296272077 --- tools/bazeldefs/defs.bzl | 2 -- tools/defs.bzl | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'tools') diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 6f091d759..905b16d41 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -8,7 +8,6 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") load("@pydeps//:requirements.bzl", _py_requirement = "requirement") -load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes") container_image = _container_image cc_binary = _cc_binary @@ -19,7 +18,6 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data -go_suffixes = _go_suffixes gtest = "@com_google_googletest//:gtest" gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" diff --git a/tools/defs.bzl b/tools/defs.bzl index 4eece2d83..ddefb72d0 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,8 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Delegate directly. cc_binary = _cc_binary -- cgit v1.2.3 From d90d71474f4c82f742140fdf026821709845cece Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 14:28:31 -0800 Subject: Remove bytes read/written from marshal.Marshallable API. Users of the API only care about whether the copy in/out succeeds in their entirety, which is already signalled by the returned error. PiperOrigin-RevId: 296297843 --- pkg/sentry/kernel/rseq.go | 2 +- pkg/sentry/syscalls/linux/sys_stat.go | 6 ++---- tools/go_marshal/gomarshal/generator_interfaces.go | 21 +++++++++++---------- tools/go_marshal/gomarshal/generator_tests.go | 2 +- tools/go_marshal/marshal/marshal.go | 4 ++-- 5 files changed, 17 insertions(+), 18 deletions(-) (limited to 'tools') diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 18416643b..ded95f532 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() { } var cs linux.RSeqCriticalSection - if _, err := cs.CopyIn(t, critAddr); err != nil { + if err := cs.CopyIn(t, critAddr); err != nil { t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index 8b66a9006..11f25e00d 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -131,8 +131,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err return err } s := statFromAttrs(t, d.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // fstat implements fstat for the given *fs.File. @@ -142,8 +141,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error { return err } s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr) - _, err = s.CopyOut(t, statAddr) - return err + return s.CopyOut(t, statAddr) } // Statx implements linux syscall statx(2). diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 3aa299ccd..834c58cee 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -507,13 +507,14 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("return task.CopyOutBytes(addr, buf)\n") + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -539,11 +540,11 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("hdr.Len = %s.SizeBytes()\n", g.r) g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - g.emit("len, err := task.CopyOutBytes(addr, buf)\n") + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) g.emit("// must live until after the CopyOutBytes.\n") g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return len, err\n") + g.emit("return err\n") } else { fallback() } @@ -553,20 +554,20 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("n, err := task.CopyInBytes(addr, buf)\n") + g.emit("_, err := task.CopyInBytes(addr, buf)\n") g.emit("if err != nil {\n") g.inIndent(func() { - g.emit("return n, err\n") + g.emit("return err\n") }) g.emit("}\n") g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return n, nil\n") + g.emit("return nil\n") } if thisPacked { g.recordUsedImport("reflect") @@ -592,11 +593,11 @@ func (g *interfaceGenerator) emitMarshallable() { g.emit("hdr.Len = %s.SizeBytes()\n", g.r) g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - g.emit("len, err := task.CopyInBytes(addr, buf)\n") + g.emit("_, err := task.CopyInBytes(addr, buf)\n") g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) g.emit("// must live until after the CopyInBytes.\n") g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return len, err\n") + g.emit("return err\n") } else { fallback() } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 8c28b00d0..2326e7a07 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -92,7 +92,7 @@ func (g *testGenerator) emitTestNonZeroSize() { g.emit("x := &%s{}\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") }) g.emit("}\n") }) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index 20353850d..f129788e0 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -91,12 +91,12 @@ type Marshallable interface { // marshalled does not escape. The implementation should avoid creating // extra copies in memory by directly deserializing to the object's // underlying memory. - CopyIn(task Task, addr usermem.Addr) (int, error) + CopyIn(task Task, addr usermem.Addr) error // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling // MarshalUnsafe on Marshallable.Packed types, as the type being serialized // does not escape. The implementation should avoid creating extra copies in // memory by directly serializing from the object's underlying memory. - CopyOut(task Task, addr usermem.Addr) (int, error) + CopyOut(task Task, addr usermem.Addr) error } -- cgit v1.2.3 From f1b72752e5de2abc3c409a6b7447224620b7c11b Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 16:22:45 -0800 Subject: Implement automated marshalling for newtypes on primitives. PiperOrigin-RevId: 296322954 --- tools/defs.bzl | 8 +- tools/go_marshal/BUILD | 5 + tools/go_marshal/gomarshal/generator.go | 43 ++- tools/go_marshal/gomarshal/generator_interfaces.go | 296 ++++++++++++++++----- tools/go_marshal/gomarshal/generator_tests.go | 15 +- tools/go_marshal/test/test.go | 10 + 6 files changed, 286 insertions(+), 91 deletions(-) (limited to 'tools') diff --git a/tools/defs.bzl b/tools/defs.bzl index ddefb72d0..45c065459 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -85,7 +85,7 @@ def go_imports(name, src, out): cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), ) -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. The recommended way is to use this rule with mostly identical configuration as the native @@ -108,6 +108,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F imports: imports required for stateify. stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). + marshal_debug: whether the gomarshal tools emits debugging output (default: false). **kwargs: standard go_library arguments. """ all_srcs = srcs @@ -146,7 +147,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F go_marshal( name = name + suffix + "_abi_autogen", srcs = src_subset, - debug = False, + debug = select({ + "//tools/go_marshal:marshal_config_verbose": True, + "//conditions:default": marshal_debug, + }), imports = imports, package = name, ) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index 80d9c0504..be49cf9c8 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -12,3 +12,8 @@ go_binary( "//tools/go_marshal/gomarshal", ], ) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0fa868415..d365a1f3c 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -44,7 +44,8 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", + "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } @@ -193,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } -// collectMarshallabeTypes walks the parsed AST and collects a list of type +// collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { var types []*ast.TypeSpec for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -222,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as continue } for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier. + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. t := spec.(*ast.TypeSpec) - if _, ok := t.Type.(*ast.StructType); ok { - debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) types = append(types, t) continue } - debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } } return types @@ -269,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - // We're guaranteed to have only struct type specs by now. See - // Generator.collectMarshallabeTypes. i := newInterfaceGenerator(t, fset) - i.validate() - i.emitMarshallable() - return i + switch ty := t.Type.(type) { + case *ast.StructType: + i.validateStruct() + i.emitMarshallableForStruct() + return i + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype() + return i + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } } // generateOneTestSuite generates a test suite for the automatically generated @@ -320,7 +337,7 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. - for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref, _ := range impl.ms { diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 834c58cee..ea1af998e 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct() { g.forEachField(func(f *ast.Field) { if len(f.Names) == 0 { g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") @@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() { g.forEachField(func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } + g.validatePrimitiveNewtype(t) }, selector: func(_, _, _ *ast.Ident) { // No validation to perform on selector fields. However this @@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. +func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalStructFieldScalar reads a single scalar field from a struct, from a +// byte slice. +func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string } } +// marshalPrimitiveScalar writes a single primitive variable to a byte slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + // areFieldsPackedExpression returns a go expression checking whether g.t's fields are // packed. Returns "", false if g.t has no fields that may be potentially // packed, otherwise returns , true, where is an expression @@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { return strings.Join(cs, " && "), true } -func (g *interfaceGenerator) emitMarshallable() { +func (g *interfaceGenerator) emitMarshallableForStruct() { // Is g.t a packed struct without consideing field types? thisPacked := true g.forEachField(func(f *ast.Field) { @@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) g.emit("}\n") }, @@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) g.emit("}\n") }, @@ -650,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") } + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + nt := g.t.Type.(*ast.Ident) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") + +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 2326e7a07..8ba47eb67 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -49,9 +49,6 @@ type testGenerator struct { } func newTestGenerator(t *ast.TypeSpec) *testGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &testGenerator{ t: t, r: receiverName(t), @@ -69,14 +66,6 @@ func (g *testGenerator) typeName() string { return g.t.Name.Name } -func (g *testGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } @@ -89,7 +78,7 @@ func (g *testGenerator) inTestFunction(name string, body func()) { func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { - g.emit("x := &%s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") @@ -100,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() { func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { - g.emit("x := %s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 8de02d707..93229dedb 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -103,3 +103,13 @@ type Stat struct { CTime Timespec _ [3]int64 } + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal +type SignalSetAlias SignalSet -- cgit v1.2.3 From 3733499952c056cc8496beb01c72dcf53177048e Mon Sep 17 00:00:00 2001 From: Zach Koopmans Date: Fri, 21 Feb 2020 13:17:44 -0800 Subject: Fix master installer. Sometimes, when we start a new instance, the file lock on "apt" is locked. Add a loop to the master installer. In addition, the "apt-get install" fails to register runsc in docker, so run the appropriate scripts to get that to happen. Also, add some helpful log messages. PiperOrigin-RevId: 296497357 --- benchmarks/harness/machine.py | 9 ++++++--- benchmarks/harness/ssh_connection.py | 9 +++++++-- tools/installers/master.sh | 17 ++++++++++++++++- 3 files changed, 29 insertions(+), 6 deletions(-) (limited to 'tools') diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py index 3d32d3dda..5bdc4aa85 100644 --- a/benchmarks/harness/machine.py +++ b/benchmarks/harness/machine.py @@ -43,6 +43,8 @@ from benchmarks.harness import machine_mocks from benchmarks.harness import ssh_connection from benchmarks.harness import tunnel_dispatcher +log = logging.getLogger(__name__) + class Machine(object): """The machine object is the primary object for benchmarks. @@ -236,9 +238,10 @@ class RemoteMachine(Machine): archive=archive, dir=harness.REMOTE_INSTALLERS_PATH)) self._has_installers = True - # Execute the remote installer. - self.run("sudo {dir}/{file}".format( - dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) + # Execute the remote installer. + self.run("sudo {dir}/{file}".format( + dir=harness.REMOTE_INSTALLERS_PATH, file=installer)) + if results: results[index] = True diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py index a50e34293..b8c8e42d4 100644 --- a/benchmarks/harness/ssh_connection.py +++ b/benchmarks/harness/ssh_connection.py @@ -13,7 +13,7 @@ # limitations under the License. """SSHConnection handles the details of SSH connections.""" - +import logging import os import warnings @@ -24,6 +24,8 @@ from benchmarks import harness # Get rid of paramiko Cryptography Warnings. warnings.filterwarnings(action="ignore", module=".*paramiko.*") +log = logging.getLogger(__name__) + def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str) -> str: @@ -94,10 +96,13 @@ class SSHConnection: The contents of stdout and stderr. """ with self._client() as client: + log.info("running command: %s", cmd) _, stdout, stderr = client.exec_command(command=cmd) - stdout.channel.recv_exit_status() + log.info("returned status: %d", stdout.channel.recv_exit_status()) stdout = stdout.read().decode("utf-8") stderr = stderr.read().decode("utf-8") + log.info("stdout: %s", stdout) + log.info("stderr: %s", stderr) return stdout, stderr def send_workload(self, name: str) -> str: diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 7b1956454..52f9734a6 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -15,6 +15,21 @@ # limitations under the License. # Install runsc from the master branch. +set -e + curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" -apt-get update && apt-get install -y runsc +while true; do + if apt-get update; then + apt-get install -y runsc + break + fi + result=$? + # Check if apt update failed to aquire the file lock. + if [[ $result -ne 100 ]]; then + exit $result + fi +done +runsc install +service docker restart + -- cgit v1.2.3 From 10aa4d3b343255db45f5ca4ff7b51f21a309e10b Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Fri, 21 Feb 2020 15:05:20 -0800 Subject: Factor platform tags. PiperOrigin-RevId: 296519566 --- test/runner/defs.bzl | 58 ++++++++++++++----------------------------- tools/bazeldefs/platforms.bzl | 17 +++++++++++++ tools/defs.bzl | 3 +++ 3 files changed, 39 insertions(+), 39 deletions(-) create mode 100644 tools/bazeldefs/platforms.bzl (limited to 'tools') diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 5e97c1867..56743a526 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -1,6 +1,6 @@ """Defines a rule for syscall test targets.""" -load("//tools:defs.bzl", "loopback") +load("//tools:defs.bzl", "default_platform", "loopback", "platforms") def _runner_test_impl(ctx): # Generate a runner binary. @@ -94,19 +94,6 @@ def _syscall_test( # Disable off-host networking. tags.append("requires-net:loopback") - # Add tag to prevent the tests from running in a Bazel sandbox. - # TODO(b/120560048): Make the tests run without this tag. - tags.append("no-sandbox") - - # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is - # more stable. - if platform == "kvm": - tags.append("manual") - tags.append("requires-kvm") - - # TODO(b/112165693): Remove when tests pass reliably. - tags.append("notap") - runner_args = [ # Arguments are passed directly to runner binary. "--platform=" + platform, @@ -149,6 +136,8 @@ def syscall_test( add_hostinet: add a hostinet test. tags: starting test tags. """ + if not tags: + tags = [] _syscall_test( test = test, @@ -160,35 +149,26 @@ def syscall_test( tags = tags, ) - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "kvm", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) - - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "ptrace", - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = tags, - ) + for (platform, platform_tags) in platforms.items(): + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platform_tags + tags, + ) if add_overlay: _syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, overlay = True, ) @@ -198,10 +178,10 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, file_access = "shared", ) @@ -210,9 +190,9 @@ def syscall_test( test = test, shard_count = shard_count, size = size, - platform = "ptrace", + platform = default_platform, use_tmpfs = use_tmpfs, network = "host", add_uds_tree = add_uds_tree, - tags = tags, + tags = platforms[default_platform] + tags, ) diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl new file mode 100644 index 000000000..92b0b5fc0 --- /dev/null +++ b/tools/bazeldefs/platforms.bzl @@ -0,0 +1,17 @@ +"""List of platforms.""" + +# Platform to associated tags. +platforms = { + "ptrace": [ + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], + "kvm": [ + "manual", + "local", + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], +} + +default_platform = "ptrace" diff --git a/tools/defs.bzl b/tools/defs.bzl index 45c065459..15a310403 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -8,6 +8,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Delegate directly. @@ -34,6 +35,8 @@ select_system = _select_system loopback = _loopback default_installer = _default_installer default_net_util = _default_net_util +platforms = _platforms +default_platform = _default_platform def go_binary(name, **kwargs): """Wraps the standard go_binary. -- cgit v1.2.3 From 8e2b14fecf204b35fe258816792bdc03a1ca0912 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 27 Feb 2020 10:21:33 -0800 Subject: Use automated release notes, if available. PiperOrigin-RevId: 297628615 --- scripts/release.sh | 13 ++++++++++++- tools/tag_release.sh | 12 +++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) (limited to 'tools') diff --git a/scripts/release.sh b/scripts/release.sh index 091abf87f..e14ba04a7 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -25,6 +25,14 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then echo "No KOKORO_RELEASE_TAG provided." >&2 exit 1 fi +if ! [[ -v KOKORO_RELNOTES ]]; then + echo "No KOKORO_RELNOTES provided." >&2 + exit 1 +fi +if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then + echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2 + exit 1 +fi # Unless an explicit releaser is provided, use the bot e-mail. declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot} @@ -46,4 +54,7 @@ EOF fi # Run the release tool, which pushes to the origin repository. -tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}" +tools/tag_release.sh \ + "${KOKORO_RELEASE_COMMIT}" \ + "${KOKORO_RELEASE_TAG}" \ + "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" diff --git a/tools/tag_release.sh b/tools/tag_release.sh index f33b902d6..4dbfe420a 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -21,13 +21,19 @@ set -xeu # Check arguments. -if [ "$#" -ne 2 ]; then - echo "usage: $0 " +if [ "$#" -ne 3 ]; then + echo "usage: $0 " exit 1 fi declare -r target_commit="$1" declare -r release="$2" +declare -r message_file="$3" + +if ! [[ -r "${message_file}" ]]; then + echo "error: message file '${message_file}' is not readable." + exit 1 +fi closest_commit() { while read line; do @@ -64,6 +70,6 @@ fi # Tag the given commit (annotated, to record the committer). declare -r tag="release-${release}" -(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \ +(git tag -F "${message_file}" -a "${tag}" "${commit}" && \ git push origin tag "${tag}") || \ (git tag -d "${tag}" && false) -- cgit v1.2.3 From aa9f8abaef5c6250bdcee8fd88b2420f20791c5d Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 27 Feb 2020 14:51:29 -0800 Subject: Implement automated marshalling for newtypes on arrays. PiperOrigin-RevId: 297693838 --- tools/go_marshal/gomarshal/BUILD | 3 + tools/go_marshal/gomarshal/generator.go | 17 +- tools/go_marshal/gomarshal/generator_interfaces.go | 665 +-------------------- .../generator_interfaces_array_newtype.go | 183 ++++++ .../generator_interfaces_primitive_newtype.go | 229 +++++++ .../gomarshal/generator_interfaces_struct.go | 450 ++++++++++++++ tools/go_marshal/gomarshal/generator_tests.go | 2 +- tools/go_marshal/gomarshal/util.go | 41 +- tools/go_marshal/test/test.go | 5 + 9 files changed, 915 insertions(+), 680 deletions(-) create mode 100644 tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go create mode 100644 tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go create mode 100644 tools/go_marshal/gomarshal/generator_interfaces_struct.go (limited to 'tools') diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index b5d5a4487..44cb33ae4 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -7,6 +7,9 @@ go_library( srcs = [ "generator.go", "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", "generator_tests.go", "util.go", ], diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index d365a1f3c..729489de5 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -235,6 +235,10 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) types = append(types, t) continue + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) + types = append(types, t) + continue } // A user specifically requested marshalling on this type, but we // don't support it. @@ -281,17 +285,20 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface i := newInterfaceGenerator(t, fset) switch ty := t.Type.(type) { case *ast.StructType: - i.validateStruct() - i.emitMarshallableForStruct() - return i + i.validateStruct(t, ty) + i.emitMarshallableForStruct(ty) case *ast.Ident: i.validatePrimitiveNewtype(ty) - i.emitMarshallableForPrimitiveNewtype() - return i + i.emitMarshallableForPrimitiveNewtype(ty) + case *ast.ArrayType: + i.validateArrayNewtype(t.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) default: // This should've been filtered out by collectMarshallabeTypes. panic(fmt.Sprintf("Unexpected type %+v", ty)) } + return i } // generateOneTestSuite generates a test suite for the automatically generated diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index ea1af998e..8babf61d2 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,10 +15,8 @@ package gomarshal import ( - "fmt" "go/ast" "go/token" - "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -81,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { g.as[fieldName] = struct{}{} } -func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - // abortAt aborts the go_marshal tool with the given error message, with a // reference position to the input source. Same as abortAt, but uses g to // resolve p to position. @@ -100,71 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we provide - // suggestions for some disallowed types and reject them, then attempt - // to marshal any remaining types by invoking the marshal.Marshallable - // interface on them. If these types don't actually implement - // marshal.Marshallable, compilation of the generated code will fail - // with an appropriate error message. - return - case "int": - g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } -} - -// validateStruct ensures the type we're working with can be marshalled. These -// checks are done ahead of time and in one place so we can make assumptions -// later. -func (g *interfaceGenerator) validateStruct() { - g.forEachField(func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - g.forEachField(func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - g.validatePrimitiveNewtype(t) - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n, _ *ast.Ident, len int) { - a := f.Type.(*ast.ArrayType) - if a.Len == nil { - g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } - - if len <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - // scalarSize returns the size of type identified by t. If t isn't a primitive // type, the size isn't known at code generation time, and must be resolved via // the marshal.Marshallable interface. @@ -191,8 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. -func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -215,9 +136,8 @@ func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar stri } } -// unmarshalStructFieldScalar reads a single scalar field from a struct, from a -// byte slice. -func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { switch typ { case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) @@ -243,580 +163,3 @@ func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar st g.recordPotentiallyNonPackedField(accessor) } } - -// marshalPrimitiveScalar writes a single primitive variable to a byte slice. -func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) - default: - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. -func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { - switch typ { - case "byte": - g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) - case "int8", "uint8": - g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) - case "int16", "uint16": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) - case "int32", "uint32": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) - - case "int64", "uint64": - g.recordUsedImport("usermem") - g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) - default: - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns , true, where is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor, _ := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - return strings.Join(cs, " && "), true -} - -func (g *interfaceGenerator) emitMarshallableForStruct() { - // Is g.t a packed struct without consideing field types? - thisPacked := true - g.forEachField(func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false - } - } - } - }) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(n, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for idx := 0; idx < %d; idx++ {\n", size) - g.inIndent(func() { - g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("return err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.recordUsedImport("marshal") - g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("if err != nil {\n") - g.inIndent(func() { - g.emit("return err\n") - }) - g.emit("}\n") - - g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return nil\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast deserialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("n, err := w.Write(buf)\n") - g.emit("return int64(n), err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") -} - -// emitMarshallableForPrimitiveNewtype outputs code to implement the -// marshal.Marshallable interface for a newtype on a primitive. Primitive -// newtypes are always packed, so we can omit the various fallbacks required for -// non-packed structs. -func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - g.recordUsedImport("usermem") - - nt := g.t.Type.(*ast.Ident) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(nt); !dynamic { - g.emit("return %d\n", size) - } else { - g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.marshalPrimitiveScalar(g.r, nt.Name, "dst") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Scalar newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") - - }) - g.emit("}\n\n") - -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..da36d9305 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if arrayLen(a) <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d\n", size*len) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..159397825 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..e66a38b2e --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,450 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns , true, where is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + // Is g.t a packed struct without consideing field types? + thisPacked := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 8ba47eb67..fd992e44a 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -164,7 +164,7 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") g.inIndent(func() { - g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)") + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") }) g.emit("}\n") }) diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index e2bca4e7c..a0936e013 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -64,6 +64,12 @@ func kindString(e ast.Expr) string { } } +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { @@ -73,6 +79,25 @@ type fieldDispatcher struct { unhandled func(n *ast.Ident) } +// Precondition: a must have a literal for the array length. Consts and +// expressions are not allowed as array lengths, and should be rejected by the +// caller. +func arrayLen(a *ast.ArrayType) int { + if a.Len == nil { + // Probably a slice? Must be handled by caller. + panic("Nil array length in array type") + } + lenLit, ok := a.Len.(*ast.BasicLit) + if !ok { + panic("Array has non-literal for length") + } + len, err := strconv.Atoi(lenLit.Value) + if err != nil { + panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) + } + return len +} + // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: - len := 0 - if v.Len != nil { - // Non-literal array length is handled by generatorInterfaces.validate(). - if lenLit, ok := v.Len.(*ast.BasicLit); ok { - var err error - len, err = strconv.Atoi(lenLit.Value) - if err != nil { - panic(err) - } - } - } switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, len) + fd.array(name, t, arrayLen(v)) default: - fd.array(name, nil, len) + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 93229dedb..c829db6da 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -104,6 +104,11 @@ type Stat struct { _ [3]int64 } +// InetAddr is an example marshallable newtype on an array. +// +// +marshal +type InetAddr [4]byte + // SignalSet is an example marshallable newtype on a primitive. // // +marshal -- cgit v1.2.3 From c96bb4d2ebc6a24b3111d986c5d40574ec8ff660 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Thu, 27 Feb 2020 15:35:19 -0800 Subject: Fix apt-get reliability issues. This is frequently causing the core build scripts to fail. The core ubuntu distribution will perform an auto-update at first start, which may cause the lock file to be held. All apt-get commands may be done in a loop in order to retry to avoid this issue. We may want to consider retrying other pieces, but for now this should avoid the most frequent cause of build flakes. PiperOrigin-RevId: 297704789 --- scripts/build.sh | 2 +- scripts/common.sh | 14 ++++++++++++++ tools/images/ubuntu1604/10_core.sh | 15 ++++++++++++++- tools/images/ubuntu1604/20_bazel.sh | 12 +++++++++++- tools/images/ubuntu1604/25_docker.sh | 33 +++++++++++++++++++++++++------- tools/images/ubuntu1604/30_containerd.sh | 12 +++++++++++- tools/images/ubuntu1604/40_kokoro.sh | 17 +++++++++++++++- tools/installers/master.sh | 7 +++---- 8 files changed, 96 insertions(+), 16 deletions(-) (limited to 'tools') diff --git a/scripts/build.sh b/scripts/build.sh index 4c042af6c..7c9c99800 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -17,7 +17,7 @@ source $(dirname $0)/common.sh # Install required packages for make_repository.sh et al. -sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils xz-utils +apt_install dpkg-sig coreutils apt-utils xz-utils # Build runsc. runsc=$(build -c opt //runsc) diff --git a/scripts/common.sh b/scripts/common.sh index 3ca699e4a..735a383de 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -84,3 +84,17 @@ function install_runsc() { # Restart docker to pick up the new runtime configuration. sudo systemctl restart docker } + +# Installs the given packages. Note that the package names should be verified to +# be correct, otherwise this may result in a loop that spins until time out. +function apt_install() { + while true; do + if (sudo apt-get update && sudo apt-get install -y "$@"); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + return $result + fi + done +} diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh index 46dda6bb1..cd518d6ac 100755 --- a/tools/images/ubuntu1604/10_core.sh +++ b/tools/images/ubuntu1604/10_core.sh @@ -17,7 +17,20 @@ set -xeo pipefail # Install all essential build tools. -apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config +while true; do + if (apt-get update && apt-get install -y \ + make \ + git-core \ + build-essential \ + linux-headers-$(uname -r) \ + pkg-config); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install a recent go toolchain. if ! [[ -d /usr/local/go ]]; then diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh index b33e1656c..bb7afa676 100755 --- a/tools/images/ubuntu1604/20_bazel.sh +++ b/tools/images/ubuntu1604/20_bazel.sh @@ -19,7 +19,17 @@ set -xeo pipefail declare -r BAZEL_VERSION=2.0.0 # Install bazel dependencies. -apt-get update && apt-get install -y openjdk-8-jdk-headless unzip +while true; do + if (apt-get update && apt-get install -y \ + openjdk-8-jdk-headless \ + unzip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Use the release installer. curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh index 1d3defcd3..11eea2d72 100755 --- a/tools/images/ubuntu1604/25_docker.sh +++ b/tools/images/ubuntu1604/25_docker.sh @@ -15,12 +15,20 @@ # limitations under the License. # Add dependencies. -apt-get update && apt-get -y install \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common +while true; do + if (apt-get update && apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install the key. curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - @@ -32,4 +40,15 @@ add-apt-repository \ stable" # Install docker. -apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io +while true; do + if (apt-get update && apt-get install -y \ + docker-ce \ + docker-ce-cli \ + containerd.io); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh index a7472bd1c..fb3699c12 100755 --- a/tools/images/ubuntu1604/30_containerd.sh +++ b/tools/images/ubuntu1604/30_containerd.sh @@ -34,7 +34,17 @@ install_helper() { } # Install dependencies for the crictl tests. -apt-get install -y btrfs-tools libseccomp-dev +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install containerd & cri-tools. GOPATH=$(mktemp -d --tmpdir gopathXXXXX) diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh index 5f2dfc858..06a1e6c48 100755 --- a/tools/images/ubuntu1604/40_kokoro.sh +++ b/tools/images/ubuntu1604/40_kokoro.sh @@ -23,7 +23,22 @@ declare -r ssh_public_keys=( ) # Install dependencies. -apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip python3-pip zip +while true; do + if (apt-get update && apt-get install -y \ + rsync \ + coreutils \ + python-psutil \ + qemu-kvm \ + python-pip \ + python3-pip \ + zip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # junitparser is used to merge junit xml files. pip install junitparser diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 52f9734a6..2c6001c6c 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -19,17 +19,16 @@ set -e curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" + while true; do - if apt-get update; then - apt-get install -y runsc + if (apt-get update && apt-get install -y runsc); then break fi result=$? - # Check if apt update failed to aquire the file lock. if [[ $result -ne 100 ]]; then exit $result fi done + runsc install service docker restart - -- cgit v1.2.3 From e5d9a4010bdbea10320348b022ee5b761c1eba07 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Tue, 11 Feb 2020 16:01:42 -0800 Subject: Add ability to execute go.mod in gopath context. --- CONTRIBUTING.md | 3 +++ WORKSPACE | 41 ++++++++++++++++++++++++++++++----------- go.mod | 31 ++++++++++++++----------------- go.sum | 29 +++++++++++++++++++---------- tools/go_mod.sh | 29 +++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 38 deletions(-) create mode 100755 tools/go_mod.sh (limited to 'tools') diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 71650a4b8..ad8e710da 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,6 +32,9 @@ will need to be added to the appropriate `BUILD` files, and the `:gopath` target will need to be re-run to generate appropriate symlinks in the `GOPATH` directory tree. +Dependencies can be added by using `go mod get`. In order to keep the +`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`. + ### Coding Guidelines All Go code should conform to the [Go style guidelines][gostyle]. C++ code diff --git a/WORKSPACE b/WORKSPACE index a15238a2e..995d2c7f1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -20,7 +20,7 @@ http_archive( ], ) -load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_toolchains") +load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() @@ -43,8 +43,8 @@ gazelle_dependencies() go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", - sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=", - version = "v0.0.0-20191220220014-0732a990476f", + sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=", + version = "v0.0.0-20190215142949-d0b11bdaac8a", ) # Load C++ rules. @@ -68,8 +68,11 @@ http_archive( "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz", ], ) + load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") + rules_proto_dependencies() + rules_proto_toolchains() # Load python dependencies. @@ -146,9 +149,9 @@ load( # This container is built from the Dockerfile in test/iptables/runner. container_pull( name = "iptables-test", + digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", registry = "gcr.io", repository = "gvisor-presubmit/iptables-test", - digest = "sha256:a137d692a2eb9fc7bf95c5f4a568da090e2c31098e93634421ed88f3a3f1db65", ) load( @@ -201,6 +204,13 @@ go_repository( version = "v0.0.0-20171129191014-dec09d789f3d", ) +go_repository( + name = "com_github_kr_pretty", + importpath = "github.com/kr/pretty", + sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=", + version = "v0.2.0", +) + go_repository( name = "com_github_kr_pty", importpath = "github.com/kr/pty", @@ -208,6 +218,13 @@ go_repository( version = "v1.1.1", ) +go_repository( + name = "com_github_kr_text", + importpath = "github.com/kr/text", + sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=", + version = "v0.1.0", +) + go_repository( name = "com_github_opencontainers_runtime-spec", importpath = "github.com/opencontainers/runtime-spec", @@ -236,6 +253,13 @@ go_repository( version = "v0.0.0-20171111001504-be1fbeda1936", ) +go_repository( + name = "in_gopkg_check_v1", + importpath = "gopkg.in/check.v1", + sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=", + version = "v1.0.0-20190902080502-41f04d3bba15", +) + go_repository( name = "org_golang_x_crypto", importpath = "golang.org/x/crypto", @@ -257,12 +281,6 @@ go_repository( version = "v0.3.0", ) -go_repository( - name = "org_golang_x_tools", - commit = "36563e24a262", - importpath = "golang.org/x/tools", -) - go_repository( name = "org_golang_x_sync", importpath = "golang.org/x/sync", @@ -272,8 +290,9 @@ go_repository( go_repository( name = "org_golang_x_time", - commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e", importpath = "golang.org/x/time", + sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=", + version = "v0.0.0-20191024005414-555d28b269f0", ) go_repository( diff --git a/go.mod b/go.mod index c4687ed02..3a8b9288d 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,18 @@ module gvisor.dev/gvisor go 1.13 require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/mock v1.3.1 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/go-cmp v0.2.0 - github.com/google/go-github/v28 v28.1.1 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 - golang.org/x/net v0.0.0-20190311183353-d8887717615a - golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6 - golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 + github.com/golang/protobuf v1.3.1 + github.com/google/btree v1.0.0 + github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 + github.com/kr/pretty v0.2.0 // indirect + github.com/kr/pty v1.1.1 + github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 + github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e + github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect + golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/go.sum b/go.sum index 434770beb..f16a549fd 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,30 @@ +github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8= github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-github/v28 v28.1.1/go.mod h1:bsqJWQX05omyWVmc00nEUql9mhQyv38lDZ8kPZcQVoM= +github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI= github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI= github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q= github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk= github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/tools/go_mod.sh b/tools/go_mod.sh new file mode 100755 index 000000000..84b779d6d --- /dev/null +++ b/tools/go_mod.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Build the :gopath target. +bazel build //:gopath +declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" + +# Copy go.mod and execute the command. +cp -a go.mod go.sum "${gopathdir}" +(cd "${gopathdir}" && go mod "$@") +cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . + +# Cleanup the WORKSPACE file. +bazel run //:gazelle -- update-repos -from_file=go.mod -- cgit v1.2.3