// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package gomarshal implements the go_marshal code generator. See README.md. package gomarshal import ( "bytes" "fmt" "go/ast" "go/parser" "go/token" "os" "sort" "strings" "gvisor.dev/gvisor/tools/tags" ) // List of identifiers we use in generated code that may conflict with a // similarly-named source identifier. Abort gracefully when we see these to // avoid potentially confusing compilation failures in generated code. // // This only applies to import aliases at the moment. All other identifiers // are qualified by a receiver argument, since they're struct fields. // // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", // All single-letter identifiers. } // Constructed fromt badIdents in init(). var badIdentsMap map[string]struct{} func init() { badIdentsMap = make(map[string]struct{}) for _, ident := range badIdents { badIdentsMap[ident] = struct{}{} } } // Generator drives code generation for a single invocation of the go_marshal // utility. // // The Generator holds arguments passed to the tool, and drives parsing, // processing and code Generator for all types marked with +marshal declared in // the input files. // // See Generator.run() as the entry point. type Generator struct { // Paths to input go source files. inputs []string // Output file to write generated go source. output *os.File // Output file to write generated tests. outputTest *os.File // Output file to write unconditionally generated tests. outputTestUC *os.File // Package name for the generated file. pkg string // Set of extra packages to import in the generated file. imports *importTable } // NewGenerator creates a new code Generator. func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("couldn't open output file %q: %w", out, err) } fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err) } fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err) } g := Generator{ inputs: srcs, output: f, outputTest: fTest, outputTestUC: fTestUC, pkg: pkg, imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } // The following imports may or may not be used by the generated code, // depending on what's required for the target types. Don't mark these as // used by default. g.imports.add("io") g.imports.add("reflect") g.imports.add("runtime") g.imports.add("unsafe") g.imports.add("gvisor.dev/gvisor/pkg/gohacks") g.imports.add("gvisor.dev/gvisor/pkg/safecopy") g.imports.add("gvisor.dev/gvisor/pkg/usermem") g.imports.add("gvisor.dev/gvisor/pkg/marshal") return &g, nil } // writeHeader writes the header for the generated source file. The header // includes the package name, package level comments and import statements. func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") // Emit build tags. if t := tags.Aggregate(g.inputs); len(t) > 0 { b.emit(strings.Join(t.Lines(), "\n")) b.emit("\n\n") } // Package header. b.emit("package %s\n\n", g.pkg) if err := b.write(g.output); err != nil { return err } return g.imports.write(g.output) } // writeTypeChecks writes a statement to force the compiler to perform a type // check for all Marshallable types referenced by the generated code. func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { if len(ms) == 0 { return nil } msl := make([]string, 0, len(ms)) for m := range ms { msl = append(msl, m) } sort.Strings(msl) var buf bytes.Buffer fmt.Fprint(&buf, "// Marshallable types used by this file.\n") for _, m := range msl { fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) } fmt.Fprint(&buf, "\n") _, err := fmt.Fprint(g.output, buf.String()) return err } // parse processes all input files passed this generator and produces a set of // parsed go ASTs. func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) for _, path := range g.inputs { debugf(" %s\n", path) } files := make([]*ast.File, 0, len(g.inputs)) fsets := make([]*token.FileSet, 0, len(g.inputs)) for _, path := range g.inputs { fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { // Not a valid input file? return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err) } if debugEnabled() { debugf("AST for %q:\n", path) ast.Print(fset, f) } files = append(files, f) fsets = append(fsets, fset) } return files, fsets, nil } // sliceAPI carries information about the '+marshal slice' directive. type sliceAPI struct { // Comment node in the AST containing the +marshal tag. comment *ast.Comment // Identifier fragment to use when naming generated functions for the slice // API. ident string // Whether the generated functions should reference the newtype name, or the // inner type name. Only meaningful on newtype declarations on primitives. inner bool } // marshallableType carries information about a type marked with the '+marshal' // directive. type marshallableType struct { spec *ast.TypeSpec slice *sliceAPI recv string dynamic bool } func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { mt := &marshallableType{ spec: spec, slice: nil, } var unhandledTags []string for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { if strings.HasPrefix(tag, "slice:") { tokens := strings.Split(tag, ":") if len(tokens) < 2 || len(tokens) > 3 { abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:[:inner]', got '%v'", tag)) } if len(tokens[1]) == 0 { abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") } sa := &sliceAPI{ comment: tagLine, ident: tokens[1], } mt.slice = sa if len(tokens) == 3 { if tokens[2] != "inner" { abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:[:inner]'") } sa.inner = true } continue } else if tag == "dynamic" { mt.dynamic = true continue } unhandledTags = append(unhandledTags, tag) } if len(unhandledTags) > 0 { abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) } return mt } // collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType { recv := make(map[string]string) // Type name to recevier name. types := make(map[*ast.TypeSpec]*marshallableType) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) // Type declaration? if !ok || gdecl.Tok != token.TYPE { // Is this a function declaration? We remember receiver names. d, ok := decl.(*ast.FuncDecl) if ok && d.Recv != nil && len(d.Recv.List) == 1 { // Accept concrete methods & pointer methods. ident, ok := d.Recv.List[0].Type.(*ast.Ident) if !ok { var st *ast.StarExpr st, ok = d.Recv.List[0].Type.(*ast.StarExpr) if ok { ident, ok = st.X.(*ast.Ident) } } // The receiver name may be not present. if ok && len(d.Recv.List[0].Names) == 1 { // Recover the type receiver name in this case. recv[ident.Name] = d.Recv.List[0].Names[0].Name } } debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") continue } // Does it have a comment? if gdecl.Doc == nil { debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") continue } // Does the comment contain a "+marshal" line? marked := false var tagLine *ast.Comment for _, c := range gdecl.Doc.List { if strings.HasPrefix(c.Text, "// +marshal") { marked = true tagLine = c break } } if !marked { debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") continue } for _, spec := range gdecl.Specs { // We already confirmed we're in a type declaration earlier, so this // cast will succeed. t := spec.(*ast.TypeSpec) switch t.Type.(type) { case *ast.StructType: debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) case *ast.Ident: // Newtype on primitive. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) case *ast.ArrayType: // Newtype on array. debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) default: // A user specifically requested marshalling on this type, but we // don't support it. abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } types[t] = newMarshallableType(f, tagLine, t) } } // Update the types with the last seen receiver. As long as the // receiver name is consistent for the type, then we will generate // code that is still consistent with itself. for t, mt := range types { r, ok := recv[t.Name.Name] if !ok { mt.recv = receiverName(t) // Default. continue } mt.recv = r // Last seen. } return types } // collectImports collects all imports from all input source files. Some of // these imports are copied to the generated output, if they're referenced by // the generated code. // // collectImports de-duplicates imports while building the list, and ensures // identifiers in the generated code don't conflict with any imported package // names. func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { is := make(map[string]importStmt) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) // Import statement? if !ok || gdecl.Tok != token.IMPORT { continue } for _, spec := range gdecl.Specs { i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) debugf("Collected import '%s' as '%s'\n", i.path, i.name) // Make sure we have an import that doesn't use any local names that // would conflict with identifiers in the generated code. if len(i.name) == 1 && i.name != "_" { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) } if _, ok := badIdentsMap[i.name]; ok { abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) } } } return is } func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { i := newInterfaceGenerator(t.spec, t.recv, fset) switch ty := t.spec.Type.(type) { case *ast.StructType: if t.dynamic { // Don't validate because this type is dynamically sized and probably // contains some funky slices which the validation does not allow. i.emitMarshallableForStruct(ty, t.dynamic) if t.slice != nil { abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") } break } i.validateStruct(t.spec, ty) i.emitMarshallableForStruct(ty, t.dynamic) if t.slice != nil { i.emitMarshallableSliceForStruct(ty, t.slice) } case *ast.Ident: i.validatePrimitiveNewtype(ty) if t.dynamic { abortAt(fset.Position(t.slice.comment.Slash), "Primitive type marked as '+marshal dynamic', but primitive types can not be dynamic.") } i.emitMarshallableForPrimitiveNewtype(ty) if t.slice != nil { i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) } case *ast.ArrayType: i.validateArrayNewtype(t.spec.Name, ty) if t.dynamic { abortAt(fset.Position(t.slice.comment.Slash), "Marking array types as `dynamic` is currently not supported.") } // After validate, we can safely call arrayLen. i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) if t.slice != nil { abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?") } default: // This should've been filtered out by collectMarshallabeTypes. panic(fmt.Sprintf("Unexpected type %+v", ty)) } return i } // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { i := newTestGenerator(t.spec, t.recv) i.emitTests(t.slice, t.dynamic) return i } // Run is the entry point to code generation using g. // // Run parses all input source files specified in g and emits generated code. func (g *Generator) Run() error { // Parse our input source files into ASTs and token sets. asts, fsets, err := g.parse() if err != nil { return err } if len(asts) != len(fsets) { panic("ASTs and FileSets don't match") } // Map of imports in source files; key = local package name, value = import // path. is := make(map[string]importStmt) for i, a := range asts { // Collect all imports from the source files. We may need to copy some // of these to the generated code if they're referenced. This has to be // done before the loop below because we need to process all ASTs before // we start requesting imports to be copied one by one as we encounter // them in each generated source. for name, i := range g.collectImports(a, fsets[i]) { is[name] = i } } var impls []*interfaceGenerator var ts []*testGenerator // Set of Marshallable types referenced by generated code. ms := make(map[string]struct{}) for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. var sortedTypes []*marshallableType for _, t := range g.collectMarshallableTypes(a, fsets[i]) { sortedTypes = append(sortedTypes, t) } sort.Slice(sortedTypes, func(x, y int) bool { // Sort by type name, which should be unique within a package. return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String() }) for _, t := range sortedTypes { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref := range impl.ms { ms[ref] = struct{}{} } impls = append(impls, impl) // Collect imports referenced by the generated code and add them to // the list of imports we need to copy to the generated code. for name := range impl.is { if !g.imports.markUsed(name) { panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } } ts = append(ts, g.generateOneTestSuite(t)) } } // Write output file header. These include things like package name and // import statements. if err := g.writeHeader(); err != nil { return err } // Write type checks for referenced marshallable types to output file. if err := g.writeTypeChecks(ms); err != nil { return err } // Write generated interfaces to output file. for _, i := range impls { if err := i.write(g.output); err != nil { return err } } // Write generated tests to test file. return g.writeTests(ts) } // writeTests outputs tests for the generated interface implementations to a go // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer // Write the unconditional test file. This file is always compiled, // regardless of what build tags were specified on the original input // files. We use this file to guarantee we never end up with an empty test // file, as that causes the build to fail with "no tests/benchmarks/examples // found". // // There's no easy way to determine ahead of time if we'll end up with an // empty build file since build constraints can arbitrarily cause some of // the original types to be not defined. We also have no way to tell bazel // to omit the entire test suite since the output files are already defined // before go-marshal is called. b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") b.emit("package %s\n\n", g.pkg) b.emit("func Example() {\n") b.inIndent(func() { b.emit("// This example is intentionally empty, and ensures this package contains at\n") b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") b.emit("// input package is marked marshallable, but emitting no testable entities \n") b.emit("// results in a build failure.\n") }) b.emit("}\n") if err := b.write(g.outputTestUC); err != nil { return err } // Now generate the real test file that contains the real types we // processed. These need to be conditionally compiled according to the build // tags, as the original types may not be defined under all build // configurations. b.reset() b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") // Emit build tags. if t := tags.Aggregate(g.inputs); len(t) > 0 { b.emit(strings.Join(t.Lines(), "\n")) b.emit("\n\n") } b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err } // Collect and write test import statements. imports := newImportTable() for _, t := range ts { imports.merge(t.imports) } if err := imports.write(g.outputTest); err != nil { return err } // Write test functions. for _, t := range ts { if err := t.write(g.outputTest); err != nil { return err } } return nil }