diff options
Diffstat (limited to 'tools/go_stateify')
-rw-r--r-- | tools/go_stateify/BUILD | 10 | ||||
-rw-r--r-- | tools/go_stateify/defs.bzl | 60 | ||||
-rw-r--r-- | tools/go_stateify/main.go | 476 |
3 files changed, 546 insertions, 0 deletions
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD new file mode 100644 index 000000000..503cdf2e5 --- /dev/null +++ b/tools/go_stateify/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "stateify", + srcs = ["main.go"], + visibility = ["//:sandbox"], + deps = ["//tools/tags"], +) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl new file mode 100644 index 000000000..6a5e666f0 --- /dev/null +++ b/tools/go_stateify/defs.bzl @@ -0,0 +1,60 @@ +"""Stateify is a tool for generating state wrappers for Go types.""" + +def _go_stateify_impl(ctx): + """Implementation for the stateify tool.""" + output = ctx.outputs.out + + # Run the stateify command. + args = ["-output=%s" % output.path] + args.append("-fullpkg=%s" % ctx.attr.package) + if ctx.attr._statepkg: + args.append("-statepkg=%s" % ctx.attr._statepkg) + if ctx.attr.imports: + args.append("-imports=%s" % ",".join(ctx.attr.imports)) + args.append("--") + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output], + mnemonic = "GoStateify", + progress_message = "Generating state library %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +go_stateify = rule( + implementation = _go_stateify_impl, + doc = "Generates save and restore logic from a set of Go files.", + attrs = { + "srcs": attr.label_list( + doc = """ +The input source files. These files should include all structs in the package +that need to be saved. +""", + mandatory = True, + allow_files = True, + ), + "imports": attr.string_list( + doc = """ +An optional list of extra non-aliased, Go-style absolute import paths required +for statified types. +""", + mandatory = False, + ), + "package": attr.string( + doc = "The fully qualified package name for the input sources.", + mandatory = True, + ), + "out": attr.output( + doc = "Name of the generator output file.", + mandatory = True, + ), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_stateify:stateify"), + ), + "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), + }, +) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go new file mode 100644 index 000000000..4f6ed208a --- /dev/null +++ b/tools/go_stateify/main.go @@ -0,0 +1,476 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Stateify provides a simple way to generate Load/Save methods based on +// existing types and struct tags. +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + + "gvisor.dev/gvisor/tools/tags" +) + +var ( + fullPkg = flag.String("fullpkg", "", "fully qualified output package") + imports = flag.String("imports", "", "extra imports for the output file") + output = flag.String("output", "", "output file") + statePkg = flag.String("statepkg", "", "state import package; defaults to empty") +) + +// resolveTypeName returns a qualified type name. +func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { + for done := false; !done; { + // Resolve star expressions. + switch rs := typ.(type) { + case *ast.StarExpr: + qualified += "*" + typ = rs.X + case *ast.ArrayType: + if rs.Len == nil { + // Slice type declaration. + qualified += "[]" + } else { + // Array type declaration. + qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" + } + typ = rs.Elt + default: + // No more descent. + done = true + } + } + + // Resolve a package selector. + sel, ok := typ.(*ast.SelectorExpr) + if ok { + qualified = qualified + sel.X.(*ast.Ident).Name + "." + typ = sel.Sel + } + + // Figure out actual type name. + ident, ok := typ.(*ast.Ident) + if !ok { + panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) + } + field = ident.Name + qualified = qualified + field + return +} + +// extractStateTag pulls the relevant state tag. +func extractStateTag(tag *ast.BasicLit) string { + if tag == nil { + return "" + } + if len(tag.Value) < 2 { + return "" + } + return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") +} + +// scanFunctions is a set of functions passed to scanFields. +type scanFunctions struct { + zerovalue func(name string) + normal func(name string) + wait func(name string) + value func(name, typName string) +} + +// scanFields scans the fields of a struct. +// +// Each provided function will be applied to appropriately tagged fields, or +// skipped if nil. +// +// Fields tagged nosave are skipped. +func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { + if ss.Fields.List == nil { + // No fields. + return + } + + // Scan all fields. + for _, field := range ss.Fields.List { + // Calculate the name. + name := "" + if field.Names != nil { + // It's a named field; override. + name = field.Names[0].Name + } else { + // Anonymous types can't be embedded, so we don't need + // to worry about providing a useful name here. + name, _ = resolveTypeName("", field.Type) + } + + // Skip _ fields. + if name == "_" { + continue + } + + // Is this a anonymous struct? If yes, then continue the + // recursion with the given prefix. We don't pay attention to + // any tags on the top-level struct field. + tag := extractStateTag(field.Tag) + if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { + scanFields(anon, name+".", fn) + continue + } + + switch tag { + case "zerovalue": + if fn.zerovalue != nil { + fn.zerovalue(name) + } + + case "": + if fn.normal != nil { + fn.normal(name) + } + + case "wait": + if fn.wait != nil { + fn.wait(name) + } + + case "manual", "nosave", "ignore": + // Do nothing. + + default: + if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { + if fn.value != nil { + fn.value(name, tag[2:len(tag)-1]) + } + } + } + } +} + +func camelCased(name string) string { + return strings.ToUpper(name[:1]) + name[1:] +} + +func main() { + // Parse flags. + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + if *fullPkg == "" { + fmt.Fprintf(os.Stderr, "Error: package required.") + os.Exit(1) + } + + // Open the output file. + var ( + outputFile *os.File + err error + ) + if *output == "" || *output == "-" { + outputFile = os.Stdout + } else { + outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) + } + defer outputFile.Close() + } + + // Set the statePrefix for below, depending on the import. + statePrefix := "" + if *statePkg != "" { + parts := strings.Split(*statePkg, "/") + statePrefix = parts[len(parts)-1] + "." + } + + // initCalls is dumped at the end. + var initCalls []string + + // Common closures. + emitRegister := func(name string) { + initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) + } + emitZeroCheck := func(name string) { + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) + } + + // Automated warning. + fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") + + // Emit build tags. + if t := tags.Aggregate(flag.Args()); len(t) > 0 { + fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + } + + // Emit the package name. + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) + + // Emit the imports lazily. + var once sync.Once + maybeEmitImports := func() { + once.Do(func() { + // Emit the imports. + fmt.Fprint(outputFile, "import (\n") + if *statePkg != "" { + fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) + } + if *imports != "" { + for _, i := range strings.Split(*imports, ",") { + fmt.Fprintf(outputFile, " \"%s\"\n", i) + } + } + fmt.Fprint(outputFile, ")\n\n") + }) + } + + files := make([]*ast.File, 0, len(flag.Args())) + + // Parse the input files. + for _, filename := range flag.Args() { + // Parse the file. + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) + os.Exit(1) + } + + files = append(files, f) + } + + type method struct { + receiver string + name string + } + + // Search for and add all methods with a pointer receiver and no other + // arguments to a set. We support auto-detecting the existence of + // several different methods with this signature. + simpleMethods := map[method]struct{}{} + for _, f := range files { + + // Go over all functions. + for _, decl := range f.Decls { + d, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if d.Name == nil || d.Recv == nil || d.Type == nil { + // Not a named method. + continue + } + if len(d.Recv.List) != 1 { + // Wrong number of receivers? + continue + } + if d.Type.Params != nil && len(d.Type.Params.List) != 0 { + // Has argument(s). + continue + } + if d.Type.Results != nil && len(d.Type.Results.List) != 0 { + // Has return(s). + continue + } + + pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) + if !ok { + // Not a pointer receiver. + continue + } + + t, ok := pt.X.(*ast.Ident) + if !ok { + // This shouldn't happen with valid Go. + continue + } + + simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} + } + } + + for _, f := range files { + // Go over all named types. + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE { + continue + } + + // Only generate code for types marked "// +stateify + // savable" in one of the proceeding comment lines. If + // the line is marked "// +stateify type" then only + // generate type information and register the type. + if d.Doc == nil { + continue + } + var ( + generateTypeInfo = false + generateSaverLoader = false + ) + for _, l := range d.Doc.List { + if l.Text == "// +stateify savable" { + generateTypeInfo = true + generateSaverLoader = true + break + } + if l.Text == "// +stateify type" { + generateTypeInfo = true + } + } + if !generateTypeInfo && !generateSaverLoader { + continue + } + + for _, gs := range d.Specs { + ts := gs.(*ast.TypeSpec) + switch x := ts.Type.(type) { + case *ast.StructType: + maybeEmitImports() + + // Record the slot for each field. + fieldCount := 0 + fields := make(map[string]int) + emitField := func(name string) { + fmt.Fprintf(outputFile, " \"%s\",\n", name) + fields[name] = fieldCount + fieldCount++ + } + emitFieldValue := func(name string, _ string) { + emitField(name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + } + + // Generate the type name method. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + + // Generate the fields method. + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return []string{\n") + scanFields(x, "", scanFunctions{ + normal: emitField, + wait: emitField, + value: emitFieldValue, + }) + fmt.Fprintf(outputFile, " }\n") + fmt.Fprintf(outputFile, "}\n\n") + + // Define beforeSave if a definition was not found. This prevents + // the code from compiling if a custom beforeSave was defined in a + // file not provided to this binary and prevents inherited methods + // from being called multiple times by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) + } + + // Generate the save method. + // + // N.B. For historical reasons, we perform the value saves first, + // and perform the value loads last. There should be no dependency + // on this specific behavior, but the ability to specify slots + // allows a manual implementation to be order-dependent. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) + scanFields(x, "", scanFunctions{value: emitSaveValue}) + scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + } + + // Define afterLoad if a definition was not found. We do this for + // the same reason that we do it for beforeSave. + _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] + if !hasAfterLoad && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) + } + + // Generate the load method. + // + // N.B. See the comment above for the save method. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(x, "", scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") + } + + // Add to our registration. + emitRegister(ts.Name.Name) + + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: + maybeEmitImports() + + // Generate the info methods. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return nil\n") + fmt.Fprintf(outputFile, "}\n\n") + + // See above. + emitRegister(ts.Name.Name) + } + } + } + } + + if len(initCalls) > 0 { + // Emit the init() function. + fmt.Fprintf(outputFile, "func init() {\n") + for _, ic := range initCalls { + fmt.Fprintf(outputFile, " %s\n", ic) + } + fmt.Fprintf(outputFile, "}\n") + } +} |