summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2020-10-09 12:11:21 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-09 12:14:28 -0700
commit743327817faa1aa46ff3b31f74a0c5c2d047d65a (patch)
treea48f2f440fa3d22c7a124c64dd3d4b8a2f07d9cb /tools/go_stateify
parent257703c050e5901aeb3734f200f5a6b41856b4d9 (diff)
Infer receiver name for stateify.
PiperOrigin-RevId: 336340035
Diffstat (limited to 'tools/go_stateify')
-rw-r--r--tools/go_stateify/main.go109
1 files changed, 51 insertions, 58 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 4ec9fbf89..e1de12e25 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -39,7 +39,7 @@ var (
)
// resolveTypeName returns a qualified type name.
-func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) {
+func resolveTypeName(typ ast.Expr) (field string, qualified string) {
for done := false; !done; {
// Resolve star expressions.
switch rs := typ.(type) {
@@ -69,11 +69,7 @@ func resolveTypeName(name string, typ ast.Expr) (field string, qualified string)
}
// Figure out actual type name.
- ident, ok := typ.(*ast.Ident)
- if !ok {
- panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name))
- }
- field = ident.Name
+ field = typ.(*ast.Ident).Name
qualified = qualified + field
return
}
@@ -119,7 +115,7 @@ func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
} else {
// Anonymous types can't be embedded, so we don't need
// to worry about providing a useful name here.
- name, _ = resolveTypeName("", field.Type)
+ name, _ = resolveTypeName(field.Type)
}
// Skip _ fields.
@@ -262,52 +258,39 @@ func main() {
}
type method struct {
- receiver string
- name string
+ typeName string
+ methodName string
}
- // Search for and add all methods with a pointer receiver and no other
- // arguments to a set. We support auto-detecting the existence of
- // several different methods with this signature.
- simpleMethods := map[method]struct{}{}
+ // Search for and add all method to a set. We auto-detecting several
+ // different methods (and insert them if we don't find them, in order
+ // to ensure that expectations match reality).
+ //
+ // While we do this, figure out the right receiver name. If there are
+ // multiple distinct receivers, then we will just pick the last one.
+ simpleMethods := make(map[method]struct{})
+ receiverNames := make(map[string]string)
for _, f := range files {
-
// Go over all functions.
for _, decl := range f.Decls {
d, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
- if d.Name == nil || d.Recv == nil || d.Type == nil {
+ if d.Recv == nil || len(d.Recv.List) != 1 {
// Not a named method.
continue
}
- if len(d.Recv.List) != 1 {
- // Wrong number of receivers?
- continue
- }
- if d.Type.Params != nil && len(d.Type.Params.List) != 0 {
- // Has argument(s).
- continue
- }
- if d.Type.Results != nil && len(d.Type.Results.List) != 0 {
- // Has return(s).
- continue
- }
- pt, ok := d.Recv.List[0].Type.(*ast.StarExpr)
- if !ok {
- // Not a pointer receiver.
- continue
- }
-
- t, ok := pt.X.(*ast.Ident)
- if !ok {
- // This shouldn't happen with valid Go.
- continue
+ // Save the method and the receiver.
+ name, _ := resolveTypeName(d.Recv.List[0].Type)
+ simpleMethods[method{
+ typeName: name,
+ methodName: d.Name.Name,
+ }] = struct{}{}
+ if len(d.Recv.List[0].Names) > 0 {
+ receiverNames[name] = d.Recv.List[0].Names[0].Name
}
-
- simpleMethods[method{t.Name, d.Name.Name}] = struct{}{}
}
}
@@ -346,7 +329,11 @@ func main() {
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
- letter := strings.ToLower(ts.Name.Name[:1])
+ recv, ok := receiverNames[ts.Name.Name]
+ if !ok {
+ // Maybe no methods were defined?
+ recv = strings.ToLower(ts.Name.Name[:1])
+ }
switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
@@ -363,32 +350,32 @@ func main() {
emitField(name)
}
emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, letter, camelCased(name), typName)
+ fmt.Fprintf(outputFile, " stateSourceObject.LoadValue(%d, new(%s), func(y interface{}) { %s.load%s(y.(%s)) })\n", fields[name], typName, recv, camelCased(name), typName)
}
emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], letter, name)
+ fmt.Fprintf(outputFile, " stateSourceObject.Load(%d, &%s.%s)\n", fields[name], recv, name)
}
emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], letter, name)
+ fmt.Fprintf(outputFile, " stateSourceObject.LoadWait(%d, &%s.%s)\n", fields[name], recv, name)
}
emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, letter, camelCased(name))
+ fmt.Fprintf(outputFile, " var %sValue %s = %s.save%s()\n", name, typName, recv, camelCased(name))
fmt.Fprintf(outputFile, " stateSinkObject.SaveValue(%d, %sValue)\n", fields[name], name)
}
emitSave := func(name string) {
- fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], letter, name)
+ fmt.Fprintf(outputFile, " stateSinkObject.Save(%d, &%s.%s)\n", fields[name], recv, name)
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, letter, name, statePrefix, name, letter, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&%s.%s) { %sFailf(\"%s is %%#v, expected zero\", &%s.%s) }\n", statePrefix, recv, name, statePrefix, name, recv, name)
}
// Generate the type name method.
- fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
// Generate the fields method.
- fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
fmt.Fprintf(outputFile, " return []string{\n")
scanFields(x, "", scanFunctions{
normal: emitField,
@@ -402,8 +389,11 @@ func main() {
// the code from compiling if a custom beforeSave was defined in a
// file not provided to this binary and prevents inherited methods
// from being called multiple times by overriding them.
- if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", letter, ts.Name.Name)
+ if _, ok := simpleMethods[method{
+ typeName: ts.Name.Name,
+ methodName: "beforeSave",
+ }]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (%s *%s) beforeSave() {}\n\n", recv, ts.Name.Name)
}
// Generate the save method.
@@ -413,8 +403,8 @@ func main() {
// on this specific behavior, but the ability to specify slots
// allows a manual implementation to be order-dependent.
if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", letter, ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " %s.beforeSave()\n", letter)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateSave(stateSinkObject %sSink) {\n", recv, ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " %s.beforeSave()\n", recv)
scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
scanFields(x, "", scanFunctions{value: emitSaveValue})
scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
@@ -423,16 +413,19 @@ func main() {
// Define afterLoad if a definition was not found. We do this for
// the same reason that we do it for beforeSave.
- _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
+ _, hasAfterLoad := simpleMethods[method{
+ typeName: ts.Name.Name,
+ methodName: "afterLoad",
+ }]
if !hasAfterLoad && generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", letter, ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) afterLoad() {}\n\n", recv, ts.Name.Name)
}
// Generate the load method.
//
// N.B. See the comment above for the save method.
if generateSaverLoader {
- fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", letter, ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateLoad(stateSourceObject %sSource) {\n", recv, ts.Name.Name, statePrefix)
scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
scanFields(x, "", scanFunctions{value: emitLoadValue})
if hasAfterLoad {
@@ -440,7 +433,7 @@ func main() {
// AfterLoad is called, the object encodes a dependency on
// referred objects (i.e. fields). This means that afterLoad
// will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", letter)
+ fmt.Fprintf(outputFile, " stateSourceObject.AfterLoad(%s.afterLoad)\n", recv)
}
fmt.Fprintf(outputFile, "}\n\n")
}
@@ -452,10 +445,10 @@ func main() {
maybeEmitImports()
// Generate the info methods.
- fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", letter, ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateTypeName() string {\n", recv, ts.Name.Name)
fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", letter, ts.Name.Name)
+ fmt.Fprintf(outputFile, "func (%s *%s) StateFields() []string {\n", recv, ts.Name.Name)
fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")