diff options
author | Rahat Mahmood <rahat@google.com> | 2019-09-09 13:35:30 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-09-09 13:36:39 -0700 |
commit | 3733b9b893ec33877b1b46c56fe07c3856942d3f (patch) | |
tree | 4b9d6c2a46ec11f5de3095980206495c7b53a013 /tools/go_marshal/gomarshal | |
parent | 6af9a9850aff75e15c6f9ab577af5b818531d6ee (diff) |
go_marshal: Implement automatic generation of ABI marshalling code.
This CL implements go_marshal, a code generation utility for
automatically serializing and deserializing ABI structs.
The go_marshal tool automatically generates implementations of the new
marshal interface. Unlike binary.Marshal/Unmarshal, the generated
interface implementations use no runtime reflection, and translates to
a single memcpy for most structs. See go_marshal/README.md for
details.
PiperOrigin-RevId: 268065475
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r-- | tools/go_marshal/gomarshal/BUILD | 17 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 382 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 507 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 154 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 387 |
5 files changed, 1447 insertions, 0 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD new file mode 100644 index 000000000..a0eae6492 --- /dev/null +++ b/tools/go_marshal/gomarshal/BUILD @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "gomarshal", + srcs = [ + "generator.go", + "generator_interfaces.go", + "generator_tests.go", + "util.go", + ], + importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal", + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go new file mode 100644 index 000000000..641ccd938 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator.go @@ -0,0 +1,382 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomarshal implements the go_marshal code generator. See README.md. +package gomarshal + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "sort" +) + +const ( + marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal" + usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem" + safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy" +) + +// List of identifiers we use in generated code, that may conflict a +// similarly-named source identifier. Avoid problems by refusing the generate +// code when we see these. +// +// This only applies to import aliases at the moment. All other identifiers +// are qualified by a receiver argument, since they're struct fields. +// +// All recievers are single letters, so we don't allow import aliases to be a +// single letter. +var badIdents = []string{ + "src", "srcs", "dst", "dsts", "blk", "buf", "err", + // All single-letter identifiers. +} + +// Generator drives code generation for a single invocation of the go_marshal +// utility. +// +// The Generator holds arguments passed to the tool, and drives parsing, +// processing and code Generator for all types marked with +marshal declared in +// the input files. +// +// See Generator.run() as the entry point. +type Generator struct { + // Paths to input go source files. + inputs []string + // Output file to write generated go source. + output *os.File + // Output file to write generated tests. + outputTest *os.File + // Package name for the generated file. + pkg string + // Go import path for package we're processing. This package should directly + // declare the type we're generating code for. + declaration string + // Set of extra packages to import in the generated file. + imports *importTable +} + +// NewGenerator creates a new code Generator. +func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) { + f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) + } + fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) + } + g := Generator{ + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + declaration: declaration, + imports: newImportTable(), + } + for _, i := range imports { + // All imports on the extra imports list are unconditionally marked as + // used, so they're always added to the generated code. + g.imports.add(i).markUsed() + } + g.imports.add(marshalImport).markUsed() + // The follow imports may or may not be used by the generated + // code, depending what's required for the target types. Don't + // mark these imports as used by default. + g.imports.add(usermemImport) + g.imports.add(safecopyImport) + g.imports.add("unsafe") + + return &g, nil +} + +// writeHeader writes the header for the generated source file. The header +// includes the package name, package level comments and import statements. +func (g *Generator) writeHeader() error { + var b sourceBuffer + b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.output); err != nil { + return err + } + + return g.imports.write(g.output) +} + +// writeTypeChecks writes a statement to force the compiler to perform a type +// check for all Marshallable types referenced by the generated code. +func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { + if len(ms) == 0 { + return nil + } + + msl := make([]string, 0, len(ms)) + for m, _ := range ms { + msl = append(msl, m) + } + sort.Strings(msl) + + var buf bytes.Buffer + fmt.Fprint(&buf, "// Marshallable types used by this file.\n") + + for _, m := range msl { + fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) + } + fmt.Fprint(&buf, "\n") + + _, err := fmt.Fprint(g.output, buf.String()) + return err +} + +// parse processes all input files passed this generator and produces a set of +// parsed go ASTs. +func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { + debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) + for _, path := range g.inputs { + debugf(" %s\n", path) + } + + files := make([]*ast.File, 0, len(g.inputs)) + fsets := make([]*token.FileSet, 0, len(g.inputs)) + + for _, path := range g.inputs { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) + } + + if debugEnabled() { + debugf("AST for %q:\n", path) + ast.Print(fset, f) + } + + files = append(files, f) + fsets = append(fsets, fset) + } + + return files, fsets, nil +} + +// collectMarshallabeTypes walks the parsed AST and collects a list of type +// declarations for which we need to generate the Marshallable interface. +func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec { + var types []*ast.TypeSpec + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Type declaration? + if !ok || gdecl.Tok != token.TYPE { + debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") + continue + } + // Does it have a comment? + if gdecl.Doc == nil { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") + continue + } + // Does the comment contain a "+marshal" line? + marked := false + for _, c := range gdecl.Doc.List { + if c.Text == "// +marshal" { + marked = true + break + } + } + if !marked { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") + continue + } + for _, spec := range gdecl.Specs { + // We already confirmed we're in a type declaration earlier. + t := spec.(*ast.TypeSpec) + if _, ok := t.Type.(*ast.StructType); ok { + debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name) + types = append(types, t) + continue + } + debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl) + } + } + return types +} + +// collectImports collects all imports from all input source files. Some of +// these imports are copied to the generated output, if they're referenced by +// the generated code. +// +// collectImports de-duplicates imports while building the list, and ensures +// identifiers in the generated code don't conflict with any imported package +// names. +func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { + badImportNames := make(map[string]bool) + for _, i := range badIdents { + badImportNames[i] = true + } + + is := make(map[string]importStmt) + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Import statement? + if !ok || gdecl.Tok != token.IMPORT { + continue + } + for _, spec := range gdecl.Specs { + i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) + debugf("Collected import '%s' as '%s'\n", i.path, i.name) + + // Make sure we have an import that doesn't use any local names that + // would conflict with identifiers in the generated code. + if len(i.name) == 1 { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) + } + if badImportNames[i.name] { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) + } + } + } + return is + +} + +func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + // We're guaranteed to have only struct type specs by now. See + // Generator.collectMarshallabeTypes. + i := newInterfaceGenerator(t, fset) + i.validate() + i.emitMarshallable() + return i +} + +// generateOneTestSuite generates a test suite for the automatically generated +// implementations type t. +func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { + i := newTestGenerator(t, g.declaration) + i.emitTests() + return i +} + +// Run is the entry point to code generation using g. +// +// Run parses all input source files specified in g and emits generated code. +func (g *Generator) Run() error { + // Parse our input source files into ASTs and token sets. + asts, fsets, err := g.parse() + if err != nil { + return err + } + + if len(asts) != len(fsets) { + panic("ASTs and FileSets don't match") + } + + // Map of imports in source files; key = local package name, value = import + // path. + is := make(map[string]importStmt) + for i, a := range asts { + // Collect all imports from the source files. We may need to copy some + // of these to the generated code if they're referenced. This has to be + // done before the loop below because we need to process all ASTs before + // we start requesting imports to be copied one by one as we encounter + // them in each generated source. + for name, i := range g.collectImports(a, fsets[i]) { + is[name] = i + } + } + + var impls []*interfaceGenerator + var ts []*testGenerator + // Set of Marshallable types referenced by generated code. + ms := make(map[string]struct{}) + for i, a := range asts { + // Collect type declarations marked for code generation and generate + // Marshallable interfaces. + for _, t := range g.collectMarshallabeTypes(a, fsets[i]) { + impl := g.generateOne(t, fsets[i]) + // Collect Marshallable types referenced by the generated code. + for ref, _ := range impl.ms { + ms[ref] = struct{}{} + } + impls = append(impls, impl) + // Collect imports referenced by the generated code and add them to + // the list of imports we need to copy to the generated code. + for name, _ := range impl.is { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name)) + } + } + ts = append(ts, g.generateOneTestSuite(t)) + } + } + + // Tool was invoked with input files with no data structures marked for code + // generation. This is probably not what the user intended. + if len(impls) == 0 { + var buf bytes.Buffer + fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n") + for _, i := range g.inputs { + fmt.Fprintf(&buf, " %s\n", i) + } + abort(buf.String()) + } + + // Write output file header. These include things like package name and + // import statements. + if err := g.writeHeader(); err != nil { + return err + } + + // Write type checks for referenced marshallable types to output file. + if err := g.writeTypeChecks(ms); err != nil { + return err + } + + // Write generated interfaces to output file. + for _, i := range impls { + if err := i.write(g.output); err != nil { + return err + } + } + + // Write generated tests to test file. + return g.writeTests(ts) +} + +// writeTests outputs tests for the generated interface implementations to a go +// source file. +func (g *Generator) writeTests(ts []*testGenerator) error { + var b sourceBuffer + b.emit("package %s_test\n\n", g.pkg) + if err := b.write(g.outputTest); err != nil { + return err + } + + imports := newImportTable() + for _, t := range ts { + imports.merge(t.imports) + } + + if err := imports.write(g.outputTest); err != nil { + return err + } + + for _, t := range ts { + if err := t.write(g.outputTest); err != nil { + return err + } + } + return nil +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go new file mode 100644 index 000000000..a712c14dc --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -0,0 +1,507 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "go/token" + "strings" +) + +// interfaceGenerator generates marshalling interfaces for a single type. +// +// getState is not thread-safe. +type interfaceGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // FileSet containing the tokens for the type we're processing. + f *token.FileSet + + // is records external packages referenced by the generated implementation. + is map[string]struct{} + + // ms records Marshallable types referenced by the generated implementation + // of t's interfaces. + ms map[string]struct{} + + // as records embedded fields in t that are potentially not packed. The key + // is the accessor for the field. + as map[string]struct{} +} + +// typeName returns the name of the type this g represents. +func (g *interfaceGenerator) typeName() string { + return g.t.Name.Name +} + +// newinterfaceGenerator creates a new interface generator. +func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + if _, ok := t.Type.(*ast.StructType); !ok { + panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) + } + g := &interfaceGenerator{ + t: t, + r: receiverName(t), + f: fset, + is: make(map[string]struct{}), + ms: make(map[string]struct{}), + as: make(map[string]struct{}), + } + g.recordUsedMarshallable(g.typeName()) + return g +} + +func (g *interfaceGenerator) recordUsedMarshallable(m string) { + g.ms[m] = struct{}{} + +} + +func (g *interfaceGenerator) recordUsedImport(i string) { + g.is[i] = struct{}{} + +} + +func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { + g.as[fieldName] = struct{}{} +} + +func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) { + // This is guaranteed to succeed because g.t is always a struct. + st := g.t.Type.(*ast.StructType) + for _, field := range st.Fields.List { + fn(field) + } +} + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// abortAt aborts the go_marshal tool with the given error message, with a +// reference position to the input source. Same as abortAt, but uses g to +// resolve p to position. +func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { + abortAt(g.f.Position(p), msg) +} + +// validate ensures the type we're working with can be marshalled. These checks +// are done ahead of time and in one place so we can make assumptions later. +func (g *interfaceGenerator) validate() { + g.forEachField(func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + g.forEachField(func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we + // provide suggestions for some disallowed types and reject + // them, then attempt to marshal any remaining types by + // invoking the marshal.Marshallable interface on them. If + // these types don't actually implement + // marshal.Marshallable, compilation of the generated code + // will fail with an appropriate error message. + return + case "int": + g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n, _ *ast.Ident, len int) { + a := f.Type.(*ast.ArrayType) + if a.Len == nil { + g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Len.(*ast.BasicLit); !ok { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions")) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } + + if len <= 0 { + g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?")) + } + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +// scalarSize returns the size of type identified by t. If t isn't a primitive +// type, the size isn't known at code generation time, and must be resolved via +// the marshal.Marshallable interface. +func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { + switch t.Name { + case "int8", "uint8", "byte": + return 1, false + case "int16", "uint16": + return 2, false + case "int32", "uint32": + return 4, false + case "int64", "uint64": + return 8, false + default: + return 0, true + } +} + +func (g *interfaceGenerator) shift(bufVar string, n int) { + g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) +} + +func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { + g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) +} + +func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(%s)\n", bufVar, accessor) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) + g.shift(bufVar, 8) + default: + g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + } +} + +func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { + switch typ { + case "int8": + g.emit("%s = int8(%s[0])\n", accessor, bufVar) + g.shift(bufVar, 1) + case "uint8": + g.emit("%s = uint8(%s[0])\n", accessor, bufVar) + g.shift(bufVar, 1) + case "byte": + g.emit("%s = %s[0]\n", accessor, bufVar) + g.shift(bufVar, 1) + + case "int16": + g.recordUsedImport("usermem") + g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) + g.shift(bufVar, 2) + case "uint16": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.shift(bufVar, 2) + + case "int32": + g.recordUsedImport("usermem") + g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) + g.shift(bufVar, 4) + case "uint32": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.shift(bufVar, 4) + + case "int64": + g.recordUsedImport("usermem") + g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) + g.shift(bufVar, 8) + case "uint64": + g.recordUsedImport("usermem") + g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.shift(bufVar, 8) + default: + g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + g.recordPotentiallyNonPackedField(accessor) + } +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +func (g *interfaceGenerator) emitMarshallable() { + // Is g.t a packed struct without consideing field types? + thisPacked := true + g.forEachField(func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if thisPacked { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + thisPacked = false + } + } + } + }) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n))) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n, t *ast.Ident, len int) { + if len < 1 { + // Zero-length arrays should've been rejected by validate(). + panic("unreachable") + } + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size * len + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for i := 0; i < %d; i++ {\n", size) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.forEachField(fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n, t *ast.Ident, size int) { + if n.Name == "_" { + g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len*size) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size) + } + return + } + + g.emit("for i := 0; i < %d; i++ {\n", size) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go new file mode 100644 index 000000000..df25cb5b2 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -0,0 +1,154 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "io" + "strings" +) + +var standardImports = []string{ + "fmt", + "reflect", + "testing", + "gvisor.dev/gvisor/tools/go_marshal/analysis", +} + +type testGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // Imports used by generated code. + imports *importTable + + // Import statement for the package declaring the type we generated code + // for. We need this to construct test instances for the type, since the + // tests aren't written in the same package. + decl *importStmt +} + +func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { + if _, ok := t.Type.(*ast.StructType); !ok { + panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) + } + g := &testGenerator{ + t: t, + r: receiverName(t), + imports: newImportTable(), + } + + for _, i := range standardImports { + g.imports.add(i).markUsed() + } + g.decl = g.imports.add(declaration) + g.decl.markUsed() + + return g +} + +func (g *testGenerator) typeName() string { + return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name) +} + +func (g *testGenerator) forEachField(fn func(f *ast.Field)) { + // This is guaranteed to succeed because g.t is always a struct. + st := g.t.Type.(*ast.StructType) + for _, field := range st.Fields.List { + fn(field) + } +} + +func (g *testGenerator) testFuncName(base string) string { + return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) +} + +func (g *testGenerator) inTestFunction(name string, body func()) { + g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) + g.inIndent(body) + g.emit("}\n\n") +} + +func (g *testGenerator) emitTestNonZeroSize() { + g.inTestFunction("TestSizeNonZero", func() { + g.emit("x := &%s{}\n", g.typeName()) + g.emit("if x.SizeBytes() == 0 {\n") + g.inIndent(func() { + g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSuspectAlignment() { + g.inTestFunction("TestSuspectAlignment", func() { + g.emit("x := %s{}\n", g.typeName()) + g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { + g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { + g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("buf := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalBytes(buf)\n") + g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") + + g.emit("y.UnmarshalBytes(buf)\n") + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + + g.emit("z.UnmarshalUnsafe(buf)\n") + g.emit("if !reflect.DeepEqual(x, z) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + }) + g.emit("}\n") + g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTests() { + g.emitTestNonZeroSize() + g.emitTestSuspectAlignment() + g.emitTestMarshalUnmarshalPreservesData() +} + +func (g *testGenerator) write(out io.Writer) error { + return g.sourceBuffer.write(out) +} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go new file mode 100644 index 000000000..967537abf --- /dev/null +++ b/tools/go_marshal/gomarshal/util.go @@ -0,0 +1,387 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/token" + "io" + "os" + "path" + "reflect" + "sort" + "strconv" + "strings" +) + +var debug = flag.Bool("debug", false, "enables debugging output") + +// receiverName returns an appropriate receiver name given a type spec. +func receiverName(t *ast.TypeSpec) string { + if len(t.Name.Name) < 1 { + // Zero length type name? + panic("unreachable") + } + return strings.ToLower(t.Name.Name[:1]) +} + +// kindString returns a user-friendly representation of an AST expr type. +func kindString(e ast.Expr) string { + switch e.(type) { + case *ast.Ident: + return "scalar" + case *ast.ArrayType: + return "array" + case *ast.StructType: + return "struct" + case *ast.StarExpr: + return "pointer" + case *ast.FuncType: + return "function" + case *ast.InterfaceType: + return "interface" + case *ast.MapType: + return "map" + case *ast.ChanType: + return "channel" + default: + return reflect.TypeOf(e).String() + } +} + +// fieldDispatcher is a collection of callbacks for handling different types of +// fields in a struct declaration. +type fieldDispatcher struct { + primitive func(n, t *ast.Ident) + selector func(n, tX, tSel *ast.Ident) + array func(n, t *ast.Ident, size int) + unhandled func(n *ast.Ident) +} + +// Precondition: All dispatch callbacks that will be invoked must be +// provided. Embedded fields are not allowed, len(f.Names) >= 1. +func (fd fieldDispatcher) dispatch(f *ast.Field) { + // Each field declaration may actually be multiple declarations of the same + // type. For example, consider: + // + // type Point struct { + // x, y, z int + // } + // + // We invoke the call-backs once per such instance. Embedded fields are not + // allowed, and results in a panic. + if len(f.Names) < 1 { + panic("Precondition not met: attempted to dispatch on embedded field") + } + + for _, name := range f.Names { + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(name, v) + case *ast.SelectorExpr: + fd.selector(name, v.X.(*ast.Ident), v.Sel) + case *ast.ArrayType: + len := 0 + if v.Len != nil { + // Non-literal array length is handled by generatorInterfaces.validate(). + if lenLit, ok := v.Len.(*ast.BasicLit); ok { + var err error + len, err = strconv.Atoi(lenLit.Value) + if err != nil { + panic(err) + } + } + } + switch t := v.Elt.(type) { + case *ast.Ident: + fd.array(name, t, len) + default: + fd.array(name, nil, len) + } + default: + fd.unhandled(name) + } + } +} + +// debugEnabled indicates whether debugging is enabled for gomarshal. +func debugEnabled() bool { + return *debug +} + +// abort aborts the go_marshal tool with the given error message. +func abort(msg string) { + if !strings.HasSuffix(msg, "\n") { + msg += "\n" + } + fmt.Print(msg) + os.Exit(1) +} + +// abortAt aborts the go_marshal tool with the given error message, with +// a reference position to the input source. +func abortAt(p token.Position, msg string) { + abort(fmt.Sprintf("%v:\n %s\n", p, msg)) +} + +// debugf conditionally prints a debug message. +func debugf(f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf(f, a...) + } +} + +// debugfAt conditionally prints a debug message with a reference to a position +// in the input source. +func debugfAt(p token.Position, f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) + } +} + +// emit generates a line of code in the output file. +// +// emit is a wrapper around writing a formatted string to the output +// buffer. emit can be invoked in one of two ways: +// +// (1) emit("some string") +// When emit is called with a single string argument, it is simply copied to +// the output buffer without any further formatting. +// (2) emit(fmtString, args...) +// emit can also be invoked in a similar fashion to *Printf() functions, +// where the first argument is a format string. +// +// Calling emit with a single argument that is not a string will result in a +// panic, as the caller's intent is ambiguous. +func emit(out io.Writer, indent int, a ...interface{}) { + const spacesPerIndentLevel = 4 + + if len(a) < 1 { + panic("emit() called with no arguments") + } + + if indent > 0 { + if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { + // Writing to the emit output should not fail. Typically the output + // is a byte.Buffer; writes to these never fail. + panic(err) + } + } + + first, ok := a[0].(string) + if !ok { + // First argument must be either the string to emit (case 1 from + // function-level comment), or a format string (case 2). + panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) + } + + if len(a) == 1 { + // Single string argument. Assume no formatting requested. + if _, err := fmt.Fprint(out, first); err != nil { + // Writing to out should not fail. + panic(err) + } + return + + } + + // Formatting requested. + if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { + // Writing to out should not fail. + panic(err) + } +} + +// sourceBuffer represents fragments of generated go source code. +// +// sourceBuffer provides a convenient way to build up go souce fragments in +// memory. May be safely zero-value initialized. Not thread-safe. +type sourceBuffer struct { + // Current indentation level. + indent int + + // Memory buffer containing contents while they're being generated. + b bytes.Buffer +} + +func (b *sourceBuffer) incIndent() { + b.indent++ +} + +func (b *sourceBuffer) decIndent() { + if b.indent <= 0 { + panic("decIndent() without matching incIndent()") + } + b.indent-- +} + +func (b *sourceBuffer) emit(a ...interface{}) { + emit(&b.b, b.indent, a...) +} + +func (b *sourceBuffer) emitNoIndent(a ...interface{}) { + emit(&b.b, 0 /*indent*/, a...) +} + +func (b *sourceBuffer) inIndent(body func()) { + b.incIndent() + body() + b.decIndent() +} + +func (b *sourceBuffer) write(out io.Writer) error { + _, err := fmt.Fprint(out, b.b.String()) + return err +} + +// Write implements io.Writer.Write. +func (b *sourceBuffer) Write(buf []byte) (int, error) { + return (b.b.Write(buf)) +} + +// importStmt represents a single import statement. +type importStmt struct { + // Local name of the imported package. + name string + // Import path. + path string + // Indicates whether the local name is an alias, or simply the final + // component of the path. + aliased bool + // Indicates whether this import was referenced by generated code. + used bool +} + +func newImport(p string) *importStmt { + name := path.Base(p) + return &importStmt{ + name: name, + path: p, + aliased: false, + } +} + +func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. + name := path.Base(p) + if name == "" || name == "/" || name == "." { + panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", + f.Position(spec.Path.Pos()), name)) + } + if spec.Name != nil { + name = spec.Name.Name + } + return &importStmt{ + name: name, + path: p, + aliased: spec.Name != nil, + } +} + +func (i *importStmt) String() string { + if i.aliased { + return fmt.Sprintf("%s \"%s\"", i.name, i.path) + } + return fmt.Sprintf("\"%s\"", i.path) +} + +func (i *importStmt) markUsed() { + i.used = true +} + +func (i *importStmt) equivalent(other *importStmt) bool { + return i == other +} + +// importTable represents a collection of importStmts. +type importTable struct { + // Map of imports and whether they should be copied to the output. + is map[string]*importStmt +} + +func newImportTable() *importTable { + return &importTable{ + is: make(map[string]*importStmt), + } +} + +// Merges import statements from other into i. Collisions in import statements +// result in a panic. +func (i *importTable) merge(other *importTable) { + for name, im := range other.is { + if dup, ok := i.is[name]; ok && dup.equivalent(im) { + panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) + } + + i.is[name] = im + } +} + +func (i *importTable) add(s string) *importStmt { + n := newImport(s) + i.is[n.name] = n + return n +} + +func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + n := newImportFromSpec(spec, f) + i.is[n.name] = n + return n +} + +// Marks the import named n as used. If no such import is in the table, returns +// false. +func (i *importTable) markUsed(n string) bool { + if n, ok := i.is[n]; ok { + n.markUsed() + return true + } + return false +} + +func (i *importTable) clear() { + for _, i := range i.is { + i.used = false + } +} + +func (i *importTable) write(out io.Writer) error { + if len(i.is) == 0 { + // Nothing to import, we're done. + return nil + } + + imports := make([]string, 0, len(i.is)) + for _, i := range i.is { + if i.used { + imports = append(imports, i.String()) + } + } + sort.Strings(imports) + + var b sourceBuffer + b.emit("import (\n") + b.incIndent() + for _, i := range imports { + b.emit("%s\n", i) + } + b.decIndent() + b.emit(")\n\n") + + return b.write(out) +} |