diff options
Diffstat (limited to 'tools/go_stateify')
-rw-r--r-- | tools/go_stateify/BUILD | 10 | ||||
-rw-r--r-- | tools/go_stateify/defs.bzl | 60 | ||||
-rw-r--r-- | tools/go_stateify/main.go | 476 |
3 files changed, 0 insertions, 546 deletions
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD deleted file mode 100644 index 503cdf2e5..000000000 --- a/tools/go_stateify/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_binary") - -package(licenses = ["notice"]) - -go_binary( - name = "stateify", - srcs = ["main.go"], - visibility = ["//:sandbox"], - deps = ["//tools/tags"], -) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl deleted file mode 100644 index 6a5e666f0..000000000 --- a/tools/go_stateify/defs.bzl +++ /dev/null @@ -1,60 +0,0 @@ -"""Stateify is a tool for generating state wrappers for Go types.""" - -def _go_stateify_impl(ctx): - """Implementation for the stateify tool.""" - output = ctx.outputs.out - - # Run the stateify command. - args = ["-output=%s" % output.path] - args.append("-fullpkg=%s" % ctx.attr.package) - if ctx.attr._statepkg: - args.append("-statepkg=%s" % ctx.attr._statepkg) - if ctx.attr.imports: - args.append("-imports=%s" % ",".join(ctx.attr.imports)) - args.append("--") - for src in ctx.attr.srcs: - args += [f.path for f in src.files.to_list()] - ctx.actions.run( - inputs = ctx.files.srcs, - outputs = [output], - mnemonic = "GoStateify", - progress_message = "Generating state library %s" % ctx.label, - arguments = args, - executable = ctx.executable._tool, - ) - -go_stateify = rule( - implementation = _go_stateify_impl, - doc = "Generates save and restore logic from a set of Go files.", - attrs = { - "srcs": attr.label_list( - doc = """ -The input source files. These files should include all structs in the package -that need to be saved. -""", - mandatory = True, - allow_files = True, - ), - "imports": attr.string_list( - doc = """ -An optional list of extra non-aliased, Go-style absolute import paths required -for statified types. -""", - mandatory = False, - ), - "package": attr.string( - doc = "The fully qualified package name for the input sources.", - mandatory = True, - ), - "out": attr.output( - doc = "Name of the generator output file.", - mandatory = True, - ), - "_tool": attr.label( - executable = True, - cfg = "host", - default = Label("//tools/go_stateify:stateify"), - ), - "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), - }, -) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go deleted file mode 100644 index 4f6ed208a..000000000 --- a/tools/go_stateify/main.go +++ /dev/null @@ -1,476 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Stateify provides a simple way to generate Load/Save methods based on -// existing types and struct tags. -package main - -import ( - "flag" - "fmt" - "go/ast" - "go/parser" - "go/token" - "os" - "path/filepath" - "reflect" - "strings" - "sync" - - "gvisor.dev/gvisor/tools/tags" -) - -var ( - fullPkg = flag.String("fullpkg", "", "fully qualified output package") - imports = flag.String("imports", "", "extra imports for the output file") - output = flag.String("output", "", "output file") - statePkg = flag.String("statepkg", "", "state import package; defaults to empty") -) - -// resolveTypeName returns a qualified type name. -func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { - for done := false; !done; { - // Resolve star expressions. - switch rs := typ.(type) { - case *ast.StarExpr: - qualified += "*" - typ = rs.X - case *ast.ArrayType: - if rs.Len == nil { - // Slice type declaration. - qualified += "[]" - } else { - // Array type declaration. - qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" - } - typ = rs.Elt - default: - // No more descent. - done = true - } - } - - // Resolve a package selector. - sel, ok := typ.(*ast.SelectorExpr) - if ok { - qualified = qualified + sel.X.(*ast.Ident).Name + "." - typ = sel.Sel - } - - // Figure out actual type name. - ident, ok := typ.(*ast.Ident) - if !ok { - panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) - } - field = ident.Name - qualified = qualified + field - return -} - -// extractStateTag pulls the relevant state tag. -func extractStateTag(tag *ast.BasicLit) string { - if tag == nil { - return "" - } - if len(tag.Value) < 2 { - return "" - } - return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") -} - -// scanFunctions is a set of functions passed to scanFields. -type scanFunctions struct { - zerovalue func(name string) - normal func(name string) - wait func(name string) - value func(name, typName string) -} - -// scanFields scans the fields of a struct. -// -// Each provided function will be applied to appropriately tagged fields, or -// skipped if nil. -// -// Fields tagged nosave are skipped. -func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { - if ss.Fields.List == nil { - // No fields. - return - } - - // Scan all fields. - for _, field := range ss.Fields.List { - // Calculate the name. - name := "" - if field.Names != nil { - // It's a named field; override. - name = field.Names[0].Name - } else { - // Anonymous types can't be embedded, so we don't need - // to worry about providing a useful name here. - name, _ = resolveTypeName("", field.Type) - } - - // Skip _ fields. - if name == "_" { - continue - } - - // Is this a anonymous struct? If yes, then continue the - // recursion with the given prefix. We don't pay attention to - // any tags on the top-level struct field. - tag := extractStateTag(field.Tag) - if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { - scanFields(anon, name+".", fn) - continue - } - - switch tag { - case "zerovalue": - if fn.zerovalue != nil { - fn.zerovalue(name) - } - - case "": - if fn.normal != nil { - fn.normal(name) - } - - case "wait": - if fn.wait != nil { - fn.wait(name) - } - - case "manual", "nosave", "ignore": - // Do nothing. - - default: - if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { - if fn.value != nil { - fn.value(name, tag[2:len(tag)-1]) - } - } - } - } -} - -func camelCased(name string) string { - return strings.ToUpper(name[:1]) + name[1:] -} - -func main() { - // Parse flags. - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) - flag.PrintDefaults() - } - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - os.Exit(1) - } - if *fullPkg == "" { - fmt.Fprintf(os.Stderr, "Error: package required.") - os.Exit(1) - } - - // Open the output file. - var ( - outputFile *os.File - err error - ) - if *output == "" || *output == "-" { - outputFile = os.Stdout - } else { - outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) - } - defer outputFile.Close() - } - - // Set the statePrefix for below, depending on the import. - statePrefix := "" - if *statePkg != "" { - parts := strings.Split(*statePkg, "/") - statePrefix = parts[len(parts)-1] + "." - } - - // initCalls is dumped at the end. - var initCalls []string - - // Common closures. - emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) - } - emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) - } - - // Automated warning. - fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - - // Emit build tags. - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) - } - - // Emit the package name. - _, pkg := filepath.Split(*fullPkg) - fmt.Fprintf(outputFile, "package %s\n\n", pkg) - - // Emit the imports lazily. - var once sync.Once - maybeEmitImports := func() { - once.Do(func() { - // Emit the imports. - fmt.Fprint(outputFile, "import (\n") - if *statePkg != "" { - fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) - } - if *imports != "" { - for _, i := range strings.Split(*imports, ",") { - fmt.Fprintf(outputFile, " \"%s\"\n", i) - } - } - fmt.Fprint(outputFile, ")\n\n") - }) - } - - files := make([]*ast.File, 0, len(flag.Args())) - - // Parse the input files. - for _, filename := range flag.Args() { - // Parse the file. - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) - if err != nil { - // Not a valid input file? - fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) - os.Exit(1) - } - - files = append(files, f) - } - - type method struct { - receiver string - name string - } - - // Search for and add all methods with a pointer receiver and no other - // arguments to a set. We support auto-detecting the existence of - // several different methods with this signature. - simpleMethods := map[method]struct{}{} - for _, f := range files { - - // Go over all functions. - for _, decl := range f.Decls { - d, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - if d.Name == nil || d.Recv == nil || d.Type == nil { - // Not a named method. - continue - } - if len(d.Recv.List) != 1 { - // Wrong number of receivers? - continue - } - if d.Type.Params != nil && len(d.Type.Params.List) != 0 { - // Has argument(s). - continue - } - if d.Type.Results != nil && len(d.Type.Results.List) != 0 { - // Has return(s). - continue - } - - pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) - if !ok { - // Not a pointer receiver. - continue - } - - t, ok := pt.X.(*ast.Ident) - if !ok { - // This shouldn't happen with valid Go. - continue - } - - simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} - } - } - - for _, f := range files { - // Go over all named types. - for _, decl := range f.Decls { - d, ok := decl.(*ast.GenDecl) - if !ok || d.Tok != token.TYPE { - continue - } - - // Only generate code for types marked "// +stateify - // savable" in one of the proceeding comment lines. If - // the line is marked "// +stateify type" then only - // generate type information and register the type. - if d.Doc == nil { - continue - } - var ( - generateTypeInfo = false - generateSaverLoader = false - ) - for _, l := range d.Doc.List { - if l.Text == "// +stateify savable" { - generateTypeInfo = true - generateSaverLoader = true - break - } - if l.Text == "// +stateify type" { - generateTypeInfo = true - } - } - if !generateTypeInfo && !generateSaverLoader { - continue - } - - for _, gs := range d.Specs { - ts := gs.(*ast.TypeSpec) - switch x := ts.Type.(type) { - case *ast.StructType: - maybeEmitImports() - - // Record the slot for each field. - fieldCount := 0 - fields := make(map[string]int) - emitField := func(name string) { - fmt.Fprintf(outputFile, " \"%s\",\n", name) - fields[name] = fieldCount - fieldCount++ - } - emitFieldValue := func(name string, _ string) { - emitField(name) - } - emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) - } - emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) - } - emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) - } - emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) - } - emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) - } - - // Generate the type name method. - fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) - fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) - fmt.Fprintf(outputFile, "}\n\n") - - // Generate the fields method. - fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) - fmt.Fprintf(outputFile, " return []string{\n") - scanFields(x, "", scanFunctions{ - normal: emitField, - wait: emitField, - value: emitFieldValue, - }) - fmt.Fprintf(outputFile, " }\n") - fmt.Fprintf(outputFile, "}\n\n") - - // Define beforeSave if a definition was not found. This prevents - // the code from compiling if a custom beforeSave was defined in a - // file not provided to this binary and prevents inherited methods - // from being called multiple times by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) - } - - // Generate the save method. - // - // N.B. For historical reasons, we perform the value saves first, - // and perform the value loads last. There should be no dependency - // on this specific behavior, but the ability to specify slots - // allows a manual implementation to be order-dependent. - if generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") - scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) - scanFields(x, "", scanFunctions{value: emitSaveValue}) - scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) - fmt.Fprintf(outputFile, "}\n\n") - } - - // Define afterLoad if a definition was not found. We do this for - // the same reason that we do it for beforeSave. - _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] - if !hasAfterLoad && generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) - } - - // Generate the load method. - // - // N.B. See the comment above for the save method. - if generateSaverLoader { - fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) - scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) - scanFields(x, "", scanFunctions{value: emitLoadValue}) - if hasAfterLoad { - // The call to afterLoad is made conditionally, because when - // AfterLoad is called, the object encodes a dependency on - // referred objects (i.e. fields). This means that afterLoad - // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") - } - fmt.Fprintf(outputFile, "}\n\n") - } - - // Add to our registration. - emitRegister(ts.Name.Name) - - case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: - maybeEmitImports() - - // Generate the info methods. - fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) - fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) - fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) - fmt.Fprintf(outputFile, " return nil\n") - fmt.Fprintf(outputFile, "}\n\n") - - // See above. - emitRegister(ts.Name.Name) - } - } - } - } - - if len(initCalls) > 0 { - // Emit the init() function. - fmt.Fprintf(outputFile, "func init() {\n") - for _, ic := range initCalls { - fmt.Fprintf(outputFile, " %s\n", ic) - } - fmt.Fprintf(outputFile, "}\n") - } -} |