diff options
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 39 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 98 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 1 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 5 |
4 files changed, 128 insertions, 15 deletions
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 0b3f600fe..01be7c477 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -34,9 +34,9 @@ const ( usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) -// List of identifiers we use in generated code, that may conflict a -// similarly-named source identifier. Avoid problems by refusing the generate -// code when we see these. +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. @@ -44,10 +44,20 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "src", "srcs", "dst", "dsts", "blk", "buf", "err", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + // Generator drives code generation for a single invocation of the go_marshal // utility. // @@ -88,16 +98,18 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as - // used, so they're always added to the generated code. + // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } g.imports.add(marshalImport).markUsed() - // The follow imports may or may not be used by the generated - // code, depending what's required for the target types. Don't - // mark these imports as used by default. - g.imports.add(usermemImport) + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("reflect") + g.imports.add("runtime") g.imports.add(safecopyImport) g.imports.add("unsafe") + g.imports.add(usermemImport) return &g, nil } @@ -229,11 +241,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - badImportNames := make(map[string]bool) - for _, i := range badIdents { - badImportNames[i] = true - } - is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -250,7 +257,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp if len(i.name) == 1 { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } - if badImportNames[i.name] { + if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } @@ -371,6 +378,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) @@ -380,6 +388,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Write test functions. for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index a712c14dc..f25331ac5 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -504,4 +504,102 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return len, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("n, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return n, err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return n, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return len, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index bcda17c3b..cc760b6d0 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -25,6 +25,7 @@ var standardImports = []string{ "fmt", "reflect", "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", } diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 967537abf..3d86935b4 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -219,6 +219,11 @@ type sourceBuffer struct { b bytes.Buffer } +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + func (b *sourceBuffer) incIndent() { b.indent++ } |