diff options
-rw-r--r-- | tools/go_stateify/main.go | 109 |
1 files changed, 51 insertions, 58 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 4ec9fbf89..e1de12e25 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -39,7 +39,7 @@ var ( ) // resolveTypeName returns a qualified type name. -func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { +func resolveTypeName(typ ast.Expr) (field string, qualified string) { for done := false; !done; { // Resolve star expressions. switch rs := typ.(type) { @@ -69,11 +69,7 @@ func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) } // Figure out actual type name. - ident, ok := typ.(*ast.Ident) - if !ok { - panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) - } - field = ident.Name + field = typ.(*ast.Ident).Name qualified = qualified + field return } @@ -119,7 +115,7 @@ func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { } else { // Anonymous types can't be embedded, so we don't need // to worry about providing a useful name here. - name, _ = resolveTypeName("", field.Type) + name, _ = resolveTypeName(field.Type) } // Skip _ fields. @@ -262,52 +258,39 @@ func main() { } type method struct { - receiver string - name string + typeName string + methodName string } - // Search for and add all methods with a pointer receiver and no other - // arguments to a set. We support auto-detecting the existence of - // several different methods with this signature. - simpleMethods := map[method]struct{}{} + // Search for and add all method to a set. We auto-detecting several + // different methods (and insert them if we don't find them, in order + // to ensure that expectations match reality). + // + // While we do this, figure out the right receiver name. If there are + // multiple distinct receivers, then we will just pick the last one. + simpleMethods := make(map[method]struct{}) + receiverNames := make(map[string]string) for _, f := range files { - // Go over all functions. for _, decl := range f.Decls { d, ok := decl.(*ast.FuncDecl) if !ok { continue } - if d.Name == nil || d.Recv == nil || d.Type == nil { + if d.Recv == nil || len(d.Recv.List) != 1 { // Not a named method. continue } - if len(d.Recv.List) != 1 { - // Wrong number of receivers? - continue - } - if d.Type.Params != nil && len(d.Type.Params.List) != 0 { - // Has argument(s). - continue - } - if d.Type.Results != nil && len(d.Type.Results.List) != 0 { - // Has return(s). - continue - } - pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) - if !ok { - // Not a pointer receiver. - continue - } - - t, ok := pt.X.(*ast.Ident) - if !ok { - // This shouldn't happen with valid Go. - continue + // Save the method and the receiver. + name, _ := resolveTypeName(d.Recv.List[0].Type) + simpleMethods[method{ + typeName: name, + methodName: d.Name.Name, + }] = struct{}{} + if len(d.Recv.List[0].Names) > 0 { + receiverNames[name] = d.Recv.List[0].Names[0].Name } - - simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} } } @@ -346,7 +329,11 @@ func main() { for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) - letter := strings.ToLower(ts.Name.Name[:1]) + recv, ok := receiverNames[ts.Name.Name] + if !ok { + // Maybe no methods were defined? + recv = strings.ToLower(ts.Name.Name[:1]) + } switch x := ts.Type.(type) { case *ast.StructType: maybeEmitImports() @@ -363,32 +350,32 @@ func main() { emitField(name) } emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, letter, camelCased(name), typName) + fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName) } emitLoad := func(name string) { - fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], letter, name) + fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name) } emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], letter, name) + fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name) } emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, letter, camelCased(name)) + fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name)) fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name) } emitSave := func(name string) { - fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], letter, name) + fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, letter, name, statePrefix, name, letter, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name) } // Generate the type name method. - fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") // Generate the fields method. - fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return []string{\n") scanFields(x, "", scanFunctions{ normal: emitField, @@ -402,8 +389,11 @@ func main() { // the code from compiling if a custom beforeSave was defined in a // file not provided to this binary and prevents inherited methods // from being called multiple times by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { - fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", letter, ts.Name.Name) + if _, ok := simpleMethods[method{ + typeName: ts.Name.Name, + methodName: "beforeSave", + }]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name) } // Generate the save method. @@ -413,8 +403,8 @@ func main() { // on this specific behavior, but the ability to specify slots // allows a manual implementation to be order-dependent. if generateSaverLoader { - fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", letter, ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " %s.beforeSave()\n", letter) + fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv) scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) scanFields(x, "", scanFunctions{value: emitSaveValue}) scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) @@ -423,16 +413,19 @@ func main() { // Define afterLoad if a definition was not found. We do this for // the same reason that we do it for beforeSave. - _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] + _, hasAfterLoad := simpleMethods[method{ + typeName: ts.Name.Name, + methodName: "afterLoad", + }] if !hasAfterLoad && generateSaverLoader { - fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", letter, ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name) } // Generate the load method. // // N.B. See the comment above for the save method. if generateSaverLoader { - fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", letter, ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix) scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) scanFields(x, "", scanFunctions{value: emitLoadValue}) if hasAfterLoad { @@ -440,7 +433,7 @@ func main() { // AfterLoad is called, the object encodes a dependency on // referred objects (i.e. fields). This means that afterLoad // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", letter) + fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv) } fmt.Fprintf(outputFile, "}\n\n") } @@ -452,10 +445,10 @@ func main() { maybeEmitImports() // Generate the info methods. - fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name) + fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name) fmt.Fprintf(outputFile, " return nil\n") fmt.Fprintf(outputFile, "}\n\n") |