diff options
author | Adin Scannell <ascannell@google.com> | 2020-06-23 23:32:23 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-06-23 23:34:06 -0700 |
commit | 364ac92baf83f2352f78b718090472639bd92a76 (patch) | |
tree | 306c99770deb6872c04fa0a6c29c5c6f322b9b55 /tools/go_stateify | |
parent | 399c52888db609296fd1341ed0daa994ad2d02b0 (diff) |
Support for saving pointers to fields in the state package.
Previously, it was not possible to encode/decode an object graph which
contained a pointer to a field within another type. This was because the
encoder was previously unable to disambiguate a pointer to an object and a
pointer within the object.
This CL remedies this by constructing an address map tracking the full memory
range object occupy. The encoded Refvalue message has been extended to allow
references to children objects within another object. Because the encoding
process may learn about object structure over time, we cannot encode any
objects under the entire graph has been generated.
This CL also updates the state package to use standard interfaces intead of
reflection-based dispatch in order to improve performance overall. This
includes a custom wire protocol to significantly reduce the number of
allocations and take advantage of structure packing.
As part of these changes, there are a small number of minor changes in other
places of the code base:
* The lists used during encoding are changed to use intrusive lists with the
objectEncodeState directly, which required that the ilist Len() method is
updated to work properly with the ElementMapper mechanism.
* A bug is fixed in the list code wherein Remove() called on an element that is
already removed can corrupt the list (removing the element if there's only a
single element). Now the behavior is correct.
* Standard error wrapping is introduced.
* Compressio was updated to implement the new wire.Reader and wire.Writer
inteface methods directly. The lack of a ReadByte and WriteByte caused issues
not due to interface dispatch, but because underlying slices for a Read or
Write call through an interface would always escape to the heap!
* Statify has been updated to support the new APIs.
See README.md for a description of how the new mechanism works.
PiperOrigin-RevId: 318010298
Diffstat (limited to 'tools/go_stateify')
-rw-r--r-- | tools/go_stateify/main.go | 182 |
1 files changed, 114 insertions, 68 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 309ee9c21..4f6ed208a 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -103,7 +103,7 @@ type scanFunctions struct { // skipped if nil. // // Fields tagged nosave are skipped. -func scanFields(ss *ast.StructType, fn scanFunctions) { +func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { if ss.Fields.List == nil { // No fields. return @@ -127,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) { continue } - switch tag := extractStateTag(field.Tag); tag { + // Is this a anonymous struct? If yes, then continue the + // recursion with the given prefix. We don't pay attention to + // any tags on the top-level struct field. + tag := extractStateTag(field.Tag) + if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { + scanFields(anon, name+".", fn) + continue + } + + switch tag { case "zerovalue": if fn.zerovalue != nil { fn.zerovalue(name) @@ -201,28 +210,12 @@ func main() { // initCalls is dumped at the end. var initCalls []string - // Declare our emission closures. + // Common closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name) - } - emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) - } - emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name) - } - emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name) - } - emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name) - } - emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) } // Automated warning. @@ -329,87 +322,140 @@ func main() { continue } - // Only generate code for types marked - // "// +stateify savable" in one of the proceeding - // comment lines. + // Only generate code for types marked "// +stateify + // savable" in one of the proceeding comment lines. If + // the line is marked "// +stateify type" then only + // generate type information and register the type. if d.Doc == nil { continue } - savable := false + var ( + generateTypeInfo = false + generateSaverLoader = false + ) for _, l := range d.Doc.List { if l.Text == "// +stateify savable" { - savable = true + generateTypeInfo = true + generateSaverLoader = true break } + if l.Text == "// +stateify type" { + generateTypeInfo = true + } } - if !savable { + if !generateTypeInfo && !generateSaverLoader { continue } for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) - switch ts.Type.(type) { - case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr: - // Don't register. - break + switch x := ts.Type.(type) { case *ast.StructType: maybeEmitImports() - ss := ts.Type.(*ast.StructType) + // Record the slot for each field. + fieldCount := 0 + fields := make(map[string]int) + emitField := func(name string) { + fmt.Fprintf(outputFile, " \"%s\",\n", name) + fields[name] = fieldCount + fieldCount++ + } + emitFieldValue := func(name string, _ string) { + emitField(name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + } + + // Generate the type name method. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + + // Generate the fields method. + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return []string{\n") + scanFields(x, "", scanFunctions{ + normal: emitField, + wait: emitField, + value: emitFieldValue, + }) + fmt.Fprintf(outputFile, " }\n") + fmt.Fprintf(outputFile, "}\n\n") - // Define beforeSave if a definition was not found. This - // prevents the code from compiling if a custom beforeSave - // was defined in a file not provided to this binary and - // prevents inherited methods from being called multiple times - // by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name) + // Define beforeSave if a definition was not found. This prevents + // the code from compiling if a custom beforeSave was defined in a + // file not provided to this binary and prevents inherited methods + // from being called multiple times by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) } // Generate the save method. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") - scanFields(ss, scanFunctions{zerovalue: emitZeroCheck}) - scanFields(ss, scanFunctions{value: emitSaveValue}) - scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave}) - fmt.Fprintf(outputFile, "}\n\n") + // + // N.B. For historical reasons, we perform the value saves first, + // and perform the value loads last. There should be no dependency + // on this specific behavior, but the ability to specify slots + // allows a manual implementation to be order-dependent. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) + scanFields(x, "", scanFunctions{value: emitSaveValue}) + scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + } - // Define afterLoad if a definition was not found. We do this - // for the same reason that we do it for beforeSave. + // Define afterLoad if a definition was not found. We do this for + // the same reason that we do it for beforeSave. _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] - if !hasAfterLoad { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name) + if !hasAfterLoad && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) } // Generate the load method. // - // Note that the manual loads always follow the - // automated loads. - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait}) - scanFields(ss, scanFunctions{value: emitLoadValue}) - if hasAfterLoad { - // The call to afterLoad is made conditionally, because when - // AfterLoad is called, the object encodes a dependency on - // referred objects (i.e. fields). This means that afterLoad - // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + // N.B. See the comment above for the save method. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(x, "", scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") } - fmt.Fprintf(outputFile, "}\n\n") // Add to our registration. emitRegister(ts.Name.Name) + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: maybeEmitImports() - _, val := resolveTypeName(ts.Name.Name, ts.Type) - - // Dispatch directly. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val) + // Generate the info methods. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val) + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return nil\n") fmt.Fprintf(outputFile, "}\n\n") // See above. |