summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_stateify/main.go')
-rw-r--r--tools/go_stateify/main.go39
1 files changed, 20 insertions, 19 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 4f6ed208a..4ec9fbf89 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -214,9 +214,6 @@ func main() {
emitRegister := func(name string) {
initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
}
- emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
- }
// Automated warning.
fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
@@ -349,6 +346,7 @@ func main() {
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
+ letter := strings.ToLower(ts.Name.Name[:1])
switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
@@ -365,29 +363,32 @@ func main() {
emitField(name)
}
emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, letter, camelCased(name), typName)
}
emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], letter, name)
}
emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], letter, name)
}
emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
- fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, letter, camelCased(name))
+ fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
}
emitSave := func(name string) {
- fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], letter, name)
+ }
+ emitZeroCheck := func(name string) {
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, letter, name, statePrefix, name, letter, name)
}
// Generate the type name method.
- fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
// Generate the fields method.
- fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
fmt.Fprintf(outputFile, " return []string{\n")
scanFields(x, "", scanFunctions{
normal: emitField,
@@ -402,7 +403,7 @@ func main() {
// file not provided to this binary and prevents inherited methods
// from being called multiple times by overriding them.
if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", letter, ts.Name.Name)
}
// Generate the save method.
@@ -412,8 +413,8 @@ func main() {
// on this specific behavior, but the ability to specify slots
// allows a manual implementation to be order-dependent.
if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", letter, ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " %s.beforeSave()\n", letter)
scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
scanFields(x, "", scanFunctions{value: emitSaveValue})
scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
@@ -424,14 +425,14 @@ func main() {
// the same reason that we do it for beforeSave.
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
if !hasAfterLoad && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", letter, ts.Name.Name)
}
// Generate the load method.
//
// N.B. See the comment above for the save method.
if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", letter, ts.Name.Name, statePrefix)
scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
scanFields(x, "", scanFunctions{value: emitLoadValue})
if hasAfterLoad {
@@ -439,7 +440,7 @@ func main() {
// AfterLoad is called, the object encodes a dependency on
// referred objects (i.e. fields). This means that afterLoad
// will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", letter)
}
fmt.Fprintf(outputFile, "}\n\n")
}
@@ -451,10 +452,10 @@ func main() {
maybeEmitImports()
// Generate the info methods.
- fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")