diff options
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r-- | tools/go_marshal/gomarshal/BUILD | 4 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 133 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 368 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go | 183 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go | 229 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_struct.go | 450 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 67 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 71 |
8 files changed, 1073 insertions, 432 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index c92b59dd6..44cb33ae4 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -7,6 +7,9 @@ go_library( srcs = [ "generator.go", "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", "generator_tests.go", "util.go", ], @@ -14,4 +17,5 @@ go_library( visibility = [ "//:sandbox", ], + deps = ["//tools/tags"], ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index af90bdecb..729489de5 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -23,6 +23,9 @@ import ( "go/token" "os" "sort" + "strings" + + "gvisor.dev/gvisor/tools/tags" ) const ( @@ -31,9 +34,9 @@ const ( usermemImport = "gvisor.dev/gvisor/pkg/usermem" ) -// List of identifiers we use in generated code, that may conflict a -// similarly-named source identifier. Avoid problems by refusing the generate -// code when we see these. +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. @@ -41,10 +44,21 @@ const ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "src", "srcs", "dst", "dsts", "blk", "buf", "err", + "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len", + "ptr", "src", "srcs", "task", "val", // All single-letter identifiers. } +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + // Generator drives code generation for a single invocation of the go_marshal // utility. // @@ -85,16 +99,20 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as - // used, so they're always added to the generated code. + // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } - g.imports.add(marshalImport).markUsed() - // The follow imports may or may not be used by the generated - // code, depending what's required for the target types. Don't - // mark these imports as used by default. - g.imports.add(usermemImport) - g.imports.add(safecopyImport) + + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("io") + g.imports.add("reflect") + g.imports.add("runtime") g.imports.add("unsafe") + g.imports.add(marshalImport) + g.imports.add(safecopyImport) + g.imports.add(usermemImport) return &g, nil } @@ -104,6 +122,14 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + + // Package header. b.emit("package %s\n\n", g.pkg) if err := b.write(g.output); err != nil { return err @@ -168,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { return files, fsets, nil } -// collectMarshallabeTypes walks the parsed AST and collects a list of type +// collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { var types []*ast.TypeSpec for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -197,14 +223,26 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as continue } for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier. + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. t := spec.(*ast.TypeSpec) - if _, ok := t.Type.(*ast.StructType); ok { - debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) + types = append(types, t) + continue + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) types = append(types, t) continue } - debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } } return types @@ -218,11 +256,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - badImportNames := make(map[string]bool) - for _, i := range badIdents { - badImportNames[i] = true - } - is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) @@ -239,7 +272,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp if len(i.name) == 1 { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } - if badImportNames[i.name] { + if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } @@ -249,11 +282,22 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - // We're guaranteed to have only struct type specs by now. See - // Generator.collectMarshallabeTypes. i := newInterfaceGenerator(t, fset) - i.validate() - i.emitMarshallable() + switch ty := t.Type.(type) { + case *ast.StructType: + i.validateStruct(t, ty) + i.emitMarshallableForStruct(ty) + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype(ty) + case *ast.ArrayType: + i.validateArrayNewtype(t.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty)) + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } return i } @@ -300,7 +344,7 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. - for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref, _ := range impl.ms { @@ -318,17 +362,6 @@ func (g *Generator) Run() error { } } - // Tool was invoked with input files with no data structures marked for code - // generation. This is probably not what the user intended. - if len(impls) == 0 { - var buf bytes.Buffer - fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") - for _, i := range g.inputs { - fmt.Fprintf(&buf, " %s\n", i) - } - abort(buf.String()) - } - // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { @@ -360,6 +393,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) @@ -369,6 +403,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error { return err } + // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func ExampleEmptyTestSuite() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index a712c14dc..8babf61d2 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -15,10 +15,8 @@ package gomarshal import ( - "fmt" "go/ast" "go/token" - "strings" ) // interfaceGenerator generates marshalling interfaces for a single type. @@ -55,9 +53,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -84,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { g.as[fieldName] = struct{}{} } -func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - // abortAt aborts the go_marshal tool with the given error message, with a // reference position to the input source. Same as abortAt, but uses g to // resolve p to position. @@ -103,67 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { - g.forEachField(func(f *ast.Field) { - if len(f.Names) == 0 { - g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") - } - }) - - g.forEachField(func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n, _ *ast.Ident, len int) { - a := f.Type.(*ast.ArrayType) - if a.Len == nil { - g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Len.(*ast.BasicLit); !ok { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } - - if len <= 0 { - g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) - } - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - // scalarSize returns the size of type identified by t. If t isn't a primitive // type, the size isn't known at code generation time, and must be resolved via // the marshal.Marshallable interface. @@ -190,7 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +136,26 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -257,251 +163,3 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string g.recordPotentiallyNonPackedField(accessor) } } - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns <clause>, true, where <clause> is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor, _ := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - return strings.Join(cs, " && "), true -} - -func (g *interfaceGenerator) emitMarshallable() { - // Is g.t a packed struct without consideing field types? - thisPacked := true - g.forEachField(func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if thisPacked { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - thisPacked = false - } - } - } - }) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) - } - }, - selector: func(n, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(n, t *ast.Ident, len int) { - if len < 1 { - // Zero-length arrays should've been rejected by validate(). - panic("unreachable") - } - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size * len - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for i := 0; i < %d; i++ {\n", size) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.forEachField(fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n, t *ast.Ident, size int) { - if n.Name == "_" { - g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len*size) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) - } - return - } - - g.emit("for i := 0; i < %d; i++ {\n", size) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - if thisPacked { - g.recordUsedImport("safecopy") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - }) - g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n") - } else { - g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) - } - } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - }) - g.emit("}\n\n") - -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..da36d9305 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,183 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if arrayLen(a) <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d\n", size*len) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %d; idx++ {\n", len) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..159397825 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..e66a38b2e --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,450 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + g.validateArrayNewtype(n, f.Type.(*ast.ArrayType)) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + // Is g.t a packed struct without consideing field types? + thisPacked := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for idx := 0; idx < %d; idx++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("return err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r) + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("if err != nil {\n") + g.inIndent(func() { + g.emit("return err\n") + }) + g.emit("}\n") + + g.emit("%s.UnmarshalBytes(buf)\n", g.r) + g.emit("return nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index bcda17c3b..fd992e44a 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -22,9 +22,11 @@ import ( ) var standardImports = []string{ + "bytes", "fmt", "reflect", "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", } @@ -47,9 +49,6 @@ type testGenerator struct { } func newTestGenerator(t *ast.TypeSpec) *testGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &testGenerator{ t: t, r: receiverName(t), @@ -67,14 +66,6 @@ func (g *testGenerator) typeName() string { return g.t.Name.Name } -func (g *testGenerator) forEachField(fn func(f *ast.Field)) { - // This is guaranteed to succeed because g.t is always a struct. - st := g.t.Type.(*ast.StructType) - for _, field := range st.Fields.List { - fn(field) - } -} - func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } @@ -87,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) { func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { - g.emit("x := &%s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") }) g.emit("}\n") }) @@ -98,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() { func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { - g.emit("x := %s{}\n", g.typeName()) + g.emit("var x %v\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } @@ -116,26 +107,64 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") }) g.emit("}\n") }) @@ -145,6 +174,8 @@ func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go index 967537abf..a0936e013 100644 --- a/tools/go_marshal/gomarshal/util.go +++ b/tools/go_marshal/gomarshal/util.go @@ -64,6 +64,12 @@ func kindString(e ast.Expr) string { } } +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { @@ -73,6 +79,25 @@ type fieldDispatcher struct { unhandled func(n *ast.Ident) } +// Precondition: a must have a literal for the array length. Consts and +// expressions are not allowed as array lengths, and should be rejected by the +// caller. +func arrayLen(a *ast.ArrayType) int { + if a.Len == nil { + // Probably a slice? Must be handled by caller. + panic("Nil array length in array type") + } + lenLit, ok := a.Len.(*ast.BasicLit) + if !ok { + panic("Array has non-literal for length") + } + len, err := strconv.Atoi(lenLit.Value) + if err != nil { + panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) + } + return len +} + // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { @@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) { case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: - len := 0 - if v.Len != nil { - // Non-literal array length is handled by generatorInterfaces.validate(). - if lenLit, ok := v.Len.(*ast.BasicLit); ok { - var err error - len, err = strconv.Atoi(lenLit.Value) - if err != nil { - panic(err) - } - } - } switch t := v.Elt.(type) { case *ast.Ident: - fd.array(name, t, len) + fd.array(name, t, arrayLen(v)) default: - fd.array(name, nil, len) + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) @@ -219,6 +234,11 @@ type sourceBuffer struct { b bytes.Buffer } +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + func (b *sourceBuffer) incIndent() { b.indent++ } @@ -305,7 +325,7 @@ func (i *importStmt) markUsed() { } func (i *importStmt) equivalent(other *importStmt) bool { - return i == other + return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. @@ -324,7 +344,7 @@ func newImportTable() *importTable { // result in a panic. func (i *importTable) merge(other *importTable) { for name, im := range other.is { - if dup, ok := i.is[name]; ok && dup.equivalent(im) { + if dup, ok := i.is[name]; ok && !dup.equivalent(im) { panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) } @@ -332,16 +352,27 @@ func (i *importTable) merge(other *importTable) { } } +func (i *importTable) addStmt(s *importStmt) *importStmt { + if old, ok := i.is[s.name]; ok && !old.equivalent(s) { + // A collision should always be between an import inserted by the + // go-marshal tool and an import from the original source file (assuming + // the original source file was valid). We could theoretically handle + // the collision by assigning a local name to our import. However, this + // would need to be plumbed throughout the generator. Given that + // collisions should be rare, simply panic on collision. + panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) + } + i.is[s.name] = s + return s +} + func (i *importTable) add(s string) *importStmt { n := newImport(s) - i.is[n.name] = n - return n + return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - n := newImportFromSpec(spec, f) - i.is[n.name] = n - return n + return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns |