diff options
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 136 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 113 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go | 109 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go | 180 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_struct.go | 384 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 52 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 175 |
7 files changed, 820 insertions, 329 deletions
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 729489de5..177013dbb 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -28,12 +28,6 @@ import ( "gvisor.dev/gvisor/tools/tags" ) -const ( - marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" - safecopyImport = "gvisor.dev/gvisor/pkg/safecopy" - usermemImport = "gvisor.dev/gvisor/pkg/usermem" -) - // List of identifiers we use in generated code that may conflict with a // similarly-named source identifier. Abort gracefully when we see these to // avoid potentially confusing compilation failures in generated code. @@ -44,8 +38,8 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", - "ptr", "src", "srcs", "task", "val", + "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", + "length", "limit", "ptr", "size", "src", "srcs", "task", "val", // All single-letter identifiers. } @@ -110,9 +104,10 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G g.imports.add("reflect") g.imports.add("runtime") g.imports.add("unsafe") - g.imports.add(marshalImport) - g.imports.add(safecopyImport) - g.imports.add(usermemImport) + g.imports.add("gvisor.dev/gvisor/pkg/gohacks") + g.imports.add("gvisor.dev/gvisor/pkg/safecopy") + g.imports.add("gvisor.dev/gvisor/pkg/usermem") + g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") return &g, nil } @@ -194,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } +// sliceAPI carries information about the '+marshal slice' directive. +type sliceAPI struct { + // Comment node in the AST containing the +marshal tag. + comment *ast.Comment + // Identifier fragment to use when naming generated functions for the slice + // API. + ident string + // Whether the generated functions should reference the newtype name, or the + // inner type name. Only meaningful on newtype declarations on primitives. + inner bool +} + +// marshallableType carries information about a type marked with the '+marshal' +// directive. +type marshallableType struct { + spec *ast.TypeSpec + slice *sliceAPI +} + +func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType { + mt := marshallableType{ + spec: spec, + slice: nil, + } + + var unhandledTags []string + + for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { + if strings.HasPrefix(tag, "slice:") { + tokens := strings.Split(tag, ":") + if len(tokens) < 2 || len(tokens) > 3 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) + } + if len(tokens[1]) == 0 { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") + } + + sa := &sliceAPI{ + comment: tagLine, + ident: tokens[1], + } + mt.slice = sa + + if len(tokens) == 3 { + if tokens[2] != "inner" { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") + } + sa.inner = true + } + + continue + } + + unhandledTags = append(unhandledTags, tag) + } + + if len(unhandledTags) > 0 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) + } + + return mt +} + // collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { - var types []*ast.TypeSpec +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType { + var types []marshallableType for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) // Type declaration? @@ -212,9 +270,11 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a } // Does the comment contain a "+marshal" line? marked := false + var tagLine *ast.Comment for _, c := range gdecl.Doc.List { - if c.Text == "// +marshal" { + if strings.HasPrefix(c.Text, "// +marshal") { marked = true + tagLine = c break } } @@ -229,20 +289,17 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a switch t.Type.(type) { case *ast.StructType: debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) - types = append(types, t) - continue case *ast.Ident: // Newtype on primitive. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) - types = append(types, t) - continue case *ast.ArrayType: // Newtype on array. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) - types = append(types, t) - continue + default: + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } - // A user specifically requested marshalling on this type, but we - // don't support it. - abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) + types = append(types, newMarshallableType(f, tagLine, t)) + } } return types @@ -269,7 +326,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp // Make sure we have an import that doesn't use any local names that // would conflict with identifiers in the generated code. - if len(i.name) == 1 { + if len(i.name) == 1 && i.name != "_" { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } if _, ok := badIdentsMap[i.name]; ok { @@ -281,19 +338,28 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } -func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - i := newInterfaceGenerator(t, fset) - switch ty := t.Type.(type) { +func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator { + i := newInterfaceGenerator(t.spec, fset) + switch ty := t.spec.Type.(type) { case *ast.StructType: - i.validateStruct(t, ty) + i.validateStruct(t.spec, ty) i.emitMarshallableForStruct(ty) + if t.slice != nil { + i.emitMarshallableSliceForStruct(ty, t.slice) + } case *ast.Ident: i.validatePrimitiveNewtype(ty) i.emitMarshallableForPrimitiveNewtype(ty) + if t.slice != nil { + i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) + } case *ast.ArrayType: - i.validateArrayNewtype(t.Name, ty) + i.validateArrayNewtype(t.spec.Name, ty) // After validate, we can safely call arrayLen. - i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) + i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")) + } default: // This should've been filtered out by collectMarshallabeTypes. panic(fmt.Sprintf("Unexpected type %+v", ty)) @@ -303,9 +369,9 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. -func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t) - i.emitTests() +func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator { + i := newTestGenerator(t.spec) + i.emitTests(t.slice) return i } @@ -355,7 +421,7 @@ func (g *Generator) Run() error { // the list of imports we need to copy to the generated code. for name, _ := range impl.is { if !g.imports.markUsed(name) { - panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name)) + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } } ts = append(ts, g.generateOneTestSuite(t)) @@ -413,7 +479,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { // empty example instead. if len(ts) == 0 { b.reset() - b.emit("func ExampleEmptyTestSuite() {\n") + b.emit("func Example() {\n") b.inIndent(func() { b.emit("// This example is intentionally empty to ensure this file contains at least\n") b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 8babf61d2..e3c3dac63 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,8 +15,10 @@ package gomarshal import ( + "fmt" "go/ast" "go/token" + "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -72,7 +74,6 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) { func (g *interfaceGenerator) recordUsedImport(i string) { g.is[i] = struct{}{} - } func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { @@ -163,3 +164,113 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { g.recordPotentiallyNonPackedField(accessor) } } + +// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a +// byte slice, bypassing escape analysis. The caller is responsible for ensuring +// srcPtr lives until they're done with dstVar, the runtime does not consider +// dstVar dependent on srcPtr due to the escape analysis bypass. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "hdr", and cannot be used +// in a context where it is already bound. +func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) { + g.recordUsedImport("gohacks") + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr) + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type +// to a byte slice. As part of the cast, the byte slice is made to look +// independent of the src slice by bypassing escape analysis. This means the +// byte slice can be used without causing the source to escape. The caller is +// responsible for ensuring srcPtr lives until they're done with dstVar, as the +// runtime no longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifiers "ptr", "val" and "hdr", +// and cannot be used in a context where these identifiers are already bound. +func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) { + g.emitNoEscapeSliceDataPointer(srcPtr, "val") + + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(val)\n") + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an +// unsafe.Pointer, bypassing escape analysis. The caller is responsible for +// ensuring srcPtr lives until they're done with dstVar, as the runtime no +// longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "ptr" cannot be used in a +// context where this identifier is already bound. +func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) { + g.recordUsedImport("gohacks") + g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr) + g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar) +} + +func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) + g.emit("// must live until the use above.\n") + g.emit("runtime.KeepAlive(%s)\n", ptrVar) +} + +func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { + switch x := e.X.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, x) + case *ast.Ident: + fmt.Fprintf(b, "%s", x.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", x.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + + fmt.Fprintf(b, "%s", e.Op) + + switch y := e.Y.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, y) + case *ast.Ident: + fmt.Fprintf(b, "%s", y.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", y.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } +} + +// arrayLenExpr returns a string containing a valid golang expression +// representing the length of array a. The returned expression should be treated +// as a single value, and will be already parenthesized as required. +func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string { + var b strings.Builder + + switch l := a.Len.(type) { + case *ast.Ident: + fmt.Fprintf(&b, "%s", l.Name) + case *ast.BasicLit: + fmt.Fprintf(&b, "%s", l.Value) + case *ast.BinaryExpr: + g.expandBinaryExpr(&b, l) + return fmt.Sprintf("(%s)", b.String()) + default: + g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + return b.String() +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index da36d9305..72ef03a22 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) } - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) - } - if _, ok := a.Elt.(*ast.Ident); !ok { g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) } - - if arrayLen(a) <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } } -func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { g.recordUsedImport("io") g.recordUsedImport("marshal") g.recordUsedImport("reflect") @@ -49,13 +41,16 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.recordUsedImport("unsafe") g.recordUsedImport("usermem") + lenExpr := g.arrayLenExpr(a) + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) g.inIndent(func() { if size, dynamic := g.scalarSize(elt); !dynamic { - g.emit("return %d\n", size*len) + g.emit("return %d * %s\n", size, lenExpr) } else { - g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr) } }) g.emit("}\n\n") @@ -63,7 +58,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") }) @@ -74,7 +69,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") }) @@ -83,6 +78,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Array newtypes are always packed.\n") @@ -104,79 +100,46 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, }) g.emit("}\n\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") }) g.emit("}\n\n") g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index 159397825..39f654ea8 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -104,6 +104,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.recordUsedImport("usermem") g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) g.inIndent(func() { if size, dynamic := g.scalarSize(nt); !dynamic { @@ -129,6 +130,7 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Scalar newtypes are always packed.\n") @@ -150,80 +152,138 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) }) g.emit("}\n\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") }) g.emit("}\n\n") g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) { + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + + eltType := g.typeName() + if slice.inner { + eltType = nt.Name + } + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&dst", "val") + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") }) g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index e66a38b2e..9cd3c9579 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType // No validation to perform on selector fields. However this // callback must still be provided. }, - array: func(n, _ *ast.Ident, len int) { - g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) { + g.validateArrayNewtype(n, a) }, unhandled: func(_ *ast.Ident) { g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) @@ -72,20 +72,24 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType }) } -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { - // Is g.t a packed struct without consideing field types? - thisPacked := true +func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { + packed := true forEachStructField(st, func(f *ast.Field) { if f.Tag != nil { if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { + if packed { debugfAt(g.f.Position(g.t.Pos()), fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false + packed = false } } } }) + return packed +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + thisPacked := g.isStructPacked(st) g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) @@ -108,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.recordUsedMarshallable(tName) dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) } else { g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) } }, }.dispatch) @@ -148,7 +149,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.shift("dst", len) } else { // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here + // an instance of the dynamic type we can reference here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) @@ -158,24 +159,30 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, - array: func(n, t *ast.Ident, size int) { + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) } else { // We can't use shiftDynamic here because we don't have // an instance of the dynamic type we can reference here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) } return } - g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) @@ -195,11 +202,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { if len, dynamic := g.scalarSize(t); !dynamic { g.shift("src", len) } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) } return @@ -207,24 +214,31 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, - array: func(n, t *ast.Ident, size int) { + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) } else { // We can't use shiftDynamic here because we don't have // an instance of the dynamic type we can referece here // (since the version in this struct is anonymous). Use // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) } return } - g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) @@ -235,6 +249,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) g.inIndent(func() { expr, fieldsMaybePacked := g.areFieldsPackedExpression() @@ -302,17 +317,17 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { }) g.emit("}\n\n") - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("return err\n") + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") } if thisPacked { g.recordUsedImport("reflect") @@ -324,48 +339,41 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyOutBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyOutBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") } else { fallback() } }) g.emit("}\n\n") + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("if err != nil {\n") - g.inIndent(func() { - g.emit("return err\n") - }) - g.emit("}\n") - - g.emit("%s.UnmarshalBytes(buf)\n", g.r) - g.emit("return nil\n") + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") + g.emit("// partially unmarshalled struct.\n") + g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return length, err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -377,25 +385,11 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast deserialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("_, err := task.CopyInBytes(addr, buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the CopyInBytes.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") } else { fallback() } @@ -410,8 +404,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("n, err := w.Write(buf)\n") - g.emit("return int64(n), err\n") + g.emit("length, err := w.Write(buf)\n") + g.emit("return int64(length), err\n") } if thisPacked { g.recordUsedImport("reflect") @@ -423,25 +417,199 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") } // Fast serialization. - g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) - g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) - g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") - g.emit("ptr := unsafe.Pointer(%s)\n", g.r) - g.emit("val := uintptr(ptr)\n") - g.emit("val = val^0\n\n") - - g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) - g.emit("var buf []byte\n") - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") - g.emit("hdr.Data = val\n") - g.emit("hdr.Len = %s.SizeBytes()\n", g.r) - g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) - - g.emit("len, err := w.Write(buf)\n") - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) - g.emit("// must live until after the Write.\n") - g.emit("runtime.KeepAlive(%s)\n", g.r) - g.emit("return int64(len), err\n") + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { + thisPacked := g.isStructPacked(st) + + if slice.inner { + abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) + } + + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + + g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") + g.emit("limit := length/size\n") + g.emit("for idx := 0; idx < limit; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("// Handle any final partial object.\n") + g.emit("if length < size*count && length%size != 0 {\n") + g.inIndent(func() { + g.emit("idx := limit\n") + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") } else { fallback() } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index fd992e44a..631295373 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -30,6 +30,11 @@ var standardImports = []string{ "gvisor.dev/gvisor/tools/go_marshal/analysis", } +var sliceAPIImports = []string{ + "encoding/binary", + "gvisor.dev/gvisor/pkg/usermem", +} + type testGenerator struct { sourceBuffer @@ -58,6 +63,11 @@ func newTestGenerator(t *ast.TypeSpec) *testGenerator { for _, i := range standardImports { g.imports.add(i).markUsed() } + // These imports are used if a type requests the slice API. Don't + // mark them as used by default. + for _, i := range sliceAPIImports { + g.imports.add(i) + } return g } @@ -132,6 +142,42 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { }) } +func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { + for _, name := range []string{"binary", "usermem"} { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) + } + } + + g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { + g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) + g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") + g.emit("buf.Reset()\n") + g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") + }) + g.emit("}\n") + g.emit("bufUnsafe := make([]byte, size)\n") + g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) + + g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + }) +} + func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { g.emit("var x, y, yUnsafe %s\n", g.typeName()) @@ -170,12 +216,16 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { }) } -func (g *testGenerator) emitTests() { +func (g *testGenerator) emitTests(slice *sliceAPI) { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() g.emitTestWriteToUnmarshalPreservesData() g.emitTestSizeBytesOnTypedNilPtr() + + if slice != nil { + g.emitTestMarshalUnmarshalSlicePreservesData(slice) + } } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index a0936e013..d94314302 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -25,7 +25,6 @@ import ( "path" "reflect" "sort" - "strconv" "strings" ) @@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { type fieldDispatcher struct { primitive func(n, t *ast.Ident) selector func(n, tX, tSel *ast.Ident) - array func(n, t *ast.Ident, size int) + array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) unhandled func(n *ast.Ident) } -// Precondition: a must have a literal for the array length. Consts and -// expressions are not allowed as array lengths, and should be rejected by the -// caller. -func arrayLen(a *ast.ArrayType) int { - if a.Len == nil { - // Probably a slice? Must be handled by caller. - panic("Nil array length in array type") - } - lenLit, ok := a.Len.(*ast.BasicLit) - if !ok { - panic("Array has non-literal for length") - } - len, err := strconv.Atoi(lenLit.Value) - if err != nil { - panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) - } - return len -} - // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.ArrayType: switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, arrayLen(v)) + fd.array(name, v, t) default: // Should be handled with a better error message during validate. panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) @@ -285,6 +265,11 @@ type importStmt struct { aliased bool // Indicates whether this import was referenced by generated code. used bool + // AST node and file set representing the import statement, if any. These + // are only non-nil if the import statement originates from an input source + // file. + spec *ast.ImportSpec + fset *token.FileSet } func newImport(p string) *importStmt { @@ -310,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { name: name, path: p, aliased: spec.Name != nil, + spec: spec, + fset: f, } } +// String implements fmt.Stringer.String. This generates a string for the import +// statement appropriate for writing directly to generated code. func (i *importStmt) String() string { if i.aliased { - return fmt.Sprintf("%s \"%s\"", i.name, i.path) + return fmt.Sprintf("%s %q", i.name, i.path) + } + return fmt.Sprintf("%q", i.path) +} + +// debugString returns a debug string representing an import statement. This +// representation is not valid golang code and is used for debugging output. +func (i *importStmt) debugString() string { + if i.spec != nil && i.fset != nil { + return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) } - return fmt.Sprintf("\"%s\"", i.path) + return fmt.Sprintf("(go-marshal import): %s", i) } func (i *importStmt) markUsed() { @@ -329,40 +327,78 @@ func (i *importStmt) equivalent(other *importStmt) bool { } // importTable represents a collection of importStmts. +// +// An importTable may contain multiple import statements referencing the same +// local name. All import statements aliasing to the same local name are +// technically ambiguous, as if such an import name is used in the generated +// code, it's not clear which import statement it refers to. We ignore any +// potential collisions until actually writing the import table to the generated +// source file. See importTable.write. +// +// Given the following import statements across all the files comprising a +// package marshalled: +// +// "sync" +// "pkg/sync" +// "pkg/sentry/kernel" +// ktime "pkg/sentry/kernel/time" +// +// An importTable representing them would look like this: +// +// importTable { +// is: map[string][]*importStmt { +// "sync": []*importStmt{ +// importStmt{name:"sync", path:"sync", aliased:false} +// importStmt{name:"sync", path:"pkg/sync", aliased:false} +// }, +// "kernel": []*importStmt{importStmt{ +// name: "kernel", +// path: "pkg/sentry/kernel", +// aliased: false +// }}, +// "ktime": []*importStmt{importStmt{ +// name: "ktime", +// path: "pkg/sentry/kernel/time", +// aliased: true, +// }}, +// } +// } +// +// Note that the local name "sync" is assigned to two different import +// statements. This is possible if the import statements are from different +// source files in the same package. +// +// Since go-marshal generates a single output file per package regardless of the +// number of input files, if "sync" is referenced by any generated code, it's +// unclear which import statement "sync" refers to. While it's theoretically +// possible to resolve this by assigning a unique local alias to each instance +// of the sync package, go-marshal currently aborts when it encounters such an +// ambiguity. +// +// TODO(b/151478251): importTable considers the final component of an import +// path to be the package name, but this is only a convention. The actual +// package name is determined by the package statement in the source files for +// the package. type importTable struct { // Map of imports and whether they should be copied to the output. - is map[string]*importStmt + is map[string][]*importStmt } func newImportTable() *importTable { return &importTable{ - is: make(map[string]*importStmt), + is: make(map[string][]*importStmt), } } -// Merges import statements from other into i. Collisions in import statements -// result in a panic. +// Merges import statements from other into i. func (i *importTable) merge(other *importTable) { - for name, im := range other.is { - if dup, ok := i.is[name]; ok && !dup.equivalent(im) { - panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) - } - - i.is[name] = im + for name, ims := range other.is { + i.is[name] = append(i.is[name], ims...) } } func (i *importTable) addStmt(s *importStmt) *importStmt { - if old, ok := i.is[s.name]; ok && !old.equivalent(s) { - // A collision should always be between an import inserted by the - // go-marshal tool and an import from the original source file (assuming - // the original source file was valid). We could theoretically handle - // the collision by assigning a local name to our import. However, this - // would need to be plumbed throughout the generator. Given that - // collisions should be rare, simply panic on collision. - panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) - } - i.is[s.name] = s + i.is[s.name] = append(i.is[s.name], s) return s } @@ -378,16 +414,20 @@ func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *impor // Marks the import named n as used. If no such import is in the table, returns // false. func (i *importTable) markUsed(n string) bool { - if n, ok := i.is[n]; ok { - n.markUsed() + if ns, ok := i.is[n]; ok { + for _, n := range ns { + n.markUsed() + } return true } return false } func (i *importTable) clear() { - for _, i := range i.is { - i.used = false + for _, is := range i.is { + for _, i := range is { + i.used = false + } } } @@ -398,9 +438,42 @@ func (i *importTable) write(out io.Writer) error { } imports := make([]string, 0, len(i.is)) - for _, i := range i.is { - if i.used { - imports = append(imports, i.String()) + for name, is := range i.is { + var lastUsed *importStmt + var ambiguous bool + + for _, i := range is { + if i.used { + if lastUsed != nil { + if !i.equivalent(lastUsed) { + ambiguous = true + } + } + lastUsed = i + } + } + + if ambiguous { + // We have two or more import statements across the different source + // files that share a local name, and at least one of these imports + // are used by the generated code. This ambiguity can't be resolved + // by go-marshal and requires the user intervention. Dump a list of + // the colliding import statements and let the user modify the input + // files as appropriate. + var b strings.Builder + fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) + fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) + // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't + // be true. Therefore the slicing below is safe. + for _, i := range is[:len(is)-1] { + fmt.Fprintf(&b, " %v\n", i.debugString()) + } + fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) + panic(b.String()) + } + + if lastUsed != nil { + imports = append(imports, lastUsed.String()) } } sort.Strings(imports) |