summaryrefslogtreecommitdiffhomepage
path: root/tools/go_stateify
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2020-06-23 23:32:23 -0700
committergVisor bot <gvisor-bot@google.com>2020-06-23 23:34:06 -0700
commit364ac92baf83f2352f78b718090472639bd92a76 (patch)
tree306c99770deb6872c04fa0a6c29c5c6f322b9b55 /tools/go_stateify
parent399c52888db609296fd1341ed0daa994ad2d02b0 (diff)
Support for saving pointers to fields in the state package.
Previously, it was not possible to encode/decode an object graph which contained a pointer to a field within another type. This was because the encoder was previously unable to disambiguate a pointer to an object and a pointer within the object. This CL remedies this by constructing an address map tracking the full memory range object occupy. The encoded Refvalue message has been extended to allow references to children objects within another object. Because the encoding process may learn about object structure over time, we cannot encode any objects under the entire graph has been generated. This CL also updates the state package to use standard interfaces intead of reflection-based dispatch in order to improve performance overall. This includes a custom wire protocol to significantly reduce the number of allocations and take advantage of structure packing. As part of these changes, there are a small number of minor changes in other places of the code base: * The lists used during encoding are changed to use intrusive lists with the objectEncodeState directly, which required that the ilist Len() method is updated to work properly with the ElementMapper mechanism. * A bug is fixed in the list code wherein Remove() called on an element that is already removed can corrupt the list (removing the element if there's only a single element). Now the behavior is correct. * Standard error wrapping is introduced. * Compressio was updated to implement the new wire.Reader and wire.Writer inteface methods directly. The lack of a ReadByte and WriteByte caused issues not due to interface dispatch, but because underlying slices for a Read or Write call through an interface would always escape to the heap! * Statify has been updated to support the new APIs. See README.md for a description of how the new mechanism works. PiperOrigin-RevId: 318010298
Diffstat (limited to 'tools/go_stateify')
-rw-r--r--tools/go_stateify/main.go182
1 files changed, 114 insertions, 68 deletions
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 309ee9c21..4f6ed208a 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -103,7 +103,7 @@ type scanFunctions struct {
// skipped if nil.
//
// Fields tagged nosave are skipped.
-func scanFields(ss *ast.StructType, fn scanFunctions) {
+func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
if ss.Fields.List == nil {
// No fields.
return
@@ -127,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) {
continue
}
- switch tag := extractStateTag(field.Tag); tag {
+ // Is this a anonymous struct? If yes, then continue the
+ // recursion with the given prefix. We don't pay attention to
+ // any tags on the top-level struct field.
+ tag := extractStateTag(field.Tag)
+ if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
+ scanFields(anon, name+".", fn)
+ continue
+ }
+
+ switch tag {
case "zerovalue":
if fn.zerovalue != nil {
fn.zerovalue(name)
@@ -201,28 +210,12 @@ func main() {
// initCalls is dumped at the end.
var initCalls []string
- // Declare our emission closures.
+ // Common closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name)
- }
- emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
- }
- emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name)
- }
- emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name)
- }
- emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
- fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name)
- }
- emitSave := func(name string) {
- fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
}
// Automated warning.
@@ -329,87 +322,140 @@ func main() {
continue
}
- // Only generate code for types marked
- // "// +stateify savable" in one of the proceeding
- // comment lines.
+ // Only generate code for types marked "// +stateify
+ // savable" in one of the proceeding comment lines. If
+ // the line is marked "// +stateify type" then only
+ // generate type information and register the type.
if d.Doc == nil {
continue
}
- savable := false
+ var (
+ generateTypeInfo = false
+ generateSaverLoader = false
+ )
for _, l := range d.Doc.List {
if l.Text == "// +stateify savable" {
- savable = true
+ generateTypeInfo = true
+ generateSaverLoader = true
break
}
+ if l.Text == "// +stateify type" {
+ generateTypeInfo = true
+ }
}
- if !savable {
+ if !generateTypeInfo && !generateSaverLoader {
continue
}
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
- switch ts.Type.(type) {
- case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr:
- // Don't register.
- break
+ switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
- ss := ts.Type.(*ast.StructType)
+ // Record the slot for each field.
+ fieldCount := 0
+ fields := make(map[string]int)
+ emitField := func(name string) {
+ fmt.Fprintf(outputFile, " \"%s\",\n", name)
+ fields[name] = fieldCount
+ fieldCount++
+ }
+ emitFieldValue := func(name string, _ string) {
+ emitField(name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ }
+
+ // Generate the type name method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Generate the fields method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return []string{\n")
+ scanFields(x, "", scanFunctions{
+ normal: emitField,
+ wait: emitField,
+ value: emitFieldValue,
+ })
+ fmt.Fprintf(outputFile, " }\n")
+ fmt.Fprintf(outputFile, "}\n\n")
- // Define beforeSave if a definition was not found. This
- // prevents the code from compiling if a custom beforeSave
- // was defined in a file not provided to this binary and
- // prevents inherited methods from being called multiple times
- // by overriding them.
- if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok {
- fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name)
+ // Define beforeSave if a definition was not found. This prevents
+ // the code from compiling if a custom beforeSave was defined in a
+ // file not provided to this binary and prevents inherited methods
+ // from being called multiple times by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
}
// Generate the save method.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " x.beforeSave()\n")
- scanFields(ss, scanFunctions{zerovalue: emitZeroCheck})
- scanFields(ss, scanFunctions{value: emitSaveValue})
- scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave})
- fmt.Fprintf(outputFile, "}\n\n")
+ //
+ // N.B. For historical reasons, we perform the value saves first,
+ // and perform the value loads last. There should be no dependency
+ // on this specific behavior, but the ability to specify slots
+ // allows a manual implementation to be order-dependent.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(x, "", scanFunctions{value: emitSaveValue})
+ scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
- // Define afterLoad if a definition was not found. We do this
- // for the same reason that we do it for beforeSave.
+ // Define afterLoad if a definition was not found. We do this for
+ // the same reason that we do it for beforeSave.
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
- if !hasAfterLoad {
- fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name)
+ if !hasAfterLoad && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
}
// Generate the load method.
//
- // Note that the manual loads always follow the
- // automated loads.
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait})
- scanFields(ss, scanFunctions{value: emitLoadValue})
- if hasAfterLoad {
- // The call to afterLoad is made conditionally, because when
- // AfterLoad is called, the object encodes a dependency on
- // referred objects (i.e. fields). This means that afterLoad
- // will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ // N.B. See the comment above for the save method.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(x, "", scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
}
- fmt.Fprintf(outputFile, "}\n\n")
// Add to our registration.
emitRegister(ts.Name.Name)
+
case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
maybeEmitImports()
- _, val := resolveTypeName(ts.Name.Name, ts.Type)
-
- // Dispatch directly.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val)
+ // Generate the info methods.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val)
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")
// See above.