summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_stateify/main.go')
-rw-r--r--tools/go_stateify/main.go470
1 files changed, 0 insertions, 470 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
deleted file mode 100644
index e1de12e25..000000000
--- a/tools/go_stateify/main.go
+++ /dev/null
@@ -1,470 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Stateify provides a simple way to generate Load/Save methods based on
-// existing types and struct tags.
-package main
-
-import (
- "flag"
- "fmt"
- "go/ast"
- "go/parser"
- "go/token"
- "os"
- "path/filepath"
- "reflect"
- "strings"
- "sync"
-
- "gvisor.dev/gvisor/tools/tags"
-)
-
-var (
- fullPkg = flag.String("fullpkg", "", "fully qualified output package")
- imports = flag.String("imports", "", "extra imports for the output file")
- output = flag.String("output", "", "output file")
- statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
-)
-
-// resolveTypeName returns a qualified type name.
-func resolveTypeName(typ ast.Expr) (field string, qualified string) {
- for done := false; !done; {
- // Resolve star expressions.
- switch rs := typ.(type) {
- case *ast.StarExpr:
- qualified += "*"
- typ = rs.X
- case *ast.ArrayType:
- if rs.Len == nil {
- // Slice type declaration.
- qualified += "[]"
- } else {
- // Array type declaration.
- qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]"
- }
- typ = rs.Elt
- default:
- // No more descent.
- done = true
- }
- }
-
- // Resolve a package selector.
- sel, ok := typ.(*ast.SelectorExpr)
- if ok {
- qualified = qualified + sel.X.(*ast.Ident).Name + "."
- typ = sel.Sel
- }
-
- // Figure out actual type name.
- field = typ.(*ast.Ident).Name
- qualified = qualified + field
- return
-}
-
-// extractStateTag pulls the relevant state tag.
-func extractStateTag(tag *ast.BasicLit) string {
- if tag == nil {
- return ""
- }
- if len(tag.Value) < 2 {
- return ""
- }
- return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state")
-}
-
-// scanFunctions is a set of functions passed to scanFields.
-type scanFunctions struct {
- zerovalue func(name string)
- normal func(name string)
- wait func(name string)
- value func(name, typName string)
-}
-
-// scanFields scans the fields of a struct.
-//
-// Each provided function will be applied to appropriately tagged fields, or
-// skipped if nil.
-//
-// Fields tagged nosave are skipped.
-func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
- if ss.Fields.List == nil {
- // No fields.
- return
- }
-
- // Scan all fields.
- for _, field := range ss.Fields.List {
- // Calculate the name.
- name := ""
- if field.Names != nil {
- // It's a named field; override.
- name = field.Names[0].Name
- } else {
- // Anonymous types can't be embedded, so we don't need
- // to worry about providing a useful name here.
- name, _ = resolveTypeName(field.Type)
- }
-
- // Skip _ fields.
- if name == "_" {
- continue
- }
-
- // Is this a anonymous struct? If yes, then continue the
- // recursion with the given prefix. We don't pay attention to
- // any tags on the top-level struct field.
- tag := extractStateTag(field.Tag)
- if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
- scanFields(anon, name+".", fn)
- continue
- }
-
- switch tag {
- case "zerovalue":
- if fn.zerovalue != nil {
- fn.zerovalue(name)
- }
-
- case "":
- if fn.normal != nil {
- fn.normal(name)
- }
-
- case "wait":
- if fn.wait != nil {
- fn.wait(name)
- }
-
- case "manual", "nosave", "ignore":
- // Do nothing.
-
- default:
- if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") {
- if fn.value != nil {
- fn.value(name, tag[2:len(tag)-1])
- }
- }
- }
- }
-}
-
-func camelCased(name string) string {
- return strings.ToUpper(name[:1]) + name[1:]
-}
-
-func main() {
- // Parse flags.
- flag.Usage = func() {
- fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
- flag.PrintDefaults()
- }
- flag.Parse()
- if len(flag.Args()) == 0 {
- flag.Usage()
- os.Exit(1)
- }
- if *fullPkg == "" {
- fmt.Fprintf(os.Stderr, "Error: package required.")
- os.Exit(1)
- }
-
- // Open the output file.
- var (
- outputFile *os.File
- err error
- )
- if *output == "" || *output == "-" {
- outputFile = os.Stdout
- } else {
- outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
- if err != nil {
- fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err)
- }
- defer outputFile.Close()
- }
-
- // Set the statePrefix for below, depending on the import.
- statePrefix := ""
- if *statePkg != "" {
- parts := strings.Split(*statePkg, "/")
- statePrefix = parts[len(parts)-1] + "."
- }
-
- // initCalls is dumped at the end.
- var initCalls []string
-
- // Common closures.
- emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
- }
-
- // Automated warning.
- fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
-
- // Emit build tags.
- if t := tags.Aggregate(flag.Args()); len(t) > 0 {
- fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
- }
-
- // Emit the package name.
- _, pkg := filepath.Split(*fullPkg)
- fmt.Fprintf(outputFile, "package %s\n\n", pkg)
-
- // Emit the imports lazily.
- var once sync.Once
- maybeEmitImports := func() {
- once.Do(func() {
- // Emit the imports.
- fmt.Fprint(outputFile, "import (\n")
- if *statePkg != "" {
- fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg)
- }
- if *imports != "" {
- for _, i := range strings.Split(*imports, ",") {
- fmt.Fprintf(outputFile, " \"%s\"\n", i)
- }
- }
- fmt.Fprint(outputFile, ")\n\n")
- })
- }
-
- files := make([]*ast.File, 0, len(flag.Args()))
-
- // Parse the input files.
- for _, filename := range flag.Args() {
- // Parse the file.
- fset := token.NewFileSet()
- f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
- if err != nil {
- // Not a valid input file?
- fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
- os.Exit(1)
- }
-
- files = append(files, f)
- }
-
- type method struct {
- typeName string
- methodName string
- }
-
- // Search for and add all method to a set. We auto-detecting several
- // different methods (and insert them if we don't find them, in order
- // to ensure that expectations match reality).
- //
- // While we do this, figure out the right receiver name. If there are
- // multiple distinct receivers, then we will just pick the last one.
- simpleMethods := make(map[method]struct{})
- receiverNames := make(map[string]string)
- for _, f := range files {
- // Go over all functions.
- for _, decl := range f.Decls {
- d, ok := decl.(*ast.FuncDecl)
- if !ok {
- continue
- }
- if d.Recv == nil || len(d.Recv.List) != 1 {
- // Not a named method.
- continue
- }
-
- // Save the method and the receiver.
- name, _ := resolveTypeName(d.Recv.List[0].Type)
- simpleMethods[method{
- typeName: name,
- methodName: d.Name.Name,
- }] = struct{}{}
- if len(d.Recv.List[0].Names) > 0 {
- receiverNames[name] = d.Recv.List[0].Names[0].Name
- }
- }
- }
-
- for _, f := range files {
- // Go over all named types.
- for _, decl := range f.Decls {
- d, ok := decl.(*ast.GenDecl)
- if !ok || d.Tok != token.TYPE {
- continue
- }
-
- // Only generate code for types marked "// +stateify
- // savable" in one of the proceeding comment lines. If
- // the line is marked "// +stateify type" then only
- // generate type information and register the type.
- if d.Doc == nil {
- continue
- }
- var (
- generateTypeInfo = false
- generateSaverLoader = false
- )
- for _, l := range d.Doc.List {
- if l.Text == "// +stateify savable" {
- generateTypeInfo = true
- generateSaverLoader = true
- break
- }
- if l.Text == "// +stateify type" {
- generateTypeInfo = true
- }
- }
- if !generateTypeInfo && !generateSaverLoader {
- continue
- }
-
- for _, gs := range d.Specs {
- ts := gs.(*ast.TypeSpec)
- recv, ok := receiverNames[ts.Name.Name]
- if !ok {
- // Maybe no methods were defined?
- recv = strings.ToLower(ts.Name.Name[:1])
- }
- switch x := ts.Type.(type) {
- case *ast.StructType:
- maybeEmitImports()
-
- // Record the slot for each field.
- fieldCount := 0
- fields := make(map[string]int)
- emitField := func(name string) {
- fmt.Fprintf(outputFile, " \"%s\",\n", name)
- fields[name] = fieldCount
- fieldCount++
- }
- emitFieldValue := func(name string, _ string) {
- emitField(name)
- }
- emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName)
- }
- emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name)
- }
- emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name)
- }
- emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name))
- fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
- }
- emitSave := func(name string) {
- fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name)
- }
- emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name)
- }
-
- // Generate the type name method.
- fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
- fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
- fmt.Fprintf(outputFile, "}\n\n")
-
- // Generate the fields method.
- fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
- fmt.Fprintf(outputFile, " return []string{\n")
- scanFields(x, "", scanFunctions{
- normal: emitField,
- wait: emitField,
- value: emitFieldValue,
- })
- fmt.Fprintf(outputFile, " }\n")
- fmt.Fprintf(outputFile, "}\n\n")
-
- // Define beforeSave if a definition was not found. This prevents
- // the code from compiling if a custom beforeSave was defined in a
- // file not provided to this binary and prevents inherited methods
- // from being called multiple times by overriding them.
- if _, ok := simpleMethods[method{
- typeName: ts.Name.Name,
- methodName: "beforeSave",
- }]; !ok && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name)
- }
-
- // Generate the save method.
- //
- // N.B. For historical reasons, we perform the value saves first,
- // and perform the value loads last. There should be no dependency
- // on this specific behavior, but the ability to specify slots
- // allows a manual implementation to be order-dependent.
- if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv)
- scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
- scanFields(x, "", scanFunctions{value: emitSaveValue})
- scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
- fmt.Fprintf(outputFile, "}\n\n")
- }
-
- // Define afterLoad if a definition was not found. We do this for
- // the same reason that we do it for beforeSave.
- _, hasAfterLoad := simpleMethods[method{
- typeName: ts.Name.Name,
- methodName: "afterLoad",
- }]
- if !hasAfterLoad && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name)
- }
-
- // Generate the load method.
- //
- // N.B. See the comment above for the save method.
- if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
- scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
- scanFields(x, "", scanFunctions{value: emitLoadValue})
- if hasAfterLoad {
- // The call to afterLoad is made conditionally, because when
- // AfterLoad is called, the object encodes a dependency on
- // referred objects (i.e. fields). This means that afterLoad
- // will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv)
- }
- fmt.Fprintf(outputFile, "}\n\n")
- }
-
- // Add to our registration.
- emitRegister(ts.Name.Name)
-
- case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
- maybeEmitImports()
-
- // Generate the info methods.
- fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
- fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
- fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
- fmt.Fprintf(outputFile, " return nil\n")
- fmt.Fprintf(outputFile, "}\n\n")
-
- // See above.
- emitRegister(ts.Name.Name)
- }
- }
- }
- }
-
- if len(initCalls) > 0 {
- // Emit the init() function.
- fmt.Fprintf(outputFile, "func init() {\n")
- for _, ic := range initCalls {
- fmt.Fprintf(outputFile, " %s\n", ic)
- }
- fmt.Fprintf(outputFile, "}\n")
- }
-}