summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_stateify')
-rw-r--r--tools/go_stateify/BUILD8
-rw-r--r--tools/go_stateify/main.go182
2 files changed, 121 insertions, 69 deletions
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
index 503cdf2e5..913558b4e 100644
--- a/tools/go_stateify/BUILD
+++ b/tools/go_stateify/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_binary")
+load("//tools:defs.bzl", "bzl_library", "go_binary")
package(licenses = ["notice"])
@@ -8,3 +8,9 @@ go_binary(
visibility = ["//:sandbox"],
deps = ["//tools/tags"],
)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 309ee9c21..4f6ed208a 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -103,7 +103,7 @@ type scanFunctions struct {
// skipped if nil.
//
// Fields tagged nosave are skipped.
-func scanFields(ss *ast.StructType, fn scanFunctions) {
+func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
if ss.Fields.List == nil {
// No fields.
return
@@ -127,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) {
continue
}
- switch tag := extractStateTag(field.Tag); tag {
+ // Is this a anonymous struct? If yes, then continue the
+ // recursion with the given prefix. We don't pay attention to
+ // any tags on the top-level struct field.
+ tag := extractStateTag(field.Tag)
+ if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
+ scanFields(anon, name+".", fn)
+ continue
+ }
+
+ switch tag {
case "zerovalue":
if fn.zerovalue != nil {
fn.zerovalue(name)
@@ -201,28 +210,12 @@ func main() {
// initCalls is dumped at the end.
var initCalls []string
- // Declare our emission closures.
+ // Common closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name)
- }
- emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
- }
- emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name)
- }
- emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name)
- }
- emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
- fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name)
- }
- emitSave := func(name string) {
- fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
}
// Automated warning.
@@ -329,87 +322,140 @@ func main() {
continue
}
- // Only generate code for types marked
- // "// +stateify savable" in one of the proceeding
- // comment lines.
+ // Only generate code for types marked "// +stateify
+ // savable" in one of the proceeding comment lines. If
+ // the line is marked "// +stateify type" then only
+ // generate type information and register the type.
if d.Doc == nil {
continue
}
- savable := false
+ var (
+ generateTypeInfo = false
+ generateSaverLoader = false
+ )
for _, l := range d.Doc.List {
if l.Text == "// +stateify savable" {
- savable = true
+ generateTypeInfo = true
+ generateSaverLoader = true
break
}
+ if l.Text == "// +stateify type" {
+ generateTypeInfo = true
+ }
}
- if !savable {
+ if !generateTypeInfo && !generateSaverLoader {
continue
}
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
- switch ts.Type.(type) {
- case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr:
- // Don't register.
- break
+ switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
- ss := ts.Type.(*ast.StructType)
+ // Record the slot for each field.
+ fieldCount := 0
+ fields := make(map[string]int)
+ emitField := func(name string) {
+ fmt.Fprintf(outputFile, " \"%s\",\n", name)
+ fields[name] = fieldCount
+ fieldCount++
+ }
+ emitFieldValue := func(name string, _ string) {
+ emitField(name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ }
+
+ // Generate the type name method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Generate the fields method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return []string{\n")
+ scanFields(x, "", scanFunctions{
+ normal: emitField,
+ wait: emitField,
+ value: emitFieldValue,
+ })
+ fmt.Fprintf(outputFile, " }\n")
+ fmt.Fprintf(outputFile, "}\n\n")
- // Define beforeSave if a definition was not found. This
- // prevents the code from compiling if a custom beforeSave
- // was defined in a file not provided to this binary and
- // prevents inherited methods from being called multiple times
- // by overriding them.
- if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok {
- fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name)
+ // Define beforeSave if a definition was not found. This prevents
+ // the code from compiling if a custom beforeSave was defined in a
+ // file not provided to this binary and prevents inherited methods
+ // from being called multiple times by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
}
// Generate the save method.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " x.beforeSave()\n")
- scanFields(ss, scanFunctions{zerovalue: emitZeroCheck})
- scanFields(ss, scanFunctions{value: emitSaveValue})
- scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave})
- fmt.Fprintf(outputFile, "}\n\n")
+ //
+ // N.B. For historical reasons, we perform the value saves first,
+ // and perform the value loads last. There should be no dependency
+ // on this specific behavior, but the ability to specify slots
+ // allows a manual implementation to be order-dependent.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(x, "", scanFunctions{value: emitSaveValue})
+ scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
- // Define afterLoad if a definition was not found. We do this
- // for the same reason that we do it for beforeSave.
+ // Define afterLoad if a definition was not found. We do this for
+ // the same reason that we do it for beforeSave.
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
- if !hasAfterLoad {
- fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name)
+ if !hasAfterLoad && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
}
// Generate the load method.
//
- // Note that the manual loads always follow the
- // automated loads.
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait})
- scanFields(ss, scanFunctions{value: emitLoadValue})
- if hasAfterLoad {
- // The call to afterLoad is made conditionally, because when
- // AfterLoad is called, the object encodes a dependency on
- // referred objects (i.e. fields). This means that afterLoad
- // will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ // N.B. See the comment above for the save method.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(x, "", scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
}
- fmt.Fprintf(outputFile, "}\n\n")
// Add to our registration.
emitRegister(ts.Name.Name)
+
case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
maybeEmitImports()
- _, val := resolveTypeName(ts.Name.Name, ts.Type)
-
- // Dispatch directly.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val)
+ // Generate the info methods.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val)
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")
// See above.