diff options
Diffstat (limited to 'tools')
-rw-r--r-- | tools/bazeldefs/defs.bzl | 5 | ||||
-rw-r--r-- | tools/bazeldefs/platforms.bzl | 17 | ||||
-rw-r--r-- | tools/defs.bzl | 31 | ||||
-rw-r--r-- | tools/go_marshal/BUILD | 5 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 119 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 443 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 67 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 30 | ||||
-rw-r--r-- | tools/go_marshal/marshal/BUILD | 3 | ||||
-rw-r--r-- | tools/go_marshal/marshal/marshal.go | 50 | ||||
-rw-r--r-- | tools/go_marshal/test/BUILD | 14 | ||||
-rw-r--r-- | tools/go_marshal/test/benchmark_test.go | 2 | ||||
-rw-r--r-- | tools/go_marshal/test/escape.go | 114 | ||||
-rw-r--r-- | tools/go_marshal/test/test.go | 10 | ||||
-rw-r--r-- | tools/go_stateify/defs.bzl | 4 | ||||
-rw-r--r-- | tools/go_stateify/main.go | 10 | ||||
-rwxr-xr-x | tools/installers/master.sh | 17 |
17 files changed, 785 insertions, 156 deletions
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 08c29ff1c..905b16d41 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -8,7 +8,6 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") load("@pydeps//:requirements.bzl", _py_requirement = "requirement") -load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes") container_image = _container_image cc_binary = _cc_binary @@ -19,8 +18,8 @@ cc_test = _cc_test cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" go_image = _go_image go_embed_data = _go_embed_data -go_suffixes = _go_suffixes gtest = "@com_google_googletest//:gtest" +gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" proto_library = native.proto_library pkg_deb = _pkg_deb @@ -72,7 +71,7 @@ def go_test(name, **kwargs): **kwargs ) -def py_requirement(name, direct = False): +def py_requirement(name, direct = True): return _py_requirement(name) def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl new file mode 100644 index 000000000..92b0b5fc0 --- /dev/null +++ b/tools/bazeldefs/platforms.bzl @@ -0,0 +1,17 @@ +"""List of platforms.""" + +# Platform to associated tags. +platforms = { + "ptrace": [ + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], + "kvm": [ + "manual", + "local", + # TODO(b/120560048): Make the tests run without this tag. + "no-sandbox", + ], +} + +default_platform = "ptrace" diff --git a/tools/defs.bzl b/tools/defs.bzl index d4690cc1a..15a310403 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") +load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") +load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Delegate directly. cc_binary = _cc_binary @@ -21,6 +23,7 @@ go_image = _go_image go_test = _go_test go_tool_library = _go_tool_library gtest = _gtest +gbenchmark = _gbenchmark pkg_deb = _pkg_deb pkg_tar = _pkg_tar py_library = _py_library @@ -32,6 +35,8 @@ select_system = _select_system loopback = _loopback default_installer = _default_installer default_net_util = _default_net_util +platforms = _platforms +default_platform = _default_platform def go_binary(name, **kwargs): """Wraps the standard go_binary. @@ -83,7 +88,7 @@ def go_imports(name, src, out): cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), ) -def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. The recommended way is to use this rule with mostly identical configuration as the native @@ -106,21 +111,24 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F imports: imports required for stateify. stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). + marshal_debug: whether the gomarshal tools emits debugging output (default: false). **kwargs: standard go_library arguments. """ all_srcs = srcs all_deps = deps + dirname, _, _ = native.package_name().rpartition("/") + full_pkg = dirname + "/" + name if stateify: # Only do stateification for non-state packages without manual autogen. # First, we need to segregate the input files via the special suffixes, # and calculate the final output set. state_sets = calculate_sets(srcs) - for (suffix, srcs) in state_sets.items(): + for (suffix, src_subset) in state_sets.items(): go_stateify( name = name + suffix + "_state_autogen_with_imports", - srcs = srcs, + srcs = src_subset, imports = imports, - package = name, + package = full_pkg, out = name + suffix + "_state_autogen_with_imports.go", ) go_imports( @@ -138,11 +146,14 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F if marshal: # See above. marshal_sets = calculate_sets(srcs) - for (suffix, srcs) in marshal_sets.items(): + for (suffix, src_subset) in marshal_sets.items(): go_marshal( name = name + suffix + "_abi_autogen", - srcs = srcs, - debug = False, + srcs = src_subset, + debug = select({ + "//tools/go_marshal:marshal_config_verbose": True, + "//conditions:default": marshal_debug, + }), imports = imports, package = name, ) @@ -170,11 +181,11 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F # See above. marshal_sets = calculate_sets(srcs) - for (suffix, srcs) in marshal_sets.items(): + for (suffix, _) in marshal_sets.items(): _go_test( name = name + suffix + "_abi_autogen_test", srcs = [name + suffix + "_abi_autogen_test.go"], - library = ":" + name + suffix, + library = ":" + name, deps = marshal_test_deps, **kwargs ) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index 80d9c0504..be49cf9c8 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -12,3 +12,8 @@ go_binary( "//tools/go_marshal/gomarshal", ], ) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0b3f600fe..d365a1f3c 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -34,9 +34,9 @@ const ( usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) -// List of identifiers we use in generated code, that may conflict a -// similarly-named source identifier. Avoid problems by refusing the generate -// code when we see these. +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. @@ -44,10 +44,21 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "src", "srcs", "dst", "dsts", "blk", "buf", "err", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", + "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + // Generator drives code generation for a single invocation of the go_marshal // utility. // @@ -88,16 +99,20 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as - // used, so they're always added to the generated code. + // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } - g.imports.add(marshalImport).markUsed() - // The follow imports may or may not be used by the generated - // code, depending what's required for the target types. Don't - // mark these imports as used by default. - g.imports.add(usermemImport) - g.imports.add(safecopyImport) + + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("io") + g.imports.add("reflect") + g.imports.add("runtime") g.imports.add("unsafe") + g.imports.add(marshalImport) + g.imports.add(safecopyImport) + g.imports.add(usermemImport) return &g, nil } @@ -111,7 +126,7 @@ func (g *Generator) writeHeader() error { // Emit build tags. if t := tags.Aggregate(g.inputs); len(t) > 0 { b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n") + b.emit("\n\n") } // Package header. @@ -179,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } -// collectMarshallabeTypes walks the parsed AST and collects a list of type +// collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { var types []*ast.TypeSpec for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -208,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as continue } for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier. + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. t := spec.(*ast.TypeSpec) - if _, ok := t.Type.(*ast.StructType); ok { - debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) types = append(types, t) continue } - debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } } return types @@ -229,11 +252,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - badImportNames := make(map[string]bool) - for _, i := range badIdents { - badImportNames[i] = true - } - is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -250,7 +268,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp if len(i.name) == 1 { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } - if badImportNames[i.name] { + if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } @@ -260,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - // We're guaranteed to have only struct type specs by now. See - // Generator.collectMarshallabeTypes. i := newInterfaceGenerator(t, fset) - i.validate() - i.emitMarshallable() - return i + switch ty := t.Type.(type) { + case *ast.StructType: + i.validateStruct() + i.emitMarshallableForStruct() + return i + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype() + return i + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } } // generateOneTestSuite generates a test suite for the automatically generated @@ -311,7 +337,7 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. - for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref, _ := range impl.ms { @@ -329,17 +355,6 @@ func (g *Generator) Run() error { } } - // Tool was invoked with input files with no data structures marked for code - // generation. This is probably not what the user intended. - if len(impls) == 0 { - var buf bytes.Buffer - fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") - for _, i := range g.inputs { - fmt.Fprintf(&buf, " %s\n", i) - } - abort(buf.String()) - } - // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { @@ -371,6 +386,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) @@ -380,6 +396,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func ExampleEmptyTestSuite() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index a712c14dc..ea1af998e 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct() { g.forEachField(func(f *ast.Field) { if len(f.Names) == 0 { g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") @@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() { g.forEachField(func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } + g.validatePrimitiveNewtype(t) }, selector: func(_, _, _ *ast.Ident) { // No validation to perform on selector fields. However this @@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. +func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalStructFieldScalar reads a single scalar field from a struct, from a +// byte slice. +func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string } } +// marshalPrimitiveScalar writes a single primitive variable to a byte slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + // areFieldsPackedExpression returns a go expression checking whether g.t's fields are // packed. Returns "", false if g.t has no fields that may be potentially // packed, otherwise returns <clause>, true, where <clause> is an expression @@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { return strings.Join(cs, " && "), true } -func (g *interfaceGenerator) emitMarshallable() { +func (g *interfaceGenerator) emitMarshallableForStruct() { // Is g.t a packed struct without consideing field types? thisPacked := true g.forEachField(func(f *ast.Field) { @@ -301,7 +330,7 @@ func (g *interfaceGenerator) emitMarshallable() { primitiveSize += size } else { g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) } }, selector: func(n, tX, tSel *ast.Ident) { @@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) g.emit("}\n") }, @@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) g.emit("}\n") }, @@ -504,4 +533,290 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + nt := g.t.Type.(*ast.Ident) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") + } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index bcda17c3b..8ba47eb67 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -22,9 +22,11 @@ import ( ) var standardImports = []string{ + "bytes", "fmt", "reflect", "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", } @@ -47,9 +49,6 @@ type testGenerator struct { } func newTestGenerator(t *ast.TypeSpec) *testGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &testGenerator{ t: t, r: receiverName(t), @@ -67,14 +66,6 @@ func (g *testGenerator) typeName() string { return g.t.Name.Name } -func (g *testGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } @@ -87,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) { func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { - g.emit("x := &%s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") }) g.emit("}\n") }) @@ -98,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() { func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { - g.emit("x := %s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } @@ -116,26 +107,64 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)") }) g.emit("}\n") }) @@ -145,6 +174,8 @@ func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 967537abf..e2bca4e7c 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -219,6 +219,11 @@ type sourceBuffer struct { b bytes.Buffer } +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + func (b *sourceBuffer) incIndent() { b.indent++ } @@ -305,7 +310,7 @@ func (i *importStmt) markUsed() { } func (i *importStmt) equivalent(other *importStmt) bool { - return i == other + return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. @@ -324,7 +329,7 @@ func newImportTable() *importTable { // result in a panic. func (i *importTable) merge(other *importTable) { for name, im := range other.is { - if dup, ok := i.is[name]; ok && dup.equivalent(im) { + if dup, ok := i.is[name]; ok && !dup.equivalent(im) { panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) } @@ -332,16 +337,27 @@ func (i *importTable) merge(other *importTable) { } } +func (i *importTable) addStmt(s *importStmt) *importStmt { + if old, ok := i.is[s.name]; ok && !old.equivalent(s) { + // A collision should always be between an import inserted by the + // go-marshal tool and an import from the original source file (assuming + // the original source file was valid). We could theoretically handle + // the collision by assigning a local name to our import. However, this + // would need to be plumbed throughout the generator. Given that + // collisions should be rare, simply panic on collision. + panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) + } + i.is[s.name] = s + return s +} + func (i *importTable) add(s string) *importStmt { n := newImport(s) - i.is[n.name] = n - return n + return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - n := newImportFromSpec(spec, f) - i.is[n.name] = n - return n + return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index ad508c72f..bacfaa5a4 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -10,4 +10,7 @@ go_library( visibility = [ "//:sandbox", ], + deps = [ + "//pkg/usermem", + ], ) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index a313a27ed..f129788e0 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -20,10 +20,38 @@ // tools/go_marshal. See the go_marshal README for details. package marshal +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// Task provides a subset of kernel.Task, used in marshalling. We don't import +// the kernel package directly to avoid circular dependency. +type Task interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + // Marshallable represents a type that can be marshalled to and from memory. type Marshallable interface { + io.WriterTo + // SizeBytes is the size of the memory representation of a type in // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). SizeBytes() int // MarshalBytes serializes a copy of a type to dst. dst must be at least @@ -48,13 +76,27 @@ type Marshallable interface { // MarshalBytes. MarshalUnsafe(dst []byte) - // UnmarshalUnsafe deserializes a type directly to the underlying memory - // allocated for the object by the runtime. + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. // // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes (usually by calling - // UnmarshalBytes directly). + // mechanism implemented in UnmarshalBytes. UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + CopyIn(task Task, addr usermem.Addr) error + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + CopyOut(task Task, addr usermem.Addr) error } diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index e345e3a8e..f27c5ce52 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") licenses(["notice"]) @@ -27,3 +27,15 @@ go_library( marshal = True, deps = ["//tools/go_marshal/test/external"], ) + +go_binary( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + gc_goopts = ["-m"], + deps = [ + ":test", + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go index e12403741..c79defe9e 100644 --- a/tools/go_marshal/test/benchmark_test.go +++ b/tools/go_marshal/test/benchmark_test.go @@ -24,7 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" - test "gvisor.dev/gvisor/tools/go_marshal/test" + "gvisor.dev/gvisor/tools/go_marshal/test" ) // Marshalling using the standard encoding/binary package. diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go new file mode 100644 index 000000000..184f05ea3 --- /dev/null +++ b/tools/go_marshal/test/escape.go @@ -0,0 +1,114 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This binary provides a convienient target for analyzing how the go-marshal +// API causes its various arguments to escape to the heap. To use, build and +// observe the output from the go compiler's escape analysis: +// +// $ bazel build :escape +// ... +// escape.go:67:2: moved to heap: task +// escape.go:77:31: make([]byte, size) escapes to heap +// escape.go:87:31: make([]byte, size) escapes to heap +// escape.go:96:6: moved to heap: stat +// ... +// +// This is not an automated test, but simply a minimal binary for easy analysis. +package main + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + task.CopyOutBytes(addr, buf) +} + +func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := task.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + task.CopyOutBytes(addr, buf) +} + +// Expected escapes: +// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface. +func doCopyOut() { + task := dummyTask{} + var stat test.Stat + stat.CopyOut(&task, usermem.Addr(0xf000ba12)) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalBytesDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - buf: make allocates on the heap. +func doMarshalUnsafeDirect() { + task := dummyTask{} + var stat test.Stat + buf := task.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + task.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface. +func doMarshalBytesViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// Expected escapes: +// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface. +func doMarshalUnsafeViaMarshallable() { + task := dummyTask{} + var stat test.Stat + task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} + +func main() { + doCopyOut() + doMarshalBytesDirect() + doMarshalUnsafeDirect() + doMarshalBytesViaMarshallable() + doMarshalUnsafeViaMarshallable() +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index 8de02d707..93229dedb 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -103,3 +103,13 @@ type Stat struct { CTime Timespec _ [3]int64 } + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal +type SignalSetAlias SignalSet diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index bdb966362..6a5e666f0 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -6,7 +6,7 @@ def _go_stateify_impl(ctx): # Run the stateify command. args = ["-output=%s" % output.path] - args.append("-pkg=%s" % ctx.attr.package) + args.append("-fullpkg=%s" % ctx.attr.package) if ctx.attr._statepkg: args.append("-statepkg=%s" % ctx.attr._statepkg) if ctx.attr.imports: @@ -43,7 +43,7 @@ for statified types. mandatory = False, ), "package": attr.string( - doc = "The package name for the input sources.", + doc = "The fully qualified package name for the input sources.", mandatory = True, ), "out": attr.output( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index aa9d4543e..3437aa476 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -23,6 +23,7 @@ import ( "go/parser" "go/token" "os" + "path/filepath" "reflect" "strings" "sync" @@ -31,7 +32,7 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") + fullPkg = flag.String("fullpkg", "", "fully qualified output package") imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") @@ -170,7 +171,7 @@ func main() { flag.Usage() os.Exit(1) } - if *pkg == "" { + if *fullPkg == "" { fmt.Fprintf(os.Stderr, "Error: package required.") os.Exit(1) } @@ -202,7 +203,7 @@ func main() { // Declare our emission closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) } emitZeroCheck := func(name string) { fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) @@ -233,7 +234,8 @@ func main() { } // Emit the package name. - fmt.Fprintf(outputFile, "package %s\n\n", *pkg) + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) // Emit the imports lazily. var once sync.Once diff --git a/tools/installers/master.sh b/tools/installers/master.sh index 7b1956454..52f9734a6 100755 --- a/tools/installers/master.sh +++ b/tools/installers/master.sh @@ -15,6 +15,21 @@ # limitations under the License. # Install runsc from the master branch. +set -e + curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" -apt-get update && apt-get install -y runsc +while true; do + if apt-get update; then + apt-get install -y runsc + break + fi + result=$? + # Check if apt update failed to aquire the file lock. + if [[ $result -ne 100 ]]; then + exit $result + fi +done +runsc install +service docker restart + |