summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r--tools/go_marshal/gomarshal/generator.go39
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go98
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go1
-rw-r--r--tools/go_marshal/gomarshal/util.go5
4 files changed, 128 insertions, 15 deletions
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 0b3f600fe..01be7c477 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -34,9 +34,9 @@ const (
usermemImport = "gvisor.dev/gvisor/pkg/usermem"
)
-// List of identifiers we use in generated code, that may conflict a
-// similarly-named source identifier. Avoid problems by refusing the generate
-// code when we see these.
+// List of identifiers we use in generated code that may conflict with a
+// similarly-named source identifier. Abort gracefully when we see these to
+// avoid potentially confusing compilation failures in generated code.
//
// This only applies to import aliases at the moment. All other identifiers
// are qualified by a receiver argument, since they're struct fields.
@@ -44,10 +44,20 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
+// Constructed fromt badIdents in init().
+var badIdentsMap map[string]struct{}
+
+func init() {
+ badIdentsMap = make(map[string]struct{})
+ for _, ident := range badIdents {
+ badIdentsMap[ident] = struct{}{}
+ }
+}
+
// Generator drives code generation for a single invocation of the go_marshal
// utility.
//
@@ -88,16 +98,18 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
}
for _, i := range imports {
// All imports on the extra imports list are unconditionally marked as
- // used, so they're always added to the generated code.
+ // used, so that they're always added to the generated code.
g.imports.add(i).markUsed()
}
g.imports.add(marshalImport).markUsed()
- // The follow imports may or may not be used by the generated
- // code, depending what's required for the target types. Don't
- // mark these imports as used by default.
- g.imports.add(usermemImport)
+ // The following imports may or may not be used by the generated code,
+ // depending on what's required for the target types. Don't mark these as
+ // used by default.
+ g.imports.add("reflect")
+ g.imports.add("runtime")
g.imports.add(safecopyImport)
g.imports.add("unsafe")
+ g.imports.add(usermemImport)
return &g, nil
}
@@ -229,11 +241,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
// identifiers in the generated code don't conflict with any imported package
// names.
func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
- badImportNames := make(map[string]bool)
- for _, i := range badIdents {
- badImportNames[i] = true
- }
-
is := make(map[string]importStmt)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -250,7 +257,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
if len(i.name) == 1 {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
- if badImportNames[i.name] {
+ if _, ok := badIdentsMap[i.name]; ok {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
}
}
@@ -371,6 +378,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Collect and write test import statements.
imports := newImportTable()
for _, t := range ts {
imports.merge(t.imports)
@@ -380,6 +388,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Write test functions.
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index a712c14dc..f25331ac5 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -504,4 +504,102 @@ func (g *interfaceGenerator) emitMarshallable() {
})
g.emit("}\n\n")
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return len, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("n, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("if err != nil {\n")
+ g.inIndent(func() {
+ g.emit("return n, err\n")
+ })
+ g.emit("}\n")
+
+ g.emit("%s.UnmarshalBytes(buf)\n", g.r)
+ g.emit("return n, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return len, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index bcda17c3b..cc760b6d0 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -25,6 +25,7 @@ var standardImports = []string{
"fmt",
"reflect",
"testing",
+
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 967537abf..3d86935b4 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -219,6 +219,11 @@ type sourceBuffer struct {
b bytes.Buffer
}
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
func (b *sourceBuffer) incIndent() {
b.indent++
}