diff options
Diffstat (limited to 'tools')
40 files changed, 1700 insertions, 624 deletions
diff --git a/tools/build/BUILD b/tools/bazeldefs/BUILD index 0c0ce3f4d..00a467473 100644 --- a/tools/build/BUILD +++ b/tools/bazeldefs/BUILD @@ -6,5 +6,5 @@ genrule( name = "loopback", outs = ["loopback.txt"], cmd = "touch $@", - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], ) diff --git a/tools/build/defs.bzl b/tools/bazeldefs/defs.bzl index d0556abd1..905b16d41 100644 --- a/tools/build/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -18,7 +18,9 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data -loopback = "//tools/build:loopback" +gtest = "@com_google_googletest//:gtest" +gbenchmark = "@com_google_benchmark//:benchmark" +loopback = "//tools/bazeldefs:loopback" proto_library = native.proto_library pkg_deb = _pkg_deb pkg_tar = _pkg_tar @@ -69,7 +71,7 @@ def go_test(name, **kwargs): **kwargs ) -def py_requirement(name, direct = False): +def py_requirement(name, direct = True): return _py_requirement(name) def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl new file mode 100644 index 000000000..92b0b5fc0 --- /dev/null +++ b/tools/bazeldefs/platforms.bzl @@ -0,0 +1,17 @@ +"""List of platforms.""" + +# Platform to associated tags. +platforms = { + "ptrace": [ + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], + "kvm": [ + "manual", + "local", + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], +} + +default_platform = "ptrace" diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl new file mode 100644 index 000000000..558fb53ae --- /dev/null +++ b/tools/bazeldefs/tags.bzl @@ -0,0 +1,40 @@ +"""List of special Go suffixes.""" + +go_suffixes = [ + "_386", + "_386_unsafe", + "_aarch64", + "_aarch64_unsafe", + "_amd64", + "_amd64_unsafe", + "_arm", + "_arm64", + "_arm64_unsafe", + "_arm_unsafe", + "_impl", + "_impl_unsafe", + "_linux", + "_linux_unsafe", + "_mips", + "_mips64", + "_mips64_unsafe", + "_mips64le", + "_mips64le_unsafe", + "_mips_unsafe", + "_mipsle", + "_mipsle_unsafe", + "_opts", + "_opts_unsafe", + "_ppc64", + "_ppc64_unsafe", + "_ppc64le", + "_ppc64le_unsafe", + "_riscv64", + "_riscv64_unsafe", + "_s390x", + "_s390x_unsafe", + "_sparc64", + "_sparc64_unsafe", + "_wasm", + "_wasm_unsafe", +] diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index 92ba8ab06..4f1a31a6d 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -5,7 +5,7 @@ package(licenses = ["notice"]) go_tool_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", ], diff --git a/tools/defs.bzl b/tools/defs.bzl index 819f12b0d..15a310403 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") +load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Delegate directly. cc_binary = _cc_binary @@ -20,6 +22,8 @@ go_embed_data = _go_embed_data go_image = _go_image go_test = _go_test go_tool_library = _go_tool_library +gtest = _gtest +gbenchmark = _gbenchmark pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = _py_library @@ -31,6 +35,8 @@ select_system = _select_system loopback = _loopback default_installer = _default_installer default_net_util = _default_net_util +platforms = _platforms +default_platform = _default_platform def go_binary(name, **kwargs): """Wraps the standard go_binary. @@ -44,7 +50,45 @@ def go_binary(name, **kwargs): **kwargs ) -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): +def calculate_sets(srcs): + """Calculates special Go sets for templates. + + Args: + srcs: the full set of Go sources. + + Returns: + A dictionary of the form: + + "": [src1.go, src2.go] + "suffix": [src3suffix.go, src4suffix.go] + + Note that suffix will typically start with '_'. + """ + result = dict() + for file in srcs: + if not file.endswith(".go"): + continue + target = "" + for suffix in go_suffixes: + if file.endswith(suffix + ".go"): + target = suffix + if not target in result: + result[target] = [file] + else: + result[target].append(file) + return result + +def go_imports(name, src, out): + """Simplify a single Go source file by eliminating unused imports.""" + native.genrule( + name = name, + srcs = [src], + outs = [out], + tools = ["@org_golang_x_tools//cmd/goimports:goimports"], + cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), + ) + +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. The recommended way is to use this rule with mostly identical configuration as the native @@ -67,41 +111,62 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F imports: imports required for stateify. stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). + marshal_debug: whether the gomarshal tools emits debugging output (default: false). **kwargs: standard go_library arguments. """ + all_srcs = srcs + all_deps = deps + dirname, _, _ = native.package_name().rpartition("/") + full_pkg = dirname + "/" + name if stateify: # Only do stateification for non-state packages without manual autogen. - go_stateify( - name = name + "_state_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - imports = imports, - package = name, - arch = select_arch(), - out = name + "_state_autogen.go", - ) - all_srcs = srcs + [name + "_state_autogen.go"] - if "//pkg/state" not in deps: - all_deps = deps + ["//pkg/state"] - else: - all_deps = deps - else: - all_deps = deps - all_srcs = srcs + # First, we need to segregate the input files via the special suffixes, + # and calculate the final output set. + state_sets = calculate_sets(srcs) + for (suffix, src_subset) in state_sets.items(): + go_stateify( + name = name + suffix + "_state_autogen_with_imports", + srcs = src_subset, + imports = imports, + package = full_pkg, + out = name + suffix + "_state_autogen_with_imports.go", + ) + go_imports( + name = name + suffix + "_state_autogen", + src = name + suffix + "_state_autogen_with_imports.go", + out = name + suffix + "_state_autogen.go", + ) + all_srcs = all_srcs + [ + name + suffix + "_state_autogen.go" + for suffix in state_sets.keys() + ] + if "//pkg/state" not in all_deps: + all_deps = all_deps + ["//pkg/state"] + if marshal: - go_marshal( - name = name + "_abi_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - debug = False, - imports = imports, - package = name, - ) + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, src_subset) in marshal_sets.items(): + go_marshal( + name = name + suffix + "_abi_autogen", + srcs = src_subset, + debug = select({ + "//tools/go_marshal:marshal_config_verbose": True, + "//conditions:default": marshal_debug, + }), + imports = imports, + package = name, + ) extra_deps = [ dep for dep in marshal_deps if not dep in all_deps ] all_deps = all_deps + extra_deps - all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + all_srcs = all_srcs + [ + name + suffix + "_abi_autogen_unsafe.go" + for suffix in marshal_sets.keys() + ] _go_library( name = name, @@ -114,13 +179,16 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F # Ignore importpath for go_test. kwargs.pop("importpath", None) - _go_test( - name = name + "_abi_autogen_test", - srcs = [name + "_abi_autogen_test.go"], - library = ":" + name, - deps = marshal_test_deps, - **kwargs - ) + # See above. + marshal_sets = calculate_sets(srcs) + for (suffix, _) in marshal_sets.items(): + _go_test( + name = name + suffix + "_abi_autogen_test", + srcs = [name + suffix + "_abi_autogen_test.go"], + library = ":" + name, + deps = marshal_test_deps, + **kwargs + ) def proto_library(name, srcs, **kwargs): """Wraps the standard proto_library. diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 069df3856..32a949c93 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -9,7 +9,7 @@ go_binary( "imports.go", "remove.go", ], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], deps = ["//tools/go_generics/globals"], ) diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index b7d35e272..2fd5a200d 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -5,5 +5,5 @@ package(licenses = ["notice"]) go_binary( name = "go_merge", srcs = ["main.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], ) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index 80d9c0504..be49cf9c8 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -12,3 +12,8 @@ go_binary( "//tools/go_marshal/gomarshal", ], ) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index c92b59dd6..44cb33ae4 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -7,6 +7,9 @@ go_library( srcs = [ "generator.go", "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", "generator_tests.go", "util.go", ], @@ -14,4 +17,5 @@ go_library( visibility = [ "//:sandbox", ], + deps = ["//tools/tags"], ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index af90bdecb..729489de5 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -23,6 +23,9 @@ import ( "go/token" "os" "sort" + "strings" + + "gvisor.dev/gvisor/tools/tags" ) const ( @@ -31,9 +34,9 @@ const ( usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) -// List of identifiers we use in generated code, that may conflict a -// similarly-named source identifier. Avoid problems by refusing the generate -// code when we see these. +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. @@ -41,10 +44,21 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "src", "srcs", "dst", "dsts", "blk", "buf", "err", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", + "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + // Generator drives code generation for a single invocation of the go_marshal // utility. // @@ -85,16 +99,20 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as - // used, so they're always added to the generated code. + // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } - g.imports.add(marshalImport).markUsed() - // The follow imports may or may not be used by the generated - // code, depending what's required for the target types. Don't - // mark these imports as used by default. - g.imports.add(usermemImport) - g.imports.add(safecopyImport) + + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("io") + g.imports.add("reflect") + g.imports.add("runtime") g.imports.add("unsafe") + g.imports.add(marshalImport) + g.imports.add(safecopyImport) + g.imports.add(usermemImport) return &g, nil } @@ -104,6 +122,14 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + + // Package header. b.emit("package %s\n\n", g.pkg) if err := b.write(g.output); err != nil { return err @@ -168,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } -// collectMarshallabeTypes walks the parsed AST and collects a list of type +// collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { var types []*ast.TypeSpec for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -197,14 +223,26 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as continue } for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier. + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. t := spec.(*ast.TypeSpec) - if _, ok := t.Type.(*ast.StructType); ok { - debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) types = append(types, t) continue } - debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } } return types @@ -218,11 +256,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - badImportNames := make(map[string]bool) - for _, i := range badIdents { - badImportNames[i] = true - } - is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -239,7 +272,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp if len(i.name) == 1 { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } - if badImportNames[i.name] { + if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } @@ -249,11 +282,22 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - // We're guaranteed to have only struct type specs by now. See - // Generator.collectMarshallabeTypes. i := newInterfaceGenerator(t, fset) - i.validate() - i.emitMarshallable() + switch ty := t.Type.(type) { + case *ast.StructType: + i.validateStruct(t, ty) + i.emitMarshallableForStruct(ty) + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype(ty) + case *ast.ArrayType: + i.validateArrayNewtype(t.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } return i } @@ -300,7 +344,7 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. - for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref, _ := range impl.ms { @@ -318,17 +362,6 @@ func (g *Generator) Run() error { } } - // Tool was invoked with input files with no data structures marked for code - // generation. This is probably not what the user intended. - if len(impls) == 0 { - var buf bytes.Buffer - fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") - for _, i := range g.inputs { - fmt.Fprintf(&buf, " %s\n", i) - } - abort(buf.String()) - } - // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { @@ -360,6 +393,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) @@ -369,6 +403,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func ExampleEmptyTestSuite() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index a712c14dc..8babf61d2 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,10 +15,8 @@ package gomarshal import ( - "fmt" "go/ast" "go/token" - "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -55,9 +53,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -84,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { g.as[fieldName] = struct{}{} } -func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - // abortAt aborts the go_marshal tool with the given error message, with a // reference position to the input source. Same as abortAt, but uses g to // resolve p to position. @@ -103,67 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { - g.forEachField(func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - g.forEachField(func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n, _ *ast.Ident, len int) { - a := f.Type.(*ast.ArrayType) - if a.Len == nil { - g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } - - if len <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - // scalarSize returns the size of type identified by t. If t isn't a primitive // type, the size isn't known at code generation time, and must be resolved via // the marshal.Marshallable interface. @@ -190,7 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +136,26 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -257,251 +163,3 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string g.recordPotentiallyNonPackedField(accessor) } } - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns <clause>, true, where <clause> is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor, _ := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - return strings.Join(cs, " && "), true -} - -func (g *interfaceGenerator) emitMarshallable() { - // Is g.t a packed struct without consideing field types? - thisPacked := true - g.forEachField(func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false - } - } - } - }) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) - } - }, - selector: func(n, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for i := 0; i < %d; i++ {\n", size) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for i := 0; i < %d; i++ {\n", size) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - }) - g.emit("}\n\n") - -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..da36d9305 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if arrayLen(a) <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d\n", size*len) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..159397825 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..e66a38b2e --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,450 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + // Is g.t a packed struct without consideing field types? + thisPacked := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index bcda17c3b..fd992e44a 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -22,9 +22,11 @@ import ( ) var standardImports = []string{ + "bytes", "fmt", "reflect", "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", } @@ -47,9 +49,6 @@ type testGenerator struct { } func newTestGenerator(t *ast.TypeSpec) *testGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &testGenerator{ t: t, r: receiverName(t), @@ -67,14 +66,6 @@ func (g *testGenerator) typeName() string { return g.t.Name.Name } -func (g *testGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } @@ -87,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) { func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { - g.emit("x := &%s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") }) g.emit("}\n") }) @@ -98,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() { func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { - g.emit("x := %s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } @@ -116,26 +107,64 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") }) g.emit("}\n") }) @@ -145,6 +174,8 @@ func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 967537abf..a0936e013 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -64,6 +64,12 @@ func kindString(e ast.Expr) string { } } +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { @@ -73,6 +79,25 @@ type fieldDispatcher struct { unhandled func(n *ast.Ident) } +// Precondition: a must have a literal for the array length. Consts and +// expressions are not allowed as array lengths, and should be rejected by the +// caller. +func arrayLen(a *ast.ArrayType) int { + if a.Len == nil { + // Probably a slice? Must be handled by caller. + panic("Nil array length in array type") + } + lenLit, ok := a.Len.(*ast.BasicLit) + if !ok { + panic("Array has non-literal for length") + } + len, err := strconv.Atoi(lenLit.Value) + if err != nil { + panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) + } + return len +} + // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: - len := 0 - if v.Len != nil { - // Non-literal array length is handled by generatorInterfaces.validate(). - if lenLit, ok := v.Len.(*ast.BasicLit); ok { - var err error - len, err = strconv.Atoi(lenLit.Value) - if err != nil { - panic(err) - } - } - } switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, len) + fd.array(name, t, arrayLen(v)) default: - fd.array(name, nil, len) + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) @@ -219,6 +234,11 @@ type sourceBuffer struct { b bytes.Buffer } +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + func (b *sourceBuffer) incIndent() { b.indent++ } @@ -305,7 +325,7 @@ func (i *importStmt) markUsed() { } func (i *importStmt) equivalent(other *importStmt) bool { - return i == other + return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. @@ -324,7 +344,7 @@ func newImportTable() *importTable { // result in a panic. func (i *importTable) merge(other *importTable) { for name, im := range other.is { - if dup, ok := i.is[name]; ok && dup.equivalent(im) { + if dup, ok := i.is[name]; ok && !dup.equivalent(im) { panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) } @@ -332,16 +352,27 @@ func (i *importTable) merge(other *importTable) { } } +func (i *importTable) addStmt(s *importStmt) *importStmt { + if old, ok := i.is[s.name]; ok && !old.equivalent(s) { + // A collision should always be between an import inserted by the + // go-marshal tool and an import from the original source file (assuming + // the original source file was valid). We could theoretically handle + // the collision by assigning a local name to our import. However, this + // would need to be plumbed throughout the generator. Given that + // collisions should be rare, simply panic on collision. + panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) + } + i.is[s.name] = s + return s +} + func (i *importTable) add(s string) *importStmt { n := newImport(s) - i.is[n.name] = n - return n + return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - n := newImportFromSpec(spec, f) - i.is[n.name] = n - return n + return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index ad508c72f..bacfaa5a4 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -10,4 +10,7 @@ go_library( visibility = [ "//:sandbox", ], + deps = [ + "//pkg/usermem", + ], ) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index a313a27ed..f129788e0 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -20,10 +20,38 @@ // tools/go_marshal. See the go_marshal README for details. package marshal +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// Task provides a subset of kernel.Task, used in marshalling. We don't import +// the kernel package directly to avoid circular dependency. +type Task interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + // Marshallable represents a type that can be marshalled to and from memory. type Marshallable interface { + io.WriterTo + // SizeBytes is the size of the memory representation of a type in // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). SizeBytes() int // MarshalBytes serializes a copy of a type to dst. dst must be at least @@ -48,13 +76,27 @@ type Marshallable interface { // MarshalBytes. MarshalUnsafe(dst []byte) - // UnmarshalUnsafe deserializes a type directly to the underlying memory - // allocated for the object by the runtime. + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. // // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes (usually by calling - // UnmarshalBytes directly). + // mechanism implemented in UnmarshalBytes. UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + CopyIn(task Task, addr usermem.Addr) error + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + CopyOut(task Task, addr usermem.Addr) error } diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index e345e3a8e..f27c5ce52 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") licenses(["notice"]) @@ -27,3 +27,15 @@ go_library( marshal = True, deps = ["//tools/go_marshal/test/external"], ) + +go_binary( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + gc_goopts = ["-m"], + deps = [ + ":test", + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go index e12403741..c79defe9e 100644 --- a/tools/go_marshal/test/benchmark_test.go +++ b/tools/go_marshal/test/benchmark_test.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" - test "gvisor.dev/gvisor/tools/go_marshal/test" + "gvisor.dev/gvisor/tools/go_marshal/test" ) // Marshalling using the standard encoding/binary package. diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go new file mode 100644 index 000000000..184f05ea3 --- /dev/null +++ b/tools/go_marshal/test/escape.go @@ -0,0 +1,114 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This binary provides a convienient target for analyzing how the go-marshal +// API causes its various arguments to escape to the heap. To use, build and +// observe the output from the go compiler's escape analysis: +// +// $ bazel build :escape +// ... +// escape.go:67:2: moved to heap: task +// escape.go:77:31: make([]byte, size) escapes to heap +// escape.go:87:31: make([]byte, size) escapes to heap +// escape.go:96:6: moved to heap: stat +// ... +// +// This is not an automated test, but simply a minimal binary for easy analysis. +package main + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + task.CopyOutBytes(addr, buf) +} + +func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + task.CopyOutBytes(addr, buf) +} + +// Expected escapes: +// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface. +func doCopyOut() { + task := dummyTask{} + var stat test.Stat + stat.CopyOut(&task, usermem.Addr(0xf000ba12)) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalBytesDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalUnsafeDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface. +func doMarshalBytesViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface. +func doMarshalUnsafeViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} + +func main() { + doCopyOut() + doMarshalBytesDirect() + doMarshalUnsafeDirect() + doMarshalBytesViaMarshallable() + doMarshalUnsafeViaMarshallable() +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 8de02d707..c829db6da 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -103,3 +103,18 @@ type Stat struct { CTime Timespec _ [3]int64 } + +// InetAddr is an example marshallable newtype on an array. +// +// +marshal +type InetAddr [4]byte + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal +type SignalSetAlias SignalSet diff --git a/tools/go_mod.sh b/tools/go_mod.sh new file mode 100755 index 000000000..84b779d6d --- /dev/null +++ b/tools/go_mod.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Build the :gopath target. +bazel build //:gopath +declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" + +# Copy go.mod and execute the command. +cp -a go.mod go.sum "${gopathdir}" +(cd "${gopathdir}" && go mod "$@") +cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . + +# Cleanup the WORKSPACE file. +bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index a133d6f8b..503cdf2e5 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -5,5 +5,6 @@ package(licenses = ["notice"]) go_binary( name = "stateify", srcs = ["main.go"], - visibility = ["//visibility:public"], + visibility = ["//:sandbox"], + deps = ["//tools/tags"], ) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 0f261d89f..6a5e666f0 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -6,8 +6,7 @@ def _go_stateify_impl(ctx): # Run the stateify command. args = ["-output=%s" % output.path] - args.append("-pkg=%s" % ctx.attr.package) - args.append("-arch=%s" % ctx.attr.arch) + args.append("-fullpkg=%s" % ctx.attr.package) if ctx.attr._statepkg: args.append("-statepkg=%s" % ctx.attr._statepkg) if ctx.attr.imports: @@ -44,18 +43,11 @@ for statified types. mandatory = False, ), "package": attr.string( - doc = "The package name for the input sources.", - mandatory = True, - ), - "arch": attr.string( - doc = "Target platform.", + doc = "The fully qualified package name for the input sources.", mandatory = True, ), "out": attr.output( - doc = """ -The name of the generated file output. This must not conflict with any other -files and must be added to the srcs of the relevant go_library. -""", + doc = "Name of the generator output file.", mandatory = True, ), "_tool": attr.label( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 7d5d291e6..3437aa476 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -22,126 +22,22 @@ import ( "go/ast" "go/parser" "go/token" - "io/ioutil" "os" "path/filepath" "reflect" "strings" "sync" + + "gvisor.dev/gvisor/tools/tags" ) var ( - pkg = flag.String("pkg", "", "output package") + fullPkg = flag.String("fullpkg", "", "fully qualified output package") imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") - arch = flag.String("arch", "", "specify the target platform") ) -// The known architectures. -var okgoarch = []string{ - "386", - "amd64", - "arm", - "arm64", - "mips", - "mipsle", - "mips64", - "mips64le", - "ppc64", - "ppc64le", - "riscv64", - "s390x", - "sparc64", - "wasm", -} - -// readfile returns the content of the named file. -func readfile(file string) string { - data, err := ioutil.ReadFile(file) - if err != nil { - panic(fmt.Sprintf("readfile err: %v", err)) - } - return string(data) -} - -// matchfield reports whether the field (x,y,z) matches this build. -// all the elements in the field must be satisfied. -func matchfield(f string, goarch string) bool { - for _, tag := range strings.Split(f, ",") { - if !matchtag(tag, goarch) { - return false - } - } - return true -} - -// matchtag reports whether the tag (x or !x) matches this build. -func matchtag(tag string, goarch string) bool { - if tag == "" { - return false - } - if tag[0] == '!' { - if len(tag) == 1 || tag[1] == '!' { - return false - } - return !matchtag(tag[1:], goarch) - } - return tag == goarch -} - -// canBuild reports whether we can build this file for target platform by -// checking file name and build tags. The code is derived from the Go source -// cmd.dist.build.shouldbuild. -func canBuild(file, goTargetArch string) bool { - name := filepath.Base(file) - excluded := func(list []string, ok string) bool { - for _, x := range list { - if x == ok || (ok == "android" && x == "linux") || (ok == "illumos" && x == "solaris") { - continue - } - i := strings.Index(name, x) - if i <= 0 || name[i-1] != '_' { - continue - } - i += len(x) - if i == len(name) || name[i] == '.' || name[i] == '_' { - return true - } - } - return false - } - if excluded(okgoarch, goTargetArch) { - return false - } - - // Check file contents for // +build lines. - for _, p := range strings.Split(readfile(file), "\n") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if !strings.HasPrefix(p, "//") { - break - } - if !strings.Contains(p, "+build") { - continue - } - fields := strings.Fields(p[2:]) - if len(fields) < 1 || fields[0] != "+build" { - continue - } - for _, p := range fields[1:] { - if matchfield(p, goTargetArch) { - goto fieldmatch - } - } - return false - fieldmatch: - } - return true -} - // resolveTypeName returns a qualified type name. func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { for done := false; !done; { @@ -275,7 +171,7 @@ func main() { flag.Usage() os.Exit(1) } - if *pkg == "" { + if *fullPkg == "" { fmt.Fprintf(os.Stderr, "Error: package required.") os.Exit(1) } @@ -307,7 +203,7 @@ func main() { // Declare our emission closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) } emitZeroCheck := func(name string) { fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) @@ -329,9 +225,17 @@ func main() { fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) } - // Emit the package name. + // Automated warning. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - fmt.Fprintf(outputFile, "package %s\n\n", *pkg) + + // Emit build tags. + if t := tags.Aggregate(flag.Args()); len(t) > 0 { + fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + } + + // Emit the package name. + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) // Emit the imports lazily. var once sync.Once @@ -364,10 +268,6 @@ func main() { os.Exit(1) } - if !canBuild(filename, *arch) { - continue - } - files = append(files, f) } diff --git a/tools/images/BUILD b/tools/images/BUILD index f1699b184..fe11f08a3 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary", "gtest") load("//tools/images:defs.bzl", "vm_image", "vm_test") package( @@ -32,8 +32,8 @@ cc_binary( srcs = ["test.cc"], linkstatic = 1, deps = [ + gtest, "//test/util:test_main", - "@com_google_googletest//:gtest", ], ) diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index 32235813a..de365d153 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -57,7 +57,10 @@ def _vm_image_impl(ctx): command = argv, input_manifests = runfiles_manifests, ) - return [DefaultInfo(files = depset([ctx.outputs.out]))] + return [DefaultInfo( + files = depset([ctx.outputs.out]), + runfiles = ctx.runfiles(files = [ctx.outputs.out]), + )] _vm_image = rule( attrs = { diff --git a/tools/images/ubuntu1604/10_core.sh b/tools/images/ubuntu1604/10_core.sh index 46dda6bb1..cd518d6ac 100755 --- a/tools/images/ubuntu1604/10_core.sh +++ b/tools/images/ubuntu1604/10_core.sh @@ -17,7 +17,20 @@ set -xeo pipefail # Install all essential build tools. -apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config +while true; do + if (apt-get update && apt-get install -y \ + make \ + git-core \ + build-essential \ + linux-headers-$(uname -r) \ + pkg-config); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install a recent go toolchain. if ! [[ -d /usr/local/go ]]; then diff --git a/tools/images/ubuntu1604/20_bazel.sh b/tools/images/ubuntu1604/20_bazel.sh index b33e1656c..bb7afa676 100755 --- a/tools/images/ubuntu1604/20_bazel.sh +++ b/tools/images/ubuntu1604/20_bazel.sh @@ -19,7 +19,17 @@ set -xeo pipefail declare -r BAZEL_VERSION=2.0.0 # Install bazel dependencies. -apt-get update && apt-get install -y openjdk-8-jdk-headless unzip +while true; do + if (apt-get update && apt-get install -y \ + openjdk-8-jdk-headless \ + unzip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Use the release installer. curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/images/ubuntu1604/25_docker.sh b/tools/images/ubuntu1604/25_docker.sh index 1d3defcd3..11eea2d72 100755 --- a/tools/images/ubuntu1604/25_docker.sh +++ b/tools/images/ubuntu1604/25_docker.sh @@ -15,12 +15,20 @@ # limitations under the License. # Add dependencies. -apt-get update && apt-get -y install \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common +while true; do + if (apt-get update && apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install the key. curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - @@ -32,4 +40,15 @@ add-apt-repository \ stable" # Install docker. -apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io +while true; do + if (apt-get update && apt-get install -y \ + docker-ce \ + docker-ce-cli \ + containerd.io); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done diff --git a/tools/images/ubuntu1604/30_containerd.sh b/tools/images/ubuntu1604/30_containerd.sh index a7472bd1c..fb3699c12 100755 --- a/tools/images/ubuntu1604/30_containerd.sh +++ b/tools/images/ubuntu1604/30_containerd.sh @@ -34,7 +34,17 @@ install_helper() { } # Install dependencies for the crictl tests. -apt-get install -y btrfs-tools libseccomp-dev +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # Install containerd & cri-tools. GOPATH=$(mktemp -d --tmpdir gopathXXXXX) diff --git a/tools/images/ubuntu1604/40_kokoro.sh b/tools/images/ubuntu1604/40_kokoro.sh index 5f2dfc858..06a1e6c48 100755 --- a/tools/images/ubuntu1604/40_kokoro.sh +++ b/tools/images/ubuntu1604/40_kokoro.sh @@ -23,7 +23,22 @@ declare -r ssh_public_keys=( ) # Install dependencies. -apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm python-pip python3-pip zip +while true; do + if (apt-get update && apt-get install -y \ + rsync \ + coreutils \ + python-psutil \ + qemu-kvm \ + python-pip \ + python3-pip \ + zip); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done # junitparser is used to merge junit xml files. pip install junitparser diff --git a/tools/installers/BUILD b/tools/installers/BUILD index 01bc4de8c..d78a265ca 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -5,10 +5,15 @@ package( licenses = ["notice"], ) +filegroup( + name = "runsc", + srcs = ["//runsc"], +) + sh_binary( name = "head", srcs = ["head.sh"], - data = ["//runsc"], + data = [":runsc"], ) sh_binary( diff --git a/tools/installers/head.sh b/tools/installers/head.sh index 4435cb27a..9de8f138c 100755 --- a/tools/installers/head.sh +++ b/tools/installers/head.sh @@ -15,7 +15,7 @@ # limitations under the License. # Install our runtime. -third_party/gvisor/runsc/runsc install +$(dirname $0)/runsc install # Restart docker. service docker restart || true diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 7b1956454..2c6001c6c 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -15,6 +15,20 @@ # limitations under the License. # Install runsc from the master branch. +set -e + curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" -apt-get update && apt-get install -y runsc + +while true; do + if (apt-get update && apt-get install -y runsc); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +runsc install +service docker restart diff --git a/tools/tag_release.sh b/tools/tag_release.sh index f33b902d6..4dbfe420a 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -21,13 +21,19 @@ set -xeu # Check arguments. -if [ "$#" -ne 2 ]; then - echo "usage: $0 <commit|revid> <release.rc>" +if [ "$#" -ne 3 ]; then + echo "usage: $0 <commit|revid> <release.rc> <message-file>" exit 1 fi declare -r target_commit="$1" declare -r release="$2" +declare -r message_file="$3" + +if ! [[ -r "${message_file}" ]]; then + echo "error: message file '${message_file}' is not readable." + exit 1 +fi closest_commit() { while read line; do @@ -64,6 +70,6 @@ fi # Tag the given commit (annotated, to record the committer). declare -r tag="release-${release}" -(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \ +(git tag -F "${message_file}" -a "${tag}" "${commit}" && \ git push origin tag "${tag}") || \ (git tag -d "${tag}" && false) diff --git a/tools/tags/BUILD b/tools/tags/BUILD new file mode 100644 index 000000000..1c02e2c89 --- /dev/null +++ b/tools/tags/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "tags", + srcs = ["tags.go"], + marshal = False, + stateify = False, + visibility = ["//tools:__subpackages__"], +) diff --git a/tools/tags/tags.go b/tools/tags/tags.go new file mode 100644 index 000000000..f35904e0a --- /dev/null +++ b/tools/tags/tags.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tags is a utility for parsing build tags. +package tags + +import ( + "fmt" + "io/ioutil" + "strings" +) + +// OrSet is a set of tags on a single line. +// +// Note that tags may include ",", and we don't distinguish this case in the +// logic below. Ideally, this constraints can be split into separate top-level +// build tags in order to resolve any issues. +type OrSet []string + +// Line returns the line for this or. +func (or OrSet) Line() string { + return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) +} + +// AndSet is the set of all OrSets. +type AndSet []OrSet + +// Lines returns the lines to be printed. +func (and AndSet) Lines() (ls []string) { + for _, or := range and { + ls = append(ls, or.Line()) + } + return +} + +// Join joins this AndSet with another. +func (and AndSet) Join(other AndSet) AndSet { + return append(and, other...) +} + +// Tags returns the unique set of +build tags. +// +// Derived form the runtime's canBuild. +func Tags(file string) (tags AndSet) { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil + } + // Check file contents for // +build lines. + for _, p := range strings.Split(string(data), "\n") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if !strings.HasPrefix(p, "//") { + break + } + if !strings.Contains(p, "+build") { + continue + } + fields := strings.Fields(p[2:]) + if len(fields) < 1 || fields[0] != "+build" { + continue + } + tags = append(tags, OrSet(fields[1:])) + } + return tags +} + +// Aggregate aggregates all tags from a set of files. +// +// Note that these may be in conflict, in which case the build will fail. +func Aggregate(files []string) (tags AndSet) { + for _, file := range files { + tags = tags.Join(Tags(file)) + } + return tags +} |