diff options
Diffstat (limited to 'tools')
-rw-r--r-- | tools/go_stateify/defs.bzl | 58 | ||||
-rw-r--r-- | tools/go_stateify/main.go | 66 |
2 files changed, 98 insertions, 26 deletions
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 60a9895ff..2b2582b7a 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -22,6 +22,8 @@ go_library( ) """ +load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") + def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" output = ctx.outputs.out @@ -33,6 +35,8 @@ def _go_stateify_impl(ctx): args += ["-statepkg=%s" % ctx.attr._statepkg] if ctx.attr.imports: args += ["-imports=%s" % ",".join(ctx.attr.imports)] + if ctx.attr.explicit: + args += ["-explicit=true"] args += ["--"] for src in ctx.attr.srcs: args += [f.path for f in src.files] @@ -45,17 +49,15 @@ def _go_stateify_impl(ctx): executable = ctx.executable._tool, ) -""" -Generates save and restore logic from a set of Go files. - - -Args: - name: the name of the rule. - srcs: the input source files. These files should include all structs in the package that need to be saved. - imports: an optional list of extra non-aliased, Go-style absolute import paths. - out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library. - package: the package name for the input sources. -""" +# Generates save and restore logic from a set of Go files. +# +# Args: +# name: the name of the rule. +# srcs: the input source files. These files should include all structs in the package that need to be saved. +# imports: an optional list of extra non-aliased, Go-style absolute import paths. +# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library. +# package: the package name for the input sources. +# explicit: only generate for types explicitly annotated as savable. go_stateify = rule( implementation = _go_stateify_impl, attrs = { @@ -63,7 +65,41 @@ go_stateify = rule( "imports": attr.string_list(mandatory = False), "package": attr.string(mandatory = True), "out": attr.output(mandatory = True), + "explicit": attr.bool(default = False), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")), "_statepkg": attr.string(default = "gvisor.googlesource.com/gvisor/pkg/state"), }, ) + +def go_library(name, srcs, deps = [], imports = [], **kwargs): + """wraps the standard go_library and does stateification.""" + if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: + # Only do stateification for non-state packages without manual autogen. + go_stateify( + name = name + "_state_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + imports = imports, + package = name, + out = name + "_state_autogen.go", + explicit = True, + ) + all_srcs = srcs + [name + "_state_autogen.go"] + if "//pkg/state" not in deps: + all_deps = deps + ["//pkg/state"] + else: + all_deps = deps + else: + all_deps = deps + all_srcs = srcs + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + +def go_test(**kwargs): + """Wraps the standard go_test.""" + _go_test( + **kwargs + ) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 6c3583c62..231c6d80b 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -25,6 +25,7 @@ import ( "os" "reflect" "strings" + "sync" ) var ( @@ -32,6 +33,7 @@ var ( imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") + explicit = flag.Bool("explicit", false, "only generate for types explicitly tagged '// +stateify savable'") ) // resolveTypeName returns a qualified type name. @@ -224,16 +226,24 @@ func main() { // Emit the package name. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") fmt.Fprintf(outputFile, "package %s\n\n", *pkg) - fmt.Fprint(outputFile, "import (\n") - if *statePkg != "" { - fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) - } - if *imports != "" { - for _, i := range strings.Split(*imports, ",") { - fmt.Fprintf(outputFile, " \"%s\"\n", i) - } + + // Emit the imports lazily. + var once sync.Once + maybeEmitImports := func() { + once.Do(func() { + // Emit the imports. + fmt.Fprint(outputFile, "import (\n") + if *statePkg != "" { + fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) + } + if *imports != "" { + for _, i := range strings.Split(*imports, ",") { + fmt.Fprintf(outputFile, " \"%s\"\n", i) + } + } + fmt.Fprint(outputFile, ")\n\n") + }) } - fmt.Fprint(outputFile, ")\n\n") files := make([]*ast.File, 0, len(flag.Args())) @@ -241,7 +251,7 @@ func main() { for _, filename := range flag.Args() { // Parse the file. fset := token.NewFileSet() - f, err := parser.ParseFile(fset, filename, nil, 0) + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { // Not a valid input file? fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) @@ -308,6 +318,26 @@ func main() { continue } + if *explicit { + // In explicit mode, only generate code for + // types explicitly marked + // "// +stateify savable" in one of the + // proceeding comment lines. + if d.Doc == nil { + continue + } + savable := false + for _, l := range d.Doc.List { + if l.Text == "// +stateify savable" { + savable = true + break + } + } + if !savable { + continue + } + } + for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) switch ts.Type.(type) { @@ -315,6 +345,8 @@ func main() { // Don't register. break case *ast.StructType: + maybeEmitImports() + ss := ts.Type.(*ast.StructType) // Define beforeSave if a definition was not found. This @@ -360,6 +392,8 @@ func main() { // Add to our registration. emitRegister(ts.Name.Name) case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: + maybeEmitImports() + _, val := resolveTypeName(ts.Name.Name, ts.Type) // Dispatch directly. @@ -377,10 +411,12 @@ func main() { } } - // Emit the init() function. - fmt.Fprintf(outputFile, "func init() {\n") - for _, ic := range initCalls { - fmt.Fprintf(outputFile, " %s\n", ic) + if len(initCalls) > 0 { + // Emit the init() function. + fmt.Fprintf(outputFile, "func init() {\n") + for _, ic := range initCalls { + fmt.Fprintf(outputFile, " %s\n", ic) + } + fmt.Fprintf(outputFile, "}\n") } - fmt.Fprintf(outputFile, "}\n") } |