diff options
author | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
---|---|---|
committer | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
commit | ac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (patch) | |
tree | 0cbc5018e8807421d701d190dc20525726c7ca76 /tools/go_stateify/main.go | |
parent | 352ae1022ce19de28fc72e034cc469872ad79d06 (diff) | |
parent | 6d0c5803d557d453f15ac6f683697eeb46dab680 (diff) |
Merge branch 'master' into ip-forwarding
- Merges aleksej-paschenko's with HEAD
- Adds vfs2 support for ip_forward
Diffstat (limited to 'tools/go_stateify/main.go')
-rw-r--r-- | tools/go_stateify/main.go | 200 |
1 files changed, 129 insertions, 71 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index db7a7107b..4f6ed208a 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -23,13 +23,16 @@ import ( "go/parser" "go/token" "os" + "path/filepath" "reflect" "strings" "sync" + + "gvisor.dev/gvisor/tools/tags" ) var ( - pkg = flag.String("pkg", "", "output package") + fullPkg = flag.String("fullpkg", "", "fully qualified output package") imports = flag.String("imports", "", "extra imports for the output file") output = flag.String("output", "", "output file") statePkg = flag.String("statepkg", "", "state import package; defaults to empty") @@ -100,7 +103,7 @@ type scanFunctions struct { // skipped if nil. // // Fields tagged nosave are skipped. -func scanFields(ss *ast.StructType, fn scanFunctions) { +func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { if ss.Fields.List == nil { // No fields. return @@ -124,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) { continue } - switch tag := extractStateTag(field.Tag); tag { + // Is this a anonymous struct? If yes, then continue the + // recursion with the given prefix. We don't pay attention to + // any tags on the top-level struct field. + tag := extractStateTag(field.Tag) + if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { + scanFields(anon, name+".", fn) + continue + } + + switch tag { case "zerovalue": if fn.zerovalue != nil { fn.zerovalue(name) @@ -168,7 +180,7 @@ func main() { flag.Usage() os.Exit(1) } - if *pkg == "" { + if *fullPkg == "" { fmt.Fprintf(os.Stderr, "Error: package required.") os.Exit(1) } @@ -198,33 +210,25 @@ func main() { // initCalls is dumped at the end. var initCalls []string - // Declare our emission closures. + // Common closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) - } - emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) - } - emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name) - } - emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name) - } - emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) } - emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) + + // Automated warning. + fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") + + // Emit build tags. + if t := tags.Aggregate(flag.Args()); len(t) > 0 { + fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) } // Emit the package name. - fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - fmt.Fprintf(outputFile, "package %s\n\n", *pkg) + _, pkg := filepath.Split(*fullPkg) + fmt.Fprintf(outputFile, "package %s\n\n", pkg) // Emit the imports lazily. var once sync.Once @@ -256,6 +260,7 @@ func main() { fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) os.Exit(1) } + files = append(files, f) } @@ -317,87 +322,140 @@ func main() { continue } - // Only generate code for types marked - // "// +stateify savable" in one of the proceeding - // comment lines. + // Only generate code for types marked "// +stateify + // savable" in one of the proceeding comment lines. If + // the line is marked "// +stateify type" then only + // generate type information and register the type. if d.Doc == nil { continue } - savable := false + var ( + generateTypeInfo = false + generateSaverLoader = false + ) for _, l := range d.Doc.List { if l.Text == "// +stateify savable" { - savable = true + generateTypeInfo = true + generateSaverLoader = true break } + if l.Text == "// +stateify type" { + generateTypeInfo = true + } } - if !savable { + if !generateTypeInfo && !generateSaverLoader { continue } for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) - switch ts.Type.(type) { - case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr: - // Don't register. - break + switch x := ts.Type.(type) { case *ast.StructType: maybeEmitImports() - ss := ts.Type.(*ast.StructType) + // Record the slot for each field. + fieldCount := 0 + fields := make(map[string]int) + emitField := func(name string) { + fmt.Fprintf(outputFile, " \"%s\",\n", name) + fields[name] = fieldCount + fieldCount++ + } + emitFieldValue := func(name string, _ string) { + emitField(name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + } + + // Generate the type name method. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + + // Generate the fields method. + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return []string{\n") + scanFields(x, "", scanFunctions{ + normal: emitField, + wait: emitField, + value: emitFieldValue, + }) + fmt.Fprintf(outputFile, " }\n") + fmt.Fprintf(outputFile, "}\n\n") - // Define beforeSave if a definition was not found. This - // prevents the code from compiling if a custom beforeSave - // was defined in a file not provided to this binary and - // prevents inherited methods from being called multiple times - // by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name) + // Define beforeSave if a definition was not found. This prevents + // the code from compiling if a custom beforeSave was defined in a + // file not provided to this binary and prevents inherited methods + // from being called multiple times by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) } // Generate the save method. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") - scanFields(ss, scanFunctions{zerovalue: emitZeroCheck}) - scanFields(ss, scanFunctions{value: emitSaveValue}) - scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave}) - fmt.Fprintf(outputFile, "}\n\n") + // + // N.B. For historical reasons, we perform the value saves first, + // and perform the value loads last. There should be no dependency + // on this specific behavior, but the ability to specify slots + // allows a manual implementation to be order-dependent. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) + scanFields(x, "", scanFunctions{value: emitSaveValue}) + scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + } - // Define afterLoad if a definition was not found. We do this - // for the same reason that we do it for beforeSave. + // Define afterLoad if a definition was not found. We do this for + // the same reason that we do it for beforeSave. _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] - if !hasAfterLoad { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name) + if !hasAfterLoad && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) } // Generate the load method. // - // Note that the manual loads always follow the - // automated loads. - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait}) - scanFields(ss, scanFunctions{value: emitLoadValue}) - if hasAfterLoad { - // The call to afterLoad is made conditionally, because when - // AfterLoad is called, the object encodes a dependency on - // referred objects (i.e. fields). This means that afterLoad - // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + // N.B. See the comment above for the save method. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(x, "", scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") } - fmt.Fprintf(outputFile, "}\n\n") // Add to our registration. emitRegister(ts.Name.Name) + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: maybeEmitImports() - _, val := resolveTypeName(ts.Name.Name, ts.Type) - - // Dispatch directly. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val) + // Generate the info methods. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val) + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return nil\n") fmt.Fprintf(outputFile, "}\n\n") // See above. |