summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_stateify')
-rw-r--r--tools/go_stateify/defs.bzl58
-rw-r--r--tools/go_stateify/main.go66
2 files changed, 98 insertions, 26 deletions
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index 60a9895ff..2b2582b7a 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -22,6 +22,8 @@ go_library(
)
"""
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+
def _go_stateify_impl(ctx):
"""Implementation for the stateify tool."""
output = ctx.outputs.out
@@ -33,6 +35,8 @@ def _go_stateify_impl(ctx):
args += ["-statepkg=%s" % ctx.attr._statepkg]
if ctx.attr.imports:
args += ["-imports=%s" % ",".join(ctx.attr.imports)]
+ if ctx.attr.explicit:
+ args += ["-explicit=true"]
args += ["--"]
for src in ctx.attr.srcs:
args += [f.path for f in src.files]
@@ -45,17 +49,15 @@ def _go_stateify_impl(ctx):
executable = ctx.executable._tool,
)
-"""
-Generates save and restore logic from a set of Go files.
-
-
-Args:
- name: the name of the rule.
- srcs: the input source files. These files should include all structs in the package that need to be saved.
- imports: an optional list of extra non-aliased, Go-style absolute import paths.
- out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
- package: the package name for the input sources.
-"""
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the package that need to be saved.
+# imports: an optional list of extra non-aliased, Go-style absolute import paths.
+# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+# explicit: only generate for types explicitly annotated as savable.
go_stateify = rule(
implementation = _go_stateify_impl,
attrs = {
@@ -63,7 +65,41 @@ go_stateify = rule(
"imports": attr.string_list(mandatory = False),
"package": attr.string(mandatory = True),
"out": attr.output(mandatory = True),
+ "explicit": attr.bool(default = False),
"_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")),
"_statepkg": attr.string(default = "gvisor.googlesource.com/gvisor/pkg/state"),
},
)
+
+def go_library(name, srcs, deps = [], imports = [], **kwargs):
+ """wraps the standard go_library and does stateification."""
+ if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs:
+ # Only do stateification for non-state packages without manual autogen.
+ go_stateify(
+ name = name + "_state_autogen",
+ srcs = [src for src in srcs if src.endswith(".go")],
+ imports = imports,
+ package = name,
+ out = name + "_state_autogen.go",
+ explicit = True,
+ )
+ all_srcs = srcs + [name + "_state_autogen.go"]
+ if "//pkg/state" not in deps:
+ all_deps = deps + ["//pkg/state"]
+ else:
+ all_deps = deps
+ else:
+ all_deps = deps
+ all_srcs = srcs
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+
+def go_test(**kwargs):
+ """Wraps the standard go_test."""
+ _go_test(
+ **kwargs
+ )
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 6c3583c62..231c6d80b 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -25,6 +25,7 @@ import (
"os"
"reflect"
"strings"
+ "sync"
)
var (
@@ -32,6 +33,7 @@ var (
imports = flag.String("imports", "", "extra imports for the output file")
output = flag.String("output", "", "output file")
statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
+ explicit = flag.Bool("explicit", false, "only generate for types explicitly tagged '// +stateify savable'")
)
// resolveTypeName returns a qualified type name.
@@ -224,16 +226,24 @@ func main() {
// Emit the package name.
fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
fmt.Fprintf(outputFile, "package %s\n\n", *pkg)
- fmt.Fprint(outputFile, "import (\n")
- if *statePkg != "" {
- fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg)
- }
- if *imports != "" {
- for _, i := range strings.Split(*imports, ",") {
- fmt.Fprintf(outputFile, " \"%s\"\n", i)
- }
+
+ // Emit the imports lazily.
+ var once sync.Once
+ maybeEmitImports := func() {
+ once.Do(func() {
+ // Emit the imports.
+ fmt.Fprint(outputFile, "import (\n")
+ if *statePkg != "" {
+ fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg)
+ }
+ if *imports != "" {
+ for _, i := range strings.Split(*imports, ",") {
+ fmt.Fprintf(outputFile, " \"%s\"\n", i)
+ }
+ }
+ fmt.Fprint(outputFile, ")\n\n")
+ })
}
- fmt.Fprint(outputFile, ")\n\n")
files := make([]*ast.File, 0, len(flag.Args()))
@@ -241,7 +251,7 @@ func main() {
for _, filename := range flag.Args() {
// Parse the file.
fset := token.NewFileSet()
- f, err := parser.ParseFile(fset, filename, nil, 0)
+ f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
// Not a valid input file?
fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
@@ -308,6 +318,26 @@ func main() {
continue
}
+ if *explicit {
+ // In explicit mode, only generate code for
+ // types explicitly marked
+ // "// +stateify savable" in one of the
+ // proceeding comment lines.
+ if d.Doc == nil {
+ continue
+ }
+ savable := false
+ for _, l := range d.Doc.List {
+ if l.Text == "// +stateify savable" {
+ savable = true
+ break
+ }
+ }
+ if !savable {
+ continue
+ }
+ }
+
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
switch ts.Type.(type) {
@@ -315,6 +345,8 @@ func main() {
// Don't register.
break
case *ast.StructType:
+ maybeEmitImports()
+
ss := ts.Type.(*ast.StructType)
// Define beforeSave if a definition was not found. This
@@ -360,6 +392,8 @@ func main() {
// Add to our registration.
emitRegister(ts.Name.Name)
case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
+ maybeEmitImports()
+
_, val := resolveTypeName(ts.Name.Name, ts.Type)
// Dispatch directly.
@@ -377,10 +411,12 @@ func main() {
}
}
- // Emit the init() function.
- fmt.Fprintf(outputFile, "func init() {\n")
- for _, ic := range initCalls {
- fmt.Fprintf(outputFile, " %s\n", ic)
+ if len(initCalls) > 0 {
+ // Emit the init() function.
+ fmt.Fprintf(outputFile, "func init() {\n")
+ for _, ic := range initCalls {
+ fmt.Fprintf(outputFile, " %s\n", ic)
+ }
+ fmt.Fprintf(outputFile, "}\n")
}
- fmt.Fprintf(outputFile, "}\n")
}