diff options
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r-- | tools/go_marshal/gomarshal/BUILD | 22 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator.go | 587 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces.go | 276 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go | 146 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_dynamic.go | 96 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go | 285 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_interfaces_struct.go | 616 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/generator_tests.go | 233 | ||||
-rw-r--r-- | tools/go_marshal/gomarshal/util.go | 503 |
9 files changed, 0 insertions, 2764 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD deleted file mode 100644 index aaa203115..000000000 --- a/tools/go_marshal/gomarshal/BUILD +++ /dev/null @@ -1,22 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -licenses(["notice"]) - -go_library( - name = "gomarshal", - srcs = [ - "generator.go", - "generator_interfaces.go", - "generator_interfaces_array_newtype.go", - "generator_interfaces_dynamic.go", - "generator_interfaces_primitive_newtype.go", - "generator_interfaces_struct.go", - "generator_tests.go", - "util.go", - ], - stateify = False, - visibility = [ - "//:sandbox", - ], - deps = ["//tools/constraintutil"], -) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go deleted file mode 100644 index 4c23637c0..000000000 --- a/tools/go_marshal/gomarshal/generator.go +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package gomarshal implements the go_marshal code generator. See README.md. -package gomarshal - -import ( - "bytes" - "fmt" - "go/ast" - "go/parser" - "go/token" - "os" - "sort" - "strings" - - "gvisor.dev/gvisor/tools/constraintutil" -) - -// List of identifiers we use in generated code that may conflict with a -// similarly-named source identifier. Abort gracefully when we see these to -// avoid potentially confusing compilation failures in generated code. -// -// This only applies to import aliases at the moment. All other identifiers -// are qualified by a receiver argument, since they're struct fields. -// -// All recievers are single letters, so we don't allow import aliases to be a -// single letter. -var badIdents = []string{ - "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", - "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", - // All single-letter identifiers. -} - -// Constructed fromt badIdents in init(). -var badIdentsMap map[string]struct{} - -func init() { - badIdentsMap = make(map[string]struct{}) - for _, ident := range badIdents { - badIdentsMap[ident] = struct{}{} - } -} - -// Generator drives code generation for a single invocation of the go_marshal -// utility. -// -// The Generator holds arguments passed to the tool, and drives parsing, -// processing and code Generator for all types marked with +marshal declared in -// the input files. -// -// See Generator.run() as the entry point. -type Generator struct { - // Paths to input go source files. - inputs []string - // Output file to write generated go source. - output *os.File - // Output file to write generated tests. - outputTest *os.File - // Output file to write unconditionally generated tests. - outputTestUC *os.File - // Package name for the generated file. - pkg string - // Set of extra packages to import in the generated file. - imports *importTable -} - -// NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) { - f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return nil, fmt.Errorf("couldn't open output file %q: %w", out, err) - } - fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err) - } - fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err) - } - g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - outputTestUC: fTestUC, - pkg: pkg, - imports: newImportTable(), - } - for _, i := range imports { - // All imports on the extra imports list are unconditionally marked as - // used, so that they're always added to the generated code. - g.imports.add(i).markUsed() - } - - // The following imports may or may not be used by the generated code, - // depending on what's required for the target types. Don't mark these as - // used by default. - g.imports.add("io") - g.imports.add("reflect") - g.imports.add("runtime") - g.imports.add("unsafe") - g.imports.add("gvisor.dev/gvisor/pkg/gohacks") - g.imports.add("gvisor.dev/gvisor/pkg/hostarch") - g.imports.add("gvisor.dev/gvisor/pkg/marshal") - return &g, nil -} - -// writeHeader writes the header for the generated source file. The header -// includes the package name, package level comments and import statements. -func (g *Generator) writeHeader() error { - var b sourceBuffer - b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") - - bcexpr, err := constraintutil.CombineFromFiles(g.inputs) - if err != nil { - return err - } - if bcexpr != nil { - // Emit build constraints. - b.emit("// If there are issues with build constraint aggregation, see\n") - b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n") - b.emit("// come from the input set of files used to generate this file. This input set\n") - b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n") - b.emit("// see tools/defs.bzl:calculate_sets().\n\n") - b.emit(constraintutil.Lines(bcexpr)) - } - - // Package header. - b.emit("package %s\n\n", g.pkg) - if err := b.write(g.output); err != nil { - return err - } - - return g.imports.write(g.output) -} - -// writeTypeChecks writes a statement to force the compiler to perform a type -// check for all Marshallable types referenced by the generated code. -func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { - if len(ms) == 0 { - return nil - } - - msl := make([]string, 0, len(ms)) - for m := range ms { - msl = append(msl, m) - } - sort.Strings(msl) - - var buf bytes.Buffer - fmt.Fprint(&buf, "// Marshallable types used by this file.\n") - - for _, m := range msl { - fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) - } - fmt.Fprint(&buf, "\n") - - _, err := fmt.Fprint(g.output, buf.String()) - return err -} - -// parse processes all input files passed this generator and produces a set of -// parsed go ASTs. -func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { - debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) - for _, path := range g.inputs { - debugf(" %s\n", path) - } - - files := make([]*ast.File, 0, len(g.inputs)) - fsets := make([]*token.FileSet, 0, len(g.inputs)) - - for _, path := range g.inputs { - fset := token.NewFileSet() - f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) - if err != nil { - // Not a valid input file? - return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err) - } - - if debugEnabled() { - debugf("AST for %q:\n", path) - ast.Print(fset, f) - } - - files = append(files, f) - fsets = append(fsets, fset) - } - - return files, fsets, nil -} - -// sliceAPI carries information about the '+marshal slice' directive. -type sliceAPI struct { - // Comment node in the AST containing the +marshal tag. - comment *ast.Comment - // Identifier fragment to use when naming generated functions for the slice - // API. - ident string - // Whether the generated functions should reference the newtype name, or the - // inner type name. Only meaningful on newtype declarations on primitives. - inner bool -} - -// marshallableType carries information about a type marked with the '+marshal' -// directive. -type marshallableType struct { - spec *ast.TypeSpec - slice *sliceAPI - recv string - dynamic bool -} - -func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { - mt := &marshallableType{ - spec: spec, - slice: nil, - } - - var unhandledTags []string - - for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { - if strings.HasPrefix(tag, "slice:") { - tokens := strings.Split(tag, ":") - if len(tokens) < 2 || len(tokens) > 3 { - abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) - } - if len(tokens[1]) == 0 { - abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") - } - - sa := &sliceAPI{ - comment: tagLine, - ident: tokens[1], - } - mt.slice = sa - - if len(tokens) == 3 { - if tokens[2] != "inner" { - abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") - } - sa.inner = true - } - - continue - } else if tag == "dynamic" { - mt.dynamic = true - continue - } - - unhandledTags = append(unhandledTags, tag) - } - - if len(unhandledTags) > 0 { - abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) - } - - return mt -} - -// collectMarshallableTypes walks the parsed AST and collects a list of type -// declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType { - recv := make(map[string]string) // Type name to recevier name. - types := make(map[*ast.TypeSpec]*marshallableType) - for _, decl := range a.Decls { - gdecl, ok := decl.(*ast.GenDecl) - // Type declaration? - if !ok || gdecl.Tok != token.TYPE { - // Is this a function declaration? We remember receiver names. - d, ok := decl.(*ast.FuncDecl) - if ok && d.Recv != nil && len(d.Recv.List) == 1 { - // Accept concrete methods & pointer methods. - ident, ok := d.Recv.List[0].Type.(*ast.Ident) - if !ok { - var st *ast.StarExpr - st, ok = d.Recv.List[0].Type.(*ast.StarExpr) - if ok { - ident, ok = st.X.(*ast.Ident) - } - } - // The receiver name may be not present. - if ok && len(d.Recv.List[0].Names) == 1 { - // Recover the type receiver name in this case. - recv[ident.Name] = d.Recv.List[0].Names[0].Name - } - } - debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") - continue - } - // Does it have a comment? - if gdecl.Doc == nil { - debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") - continue - } - // Does the comment contain a "+marshal" line? - marked := false - var tagLine *ast.Comment - for _, c := range gdecl.Doc.List { - if strings.HasPrefix(c.Text, "// +marshal") { - marked = true - tagLine = c - break - } - } - if !marked { - debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") - continue - } - for _, spec := range gdecl.Specs { - // We already confirmed we're in a type declaration earlier, so this - // cast will succeed. - t := spec.(*ast.TypeSpec) - switch t.Type.(type) { - case *ast.StructType: - debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) - case *ast.Ident: // Newtype on primitive. - debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) - case *ast.ArrayType: // Newtype on array. - debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) - default: - // A user specifically requested marshalling on this type, but we - // don't support it. - abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) - } - types[t] = newMarshallableType(f, tagLine, t) - } - } - // Update the types with the last seen receiver. As long as the - // receiver name is consistent for the type, then we will generate - // code that is still consistent with itself. - for t, mt := range types { - r, ok := recv[t.Name.Name] - if !ok { - mt.recv = receiverName(t) // Default. - continue - } - mt.recv = r // Last seen. - } - return types -} - -// collectImports collects all imports from all input source files. Some of -// these imports are copied to the generated output, if they're referenced by -// the generated code. -// -// collectImports de-duplicates imports while building the list, and ensures -// identifiers in the generated code don't conflict with any imported package -// names. -func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { - is := make(map[string]importStmt) - for _, decl := range a.Decls { - gdecl, ok := decl.(*ast.GenDecl) - // Import statement? - if !ok || gdecl.Tok != token.IMPORT { - continue - } - for _, spec := range gdecl.Specs { - i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) - debugf("Collected import '%s' as '%s'\n", i.path, i.name) - - // Make sure we have an import that doesn't use any local names that - // would conflict with identifiers in the generated code. - if len(i.name) == 1 && i.name != "_" { - abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) - } - if _, ok := badIdentsMap[i.name]; ok { - abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) - } - } - } - return is - -} - -func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { - i := newInterfaceGenerator(t.spec, t.recv, fset) - if t.dynamic { - if t.slice != nil { - abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.") - } - // No validation needed, assume the user knows what they are doing. - i.emitMarshallableForDynamicType() - return i - } - switch ty := t.spec.Type.(type) { - case *ast.StructType: - i.validateStruct(t.spec, ty) - i.emitMarshallableForStruct(ty) - if t.slice != nil { - i.emitMarshallableSliceForStruct(ty, t.slice) - } - case *ast.Ident: - i.validatePrimitiveNewtype(ty) - i.emitMarshallableForPrimitiveNewtype(ty) - if t.slice != nil { - i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) - } - case *ast.ArrayType: - i.validateArrayNewtype(t.spec.Name, ty) - // After validate, we can safely call arrayLen. - i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) - if t.slice != nil { - abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?") - } - default: - // This should've been filtered out by collectMarshallabeTypes. - panic(fmt.Sprintf("Unexpected type %+v", ty)) - } - return i -} - -// generateOneTestSuite generates a test suite for the automatically generated -// implementations type t. -func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { - i := newTestGenerator(t.spec, t.recv) - i.emitTests(t.slice) - return i -} - -// Run is the entry point to code generation using g. -// -// Run parses all input source files specified in g and emits generated code. -func (g *Generator) Run() error { - // Parse our input source files into ASTs and token sets. - asts, fsets, err := g.parse() - if err != nil { - return err - } - - if len(asts) != len(fsets) { - panic("ASTs and FileSets don't match") - } - - // Map of imports in source files; key = local package name, value = import - // path. - is := make(map[string]importStmt) - for i, a := range asts { - // Collect all imports from the source files. We may need to copy some - // of these to the generated code if they're referenced. This has to be - // done before the loop below because we need to process all ASTs before - // we start requesting imports to be copied one by one as we encounter - // them in each generated source. - for name, i := range g.collectImports(a, fsets[i]) { - is[name] = i - } - } - - var impls []*interfaceGenerator - var ts []*testGenerator - // Set of Marshallable types referenced by generated code. - ms := make(map[string]struct{}) - for i, a := range asts { - // Collect type declarations marked for code generation and generate - // Marshallable interfaces. - var sortedTypes []*marshallableType - for _, t := range g.collectMarshallableTypes(a, fsets[i]) { - sortedTypes = append(sortedTypes, t) - } - sort.Slice(sortedTypes, func(x, y int) bool { - // Sort by type name, which should be unique within a package. - return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String() - }) - for _, t := range sortedTypes { - impl := g.generateOne(t, fsets[i]) - // Collect Marshallable types referenced by the generated code. - for ref := range impl.ms { - ms[ref] = struct{}{} - } - impls = append(impls, impl) - // Collect imports referenced by the generated code and add them to - // the list of imports we need to copy to the generated code. - for name := range impl.is { - if !g.imports.markUsed(name) { - panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) - } - } - // Do not generate tests for dynamic types because they inherently - // violate some go_marshal requirements. - if !t.dynamic { - ts = append(ts, g.generateOneTestSuite(t)) - } - } - } - - // Write output file header. These include things like package name and - // import statements. - if err := g.writeHeader(); err != nil { - return err - } - - // Write type checks for referenced marshallable types to output file. - if err := g.writeTypeChecks(ms); err != nil { - return err - } - - // Write generated interfaces to output file. - for _, i := range impls { - if err := i.write(g.output); err != nil { - return err - } - } - - // Write generated tests to test file. - return g.writeTests(ts) -} - -// writeTests outputs tests for the generated interface implementations to a go -// source file. -func (g *Generator) writeTests(ts []*testGenerator) error { - var b sourceBuffer - - // Write the unconditional test file. This file is always compiled, - // regardless of what build tags were specified on the original input - // files. We use this file to guarantee we never end up with an empty test - // file, as that causes the build to fail with "no tests/benchmarks/examples - // found". - // - // There's no easy way to determine ahead of time if we'll end up with an - // empty build file since build constraints can arbitrarily cause some of - // the original types to be not defined. We also have no way to tell bazel - // to omit the entire test suite since the output files are already defined - // before go-marshal is called. - b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") - b.emit("package %s\n\n", g.pkg) - b.emit("func Example() {\n") - b.inIndent(func() { - b.emit("// This example is intentionally empty, and ensures this package contains at\n") - b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n") - b.emit("// input package is marked marshallable, but emitting no testable entities \n") - b.emit("// results in a build failure.\n") - }) - b.emit("}\n") - if err := b.write(g.outputTestUC); err != nil { - return err - } - - // Now generate the real test file that contains the real types we - // processed. These need to be conditionally compiled according to the build - // tags, as the original types may not be defined under all build - // configurations. - - b.reset() - b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") - - // Emit build constraints. - bcexpr, err := constraintutil.CombineFromFiles(g.inputs) - if err != nil { - return err - } - b.emit(constraintutil.Lines(bcexpr)) - - b.emit("package %s\n\n", g.pkg) - if err := b.write(g.outputTest); err != nil { - return err - } - - // Collect and write test import statements. - imports := newImportTable() - for _, t := range ts { - imports.merge(t.imports) - } - - if err := imports.write(g.outputTest); err != nil { - return err - } - - // Write test functions. - for _, t := range ts { - if err := t.write(g.outputTest); err != nil { - return err - } - } - return nil -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go deleted file mode 100644 index 3e643e77f..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "fmt" - "go/ast" - "go/token" - "strings" -) - -// interfaceGenerator generates marshalling interfaces for a single type. -// -// getState is not thread-safe. -type interfaceGenerator struct { - sourceBuffer - - // The type we're serializing. - t *ast.TypeSpec - - // Receiver argument for generated methods. - r string - - // FileSet containing the tokens for the type we're processing. - f *token.FileSet - - // is records external packages referenced by the generated implementation. - is map[string]struct{} - - // ms records Marshallable types referenced by the generated implementation - // of t's interfaces. - ms map[string]struct{} - - // as records fields in t that are potentially not packed. The key is the - // accessor for the field. - as map[string]struct{} -} - -// typeName returns the name of the type this g represents. -func (g *interfaceGenerator) typeName() string { - return g.t.Name.Name -} - -// newinterfaceGenerator creates a new interface generator. -func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator { - g := &interfaceGenerator{ - t: t, - r: r, - f: fset, - is: make(map[string]struct{}), - ms: make(map[string]struct{}), - as: make(map[string]struct{}), - } - g.recordUsedMarshallable(g.typeName()) - return g -} - -func (g *interfaceGenerator) recordUsedMarshallable(m string) { - g.ms[m] = struct{}{} - -} - -func (g *interfaceGenerator) recordUsedImport(i string) { - g.is[i] = struct{}{} -} - -func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { - g.as[fieldName] = struct{}{} -} - -// abortAt aborts the go_marshal tool with the given error message, with a -// reference position to the input source. Same as abortAt, but uses g to -// resolve p to position. -func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { - abortAt(g.f.Position(p), msg) -} - -// scalarSize returns the size of type identified by t. If t isn't a primitive -// type, the size isn't known at code generation time, and must be resolved via -// the marshal.Marshallable interface. -func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { - switch t.Name { - case "int8", "uint8", "byte": - return 1, false - case "int16", "uint16": - return 2, false - case "int32", "uint32": - return 4, false - case "int64", "uint64": - return 8, false - default: - return 0, true - } -} - -func (g *interfaceGenerator) shift(bufVar string, n int) { - g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) -} - -func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { - g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) -} - -// marshalScalar writes a single scalar to a byte slice. -func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(%s)\n", bufVar, accessor) - g.shift(bufVar, 1) - case "int16", "uint16": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) - g.shift(bufVar, 2) - case "int32", "uint32": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) - g.shift(bufVar, 4) - case "int64", "uint64": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) - g.shift(bufVar, 8) - default: - g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) - } -} - -// unmarshalScalar reads a single scalar from a byte slice. -func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { - switch typ { - case "byte": - g.emit("%s = %s[0]\n", accessor, bufVar) - g.shift(bufVar, 1) - case "int8", "uint8": - g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) - g.shift(bufVar, 1) - case "int16", "uint16": - g.recordUsedImport("hostarch") - g.emit("%s = %s(hostarch.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) - g.shift(bufVar, 2) - case "int32", "uint32": - g.recordUsedImport("hostarch") - g.emit("%s = %s(hostarch.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) - g.shift(bufVar, 4) - case "int64", "uint64": - g.recordUsedImport("hostarch") - g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) - g.shift(bufVar, 8) - default: - g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) - g.recordPotentiallyNonPackedField(accessor) - } -} - -// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a -// byte slice, bypassing escape analysis. The caller is responsible for ensuring -// srcPtr lives until they're done with dstVar, the runtime does not consider -// dstVar dependent on srcPtr due to the escape analysis bypass. -// -// srcPtr must be a pointer. -// -// This function uses internally uses the identifier "hdr", and cannot be used -// in a context where it is already bound. -func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) { - g.recordUsedImport("gohacks") - g.emit("// Construct a slice backed by dst's underlying memory.\n") - g.emit("var %s []byte\n", dstVar) - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) - g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr) - g.emit("hdr.Len = %s\n", lenExpr) - g.emit("hdr.Cap = %s\n\n", lenExpr) -} - -// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type -// to a byte slice. As part of the cast, the byte slice is made to look -// independent of the src slice by bypassing escape analysis. This means the -// byte slice can be used without causing the source to escape. The caller is -// responsible for ensuring srcPtr lives until they're done with dstVar, as the -// runtime no longer considers dstVar dependent on srcPtr and is free to GC it. -// -// srcPtr must be a pointer. -// -// This function uses internally uses the identifiers "ptr", "val" and "hdr", -// and cannot be used in a context where these identifiers are already bound. -func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) { - g.emitNoEscapeSliceDataPointer(srcPtr, "val") - - g.emit("// Construct a slice backed by dst's underlying memory.\n") - g.emit("var %s []byte\n", dstVar) - g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) - g.emit("hdr.Data = uintptr(val)\n") - g.emit("hdr.Len = %s\n", lenExpr) - g.emit("hdr.Cap = %s\n\n", lenExpr) -} - -// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an -// unsafe.Pointer, bypassing escape analysis. The caller is responsible for -// ensuring srcPtr lives until they're done with dstVar, as the runtime no -// longer considers dstVar dependent on srcPtr and is free to GC it. -// -// srcPtr must be a pointer. -// -// This function uses internally uses the identifier "ptr" cannot be used in a -// context where this identifier is already bound. -func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) { - g.recordUsedImport("gohacks") - g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr) - g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar) -} - -func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { - g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) - g.emit("// must live until the use above.\n") - g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar) -} - -func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { - switch x := e.X.(type) { - case *ast.BinaryExpr: - // Recursively expand sub-expression. - g.expandBinaryExpr(b, x) - case *ast.Ident: - fmt.Fprintf(b, "%s", x.Name) - case *ast.BasicLit: - fmt.Fprintf(b, "%s", x.Value) - default: - g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") - } - - fmt.Fprintf(b, "%s", e.Op) - - switch y := e.Y.(type) { - case *ast.BinaryExpr: - // Recursively expand sub-expression. - g.expandBinaryExpr(b, y) - case *ast.Ident: - fmt.Fprintf(b, "%s", y.Name) - case *ast.BasicLit: - fmt.Fprintf(b, "%s", y.Value) - default: - g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") - } -} - -// arrayLenExpr returns a string containing a valid golang expression -// representing the length of array a. The returned expression should be treated -// as a single value, and will be already parenthesized as required. -func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string { - var b strings.Builder - - switch l := a.Len.(type) { - case *ast.Ident: - fmt.Fprintf(&b, "%s", l.Name) - case *ast.BasicLit: - fmt.Fprintf(&b, "%s", l.Value) - case *ast.BinaryExpr: - g.expandBinaryExpr(&b, l) - return fmt.Sprintf("(%s)", b.String()) - default: - g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") - } - return b.String() -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go deleted file mode 100644 index bd7741ae5..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// newtypes on arrays. - -package gomarshal - -import ( - "fmt" - "go/ast" -) - -func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { - if a.Len == nil { - g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) - } - - if _, ok := a.Elt.(*ast.Ident); !ok { - g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) - } -} - -func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { - g.recordUsedImport("gohacks") - g.recordUsedImport("hostarch") - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - - lenExpr := g.arrayLenExpr(a) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(elt); !dynamic { - g.emit("return %d * %s\n", size, lenExpr) - } else { - g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") - }) - g.emit("}\n") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") - }) - g.emit("}\n") - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Array newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := w.Write(buf)\n") - g.emitKeepAlive(g.r) - g.emit("return int64(length), err\n") - - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go deleted file mode 100644 index 345020ddc..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -func (g *interfaceGenerator) emitMarshallableForDynamicType() { - // The user writes their own MarshalBytes, UnmarshalBytes and SizeBytes for - // dynamic types. Generate the rest using these definitions. - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s is dynamic so it might have slice/string headers. Hence, it is not packed.\n", g.typeName()) - g.emit("return false\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") - g.emit("// partially unmarshalled struct.\n") - g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("length, err := writer.Write(buf)\n") - g.emit("return int64(length), err\n") - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go deleted file mode 100644 index ba4b7324e..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ /dev/null @@ -1,285 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// newtypes on primitives. - -package gomarshal - -import ( - "fmt" - "go/ast" -) - -// marshalPrimitiveScalar writes a single primitive variable to a byte -// slice. -func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { - switch typ { - case "int8", "uint8", "byte": - g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) - case "int16", "uint16": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) - case "int32", "uint32": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) - case "int64", "uint64": - g.recordUsedImport("hostarch") - g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) - default: - g.emit("// Explicilty cast to the underlying type before dispatching to\n") - g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. -func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { - switch typ { - case "byte": - g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) - case "int8", "uint8": - g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) - case "int16", "uint16": - g.recordUsedImport("hostarch") - g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) - case "int32", "uint32": - g.recordUsedImport("hostarch") - g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) - case "int64", "uint64": - g.recordUsedImport("hostarch") - g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) - default: - g.emit("// Explicilty cast to the underlying type before dispatching to\n") - g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) - g.emit("inner := (*%s)(%s)\n", typ, accessor) - g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) - } -} - -func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we provide - // suggestions for some disallowed types and reject them, then attempt - // to marshal any remaining types by invoking the marshal.Marshallable - // interface on them. If these types don't actually implement - // marshal.Marshallable, compilation of the generated code will fail - // with an appropriate error message. - return - case "int": - g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } -} - -// emitMarshallableForPrimitiveNewtype outputs code to implement the -// marshal.Marshallable interface for a newtype on a primitive. Primitive -// newtypes are always packed, so we can omit the various fallbacks required for -// non-packed structs. -func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { - g.recordUsedImport("gohacks") - g.recordUsedImport("hostarch") - g.recordUsedImport("io") - g.recordUsedImport("marshal") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - if size, dynamic := g.scalarSize(nt); !dynamic { - g.emit("return %d\n", size) - } else { - g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) - } - }) - g.emit("}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.marshalPrimitiveScalar(g.r, nt.Name, "dst") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("// Scalar newtypes are always packed.\n") - g.emit("return true\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := w.Write(buf)\n") - g.emitKeepAlive(g.r) - g.emit("return int64(length), err\n") - - }) - g.emit("}\n\n") -} - -func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) { - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - - eltType := g.typeName() - if slice.inner { - eltType = nt.Name - } - - g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) - g.emit("//go:nosplit\n") - g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) - g.inIndent(func() { - g.emit("count := len(dst)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emitKeepAlive("dst") - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) - g.emit("//go:nosplit\n") - g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, eltType) - g.inIndent(func() { - g.emit("count := len(src)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emitCastSliceToByteSlice("&src", "buf", "size * count") - - g.emit("length, err := cc.CopyOutBytes(addr, buf) // escapes: okay.\n") - g.emitKeepAlive("src") - g.emit("return length, err\n") - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) - g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(src)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emit("dst = dst[:size*count]\n") - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n") - g.emit("return size*count, nil\n") - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) - g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(dst)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emit("src = src[:(size*count)]\n") - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n") - g.emit("return size*count, nil\n") - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go deleted file mode 100644 index 4c47218f1..000000000 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains the bits of the code generator specific to marshalling -// structs. - -package gomarshal - -import ( - "fmt" - "go/ast" - "sort" - "strings" -) - -func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { - return fmt.Sprintf("%s.%s", g.r, n.Name) -} - -// areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially -// packed, otherwise returns <clause>, true, where <clause> is an expression -// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". -func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { - if len(g.as) == 0 { - return "", false - } - - cs := make([]string, 0, len(g.as)) - for accessor := range g.as { - cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) - } - // Sort expressions for determinstic build outputs. - sort.Strings(cs) - return strings.Join(cs, " && "), true -} - -// validateStruct ensures the type we're working with can be marshalled. These -// checks are done ahead of time and in one place so we can make assumptions -// later. -func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { - forEachStructField(st, func(f *ast.Field) { - fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - g.validatePrimitiveNewtype(t) - }, - selector: func(_, _, _ *ast.Ident) { - // No validation to perform on selector fields. However this - // callback must still be provided. - }, - array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) { - g.validateArrayNewtype(n, a) - }, - unhandled: func(_ *ast.Ident) { - g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) - }, - }.dispatch(f) - }) -} - -func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { - packed := true - forEachStructField(st, func(f *ast.Field) { - if f.Tag != nil { - if f.Tag.Value == "`marshal:\"unaligned\"`" { - if packed { - debugfAt(g.f.Position(g.t.Pos()), - fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) - packed = false - } - } - } - }) - return packed -} - -func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { - thisPacked := g.isStructPacked(st) - - g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") - g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) - g.inIndent(func() { - primitiveSize := 0 - var dynamicSizeTerms []string - - forEachStructField(st, fieldDispatcher{ - primitive: func(_, t *ast.Ident) { - if size, dynamic := g.scalarSize(t); !dynamic { - primitiveSize += size - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) - } - }, - selector: func(_, tX, tSel *ast.Ident) { - tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) - g.recordUsedImport(tX.Name) - g.recordUsedMarshallable(tName) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) - }, - array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if size, dynamic := g.scalarSize(t); !dynamic { - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) - } else { - g.recordUsedMarshallable(t.Name) - dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) - } - }, - }.dispatch) - g.emit("return %d", primitiveSize) - if len(dynamicSizeTerms) > 0 { - g.incIndent() - } - { - for _, d := range dynamicSizeTerms { - g.emitNoIndent(" +\n") - g.emit(d) - } - } - if len(dynamicSizeTerms) > 0 { - g.decIndent() - } - }) - g.emit("\n}\n\n") - - g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("dst", len) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) - } - return - } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) - g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - return - } - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can reference here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - forEachStructField(st, fieldDispatcher{ - primitive: func(n, t *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) - if len, dynamic := g.scalarSize(t); !dynamic { - g.shift("src", len) - } else { - // We don't have an instance of the dynamic type we can - // reference here (since the version in this struct is - // anonymous). Use a typed nil pointer to call - // SizeBytes() instead. - g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) - } - return - } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") - }, - selector: func(n, tX, tSel *ast.Ident) { - if n.Name == "_" { - g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) - g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) - g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) - return - } - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") - }, - array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { - lenExpr := g.arrayLenExpr(a) - if n.Name == "_" { - g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) - if size, dynamic := g.scalarSize(t); !dynamic { - g.emit("src = src[%d*(%s):]\n", size, lenExpr) - } else { - // We can't use shiftDynamic here because we don't have - // an instance of the dynamic type we can referece here - // (since the version in this struct is anonymous). Use - // a typed nil pointer to call SizeBytes() instead. - g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) - } - return - } - - g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) - g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") - }) - g.emit("}\n") - }, - }.dispatch) - }) - g.emit("}\n\n") - - g.emit("// Packed implements marshal.Marshallable.Packed.\n") - g.emit("//go:nosplit\n") - g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) - g.inIndent(func() { - expr, fieldsMaybePacked := g.areFieldsPackedExpression() - switch { - case !thisPacked: - g.emit("return false\n") - case fieldsMaybePacked: - g.emit("return %s\n", expr) - default: - g.emit("return true\n") - - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) - } - if thisPacked { - g.recordUsedImport("gohacks") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) - g.emit("}\n") - } else { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) - } - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) - } - if thisPacked { - g.recordUsedImport("gohacks") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) - g.emit("}\n") - } else { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - } - } else { - fallback() - } - }) - g.emit("}\n\n") - g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) - }) - g.emit("}\n\n") - - g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") - g.emit("//go:nosplit\n") - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") - g.emit("// partially unmarshalled struct.\n") - g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return length, err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast deserialization. - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") - g.emitKeepAlive(g.r) - g.emit("return length, err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") - g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName()) - g.inIndent(func() { - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) - g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("length, err := writer.Write(buf)\n") - g.emit("return int64(length), err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if cond, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !%s {\n", cond) - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - - g.emit("length, err := writer.Write(buf)\n") - g.emitKeepAlive(g.r) - g.emit("return int64(length), err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") -} - -func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { - thisPacked := g.isStructPacked(st) - - if slice.inner { - abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) - } - - g.recordUsedImport("marshal") - g.recordUsedImport("hostarch") - - g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(dst)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(size * count)\n") - g.emit("length, err := cc.CopyInBytes(addr, buf)\n\n") - - g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") - g.emit("limit := length/size\n") - g.emit("for idx := 0; idx < limit; idx++ {\n") - g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") - }) - g.emit("}\n\n") - - g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n") - g.emit("// final element, but may not contain valid data for the entire range. This may\n") - g.emit("// result in unmarshalling zero values for some parts of the object.\n") - g.emit("if length%size != 0 {\n") - g.inIndent(func() { - g.emit("idx := limit\n") - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") - }) - g.emit("}\n\n") - - g.emit("return length, err\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if _, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !dst[0].Packed() {\n") - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast deserialization. - g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - - g.emit("length, err := cc.CopyInBytes(addr, buf)\n") - g.emitKeepAlive("dst") - g.emit("return length, err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(src)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := cc.CopyScratchBuffer(size * count)\n") - g.emit("for idx := 0; idx < count; idx++ {\n") - g.inIndent(func() { - g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") - }) - g.emit("}\n") - g.emit("return cc.CopyOutBytes(addr, buf)\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - if _, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !src[0].Packed() {\n") - g.inIndent(fallback) - g.emit("}\n\n") - } - // Fast serialization. - g.emitCastSliceToByteSlice("&src", "buf", "size * count") - - g.emit("length, err := cc.CopyOutBytes(addr, buf)\n") - g.emitKeepAlive("src") - g.emit("return length, err\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) - g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(src)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("for idx := 0; idx < count; idx++ {\n") - g.inIndent(func() { - g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") - }) - g.emit("}\n") - g.emit("return size * count, nil\n") - } - if thisPacked { - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - g.recordUsedImport("unsafe") - g.recordUsedImport("gohacks") - if _, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !src[0].Packed() {\n") - g.inIndent(fallback) - g.emit("}\n\n") - } - g.emit("dst = dst[:size*count]\n") - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n") - g.emit("return size * count, nil\n") - } else { - fallback() - } - }) - g.emit("}\n\n") - - g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) - g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) - g.inIndent(func() { - g.emit("count := len(dst)\n") - g.emit("if count == 0 {\n") - g.inIndent(func() { - g.emit("return 0, nil\n") - }) - g.emit("}\n") - g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - fallback := func() { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("for idx := 0; idx < count; idx++ {\n") - g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") - }) - g.emit("}\n") - g.emit("return size * count, nil\n") - } - if thisPacked { - g.recordUsedImport("gohacks") - g.recordUsedImport("reflect") - g.recordUsedImport("runtime") - if _, ok := g.areFieldsPackedExpression(); ok { - g.emit("if !dst[0].Packed() {\n") - g.inIndent(fallback) - g.emit("}\n\n") - } - - g.emit("src = src[:(size*count)]\n") - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n") - - g.emit("return count*size, nil\n") - } else { - fallback() - } - }) - g.emit("}\n\n") -} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go deleted file mode 100644 index 8f93a1de5..000000000 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "fmt" - "go/ast" - "io" - "strings" -) - -var standardImports = []string{ - "bytes", - "fmt", - "reflect", - "testing", - - "gvisor.dev/gvisor/tools/go_marshal/analysis", -} - -var sliceAPIImports = []string{ - "encoding/binary", - "gvisor.dev/gvisor/pkg/hostarch", -} - -type testGenerator struct { - sourceBuffer - - // The type we're serializing. - t *ast.TypeSpec - - // Receiver argument for generated methods. - r string - - // Imports used by generated code. - imports *importTable - - // Import statement for the package declaring the type we generated code - // for. We need this to construct test instances for the type, since the - // tests aren't written in the same package. - decl *importStmt -} - -func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator { - g := &testGenerator{ - t: t, - r: r, - imports: newImportTable(), - } - - for _, i := range standardImports { - g.imports.add(i).markUsed() - } - // These imports are used if a type requests the slice API. Don't - // mark them as used by default. - for _, i := range sliceAPIImports { - g.imports.add(i) - } - - return g -} - -func (g *testGenerator) typeName() string { - return g.t.Name.Name -} - -func (g *testGenerator) testFuncName(base string) string { - return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) -} - -func (g *testGenerator) inTestFunction(name string, body func()) { - g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) - g.inIndent(body) - g.emit("}\n\n") -} - -func (g *testGenerator) emitTestNonZeroSize() { - g.inTestFunction("TestSizeNonZero", func() { - g.emit("var x %v\n", g.typeName()) - g.emit("if x.SizeBytes() == 0 {\n") - g.inIndent(func() { - g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestSuspectAlignment() { - g.inTestFunction("TestSuspectAlignment", func() { - g.emit("var x %v\n", g.typeName()) - g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") - }) -} - -func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { - g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { - g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) - g.emit("analysis.RandomizeValue(&x)\n\n") - - g.emit("buf := make([]byte, x.SizeBytes())\n") - g.emit("x.MarshalBytes(buf)\n") - g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") - g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") - - g.emit("y.UnmarshalBytes(buf)\n") - g.emit("if !reflect.DeepEqual(x, y) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") - }) - g.emit("}\n") - g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") - g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") - }) - g.emit("}\n\n") - - g.emit("z.UnmarshalUnsafe(buf)\n") - g.emit("if !reflect.DeepEqual(x, z) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") - }) - g.emit("}\n") - g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") - g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { - for _, name := range []string{"binary", "hostarch"} { - if !g.imports.markUsed(name) { - panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) - } - } - - g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { - g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) - g.emit("analysis.RandomizeValue(&x)\n\n") - g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) - g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") - g.emit("buf.Reset()\n") - g.emit("if err := binary.Write(buf, hostarch.ByteOrder, x[:]); err != nil {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") - }) - g.emit("}\n") - g.emit("bufUnsafe := make([]byte, size)\n") - g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) - - g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) - g.emit("if !reflect.DeepEqual(x, y) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") - }) - g.emit("}\n") - g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) - g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") - }) - g.emit("}\n\n") - }) -} - -func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { - g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { - g.emit("var x, y, yUnsafe %s\n", g.typeName()) - g.emit("analysis.RandomizeValue(&x)\n\n") - - g.emit("var buf bytes.Buffer\n\n") - - g.emit("x.WriteTo(&buf)\n") - g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") - g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") - - g.emit("if !reflect.DeepEqual(x, y) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") - }) - g.emit("}\n") - g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") - g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { - g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { - g.emit("var x %s\n", g.typeName()) - g.emit("sizeFromConcrete := x.SizeBytes()\n") - g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) - - g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") - g.inIndent(func() { - g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") - }) - g.emit("}\n") - }) -} - -func (g *testGenerator) emitTests(slice *sliceAPI) { - g.emitTestNonZeroSize() - g.emitTestSuspectAlignment() - g.emitTestMarshalUnmarshalPreservesData() - g.emitTestWriteToUnmarshalPreservesData() - g.emitTestSizeBytesOnTypedNilPtr() - - if slice != nil { - g.emitTestMarshalUnmarshalSlicePreservesData(slice) - } -} - -func (g *testGenerator) write(out io.Writer) error { - return g.sourceBuffer.write(out) -} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go deleted file mode 100644 index 6a42691cd..000000000 --- a/tools/go_marshal/gomarshal/util.go +++ /dev/null @@ -1,503 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gomarshal - -import ( - "bytes" - "flag" - "fmt" - "go/ast" - "go/token" - "io" - "os" - "path" - "reflect" - "sort" - "strings" -) - -var debug = flag.Bool("debug", false, "enables debugging output") - -// receiverName returns an appropriate receiver name given a type spec. -func receiverName(t *ast.TypeSpec) string { - if len(t.Name.Name) < 1 { - // Zero length type name? - panic("unreachable") - } - return strings.ToLower(t.Name.Name[:1]) -} - -// kindString returns a user-friendly representation of an AST expr type. -func kindString(e ast.Expr) string { - switch e.(type) { - case *ast.Ident: - return "scalar" - case *ast.ArrayType: - return "array" - case *ast.StructType: - return "struct" - case *ast.StarExpr: - return "pointer" - case *ast.FuncType: - return "function" - case *ast.InterfaceType: - return "interface" - case *ast.MapType: - return "map" - case *ast.ChanType: - return "channel" - default: - return reflect.TypeOf(e).String() - } -} - -func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { - for _, field := range st.Fields.List { - fn(field) - } -} - -// fieldDispatcher is a collection of callbacks for handling different types of -// fields in a struct declaration. -type fieldDispatcher struct { - primitive func(n, t *ast.Ident) - selector func(n, tX, tSel *ast.Ident) - array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) - unhandled func(n *ast.Ident) -} - -// Precondition: All dispatch callbacks that will be invoked must be -// provided. -func (fd fieldDispatcher) dispatch(f *ast.Field) { - // Each field declaration may actually be multiple declarations of the same - // type. For example, consider: - // - // type Point struct { - // x, y, z int - // } - // - // We invoke the call-backs once per such instance. - - // Handle embedded fields. Embedded fields have no names, but can be - // referenced by the type name. - if len(f.Names) < 1 { - switch v := f.Type.(type) { - case *ast.Ident: - fd.primitive(v, v) - case *ast.SelectorExpr: - fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel) - default: - // Note: Arrays can't be embedded, which is handled here. - panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type)) - } - return - } - - // Non-embedded field. - for _, name := range f.Names { - switch v := f.Type.(type) { - case *ast.Ident: - fd.primitive(name, v) - case *ast.SelectorExpr: - fd.selector(name, v.X.(*ast.Ident), v.Sel) - case *ast.ArrayType: - switch t := v.Elt.(type) { - case *ast.Ident: - fd.array(name, v, t) - default: - // Should be handled with a better error message during validate. - panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) - } - default: - fd.unhandled(name) - } - } -} - -// debugEnabled indicates whether debugging is enabled for gomarshal. -func debugEnabled() bool { - return *debug -} - -// abort aborts the go_marshal tool with the given error message. -func abort(msg string) { - if !strings.HasSuffix(msg, "\n") { - msg += "\n" - } - fmt.Print(msg) - os.Exit(1) -} - -// abortAt aborts the go_marshal tool with the given error message, with -// a reference position to the input source. -func abortAt(p token.Position, msg string) { - abort(fmt.Sprintf("%v:\n %s\n", p, msg)) -} - -// debugf conditionally prints a debug message. -func debugf(f string, a ...interface{}) { - if debugEnabled() { - fmt.Printf(f, a...) - } -} - -// debugfAt conditionally prints a debug message with a reference to a position -// in the input source. -func debugfAt(p token.Position, f string, a ...interface{}) { - if debugEnabled() { - fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) - } -} - -// emit generates a line of code in the output file. -// -// emit is a wrapper around writing a formatted string to the output -// buffer. emit can be invoked in one of two ways: -// -// (1) emit("some string") -// When emit is called with a single string argument, it is simply copied to -// the output buffer without any further formatting. -// (2) emit(fmtString, args...) -// emit can also be invoked in a similar fashion to *Printf() functions, -// where the first argument is a format string. -// -// Calling emit with a single argument that is not a string will result in a -// panic, as the caller's intent is ambiguous. -func emit(out io.Writer, indent int, a ...interface{}) { - const spacesPerIndentLevel = 4 - - if len(a) < 1 { - panic("emit() called with no arguments") - } - - if indent > 0 { - if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { - // Writing to the emit output should not fail. Typically the output - // is a byte.Buffer; writes to these never fail. - panic(err) - } - } - - first, ok := a[0].(string) - if !ok { - // First argument must be either the string to emit (case 1 from - // function-level comment), or a format string (case 2). - panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) - } - - if len(a) == 1 { - // Single string argument. Assume no formatting requested. - if _, err := fmt.Fprint(out, first); err != nil { - // Writing to out should not fail. - panic(err) - } - return - - } - - // Formatting requested. - if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { - // Writing to out should not fail. - panic(err) - } -} - -// sourceBuffer represents fragments of generated go source code. -// -// sourceBuffer provides a convenient way to build up go souce fragments in -// memory. May be safely zero-value initialized. Not thread-safe. -type sourceBuffer struct { - // Current indentation level. - indent int - - // Memory buffer containing contents while they're being generated. - b bytes.Buffer -} - -func (b *sourceBuffer) reset() { - b.indent = 0 - b.b.Reset() -} - -func (b *sourceBuffer) incIndent() { - b.indent++ -} - -func (b *sourceBuffer) decIndent() { - if b.indent <= 0 { - panic("decIndent() without matching incIndent()") - } - b.indent-- -} - -func (b *sourceBuffer) emit(a ...interface{}) { - emit(&b.b, b.indent, a...) -} - -func (b *sourceBuffer) emitNoIndent(a ...interface{}) { - emit(&b.b, 0 /*indent*/, a...) -} - -func (b *sourceBuffer) inIndent(body func()) { - b.incIndent() - body() - b.decIndent() -} - -func (b *sourceBuffer) write(out io.Writer) error { - _, err := fmt.Fprint(out, b.b.String()) - return err -} - -// Write implements io.Writer.Write. -func (b *sourceBuffer) Write(buf []byte) (int, error) { - return (b.b.Write(buf)) -} - -// importStmt represents a single import statement. -type importStmt struct { - // Local name of the imported package. - name string - // Import path. - path string - // Indicates whether the local name is an alias, or simply the final - // component of the path. - aliased bool - // Indicates whether this import was referenced by generated code. - used bool - // AST node and file set representing the import statement, if any. These - // are only non-nil if the import statement originates from an input source - // file. - spec *ast.ImportSpec - fset *token.FileSet -} - -func newImport(p string) *importStmt { - name := path.Base(p) - return &importStmt{ - name: name, - path: p, - aliased: false, - } -} - -func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. - name := path.Base(p) - if name == "" || name == "/" || name == "." { - panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", - f.Position(spec.Path.Pos()), name)) - } - if spec.Name != nil { - name = spec.Name.Name - } - return &importStmt{ - name: name, - path: p, - aliased: spec.Name != nil, - spec: spec, - fset: f, - } -} - -// String implements fmt.Stringer.String. This generates a string for the import -// statement appropriate for writing directly to generated code. -func (i *importStmt) String() string { - if i.aliased { - return fmt.Sprintf("%s %q", i.name, i.path) - } - return fmt.Sprintf("%q", i.path) -} - -// debugString returns a debug string representing an import statement. This -// representation is not valid golang code and is used for debugging output. -func (i *importStmt) debugString() string { - if i.spec != nil && i.fset != nil { - return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) - } - return fmt.Sprintf("(go-marshal import): %s", i) -} - -func (i *importStmt) markUsed() { - i.used = true -} - -func (i *importStmt) equivalent(other *importStmt) bool { - return i.name == other.name && i.path == other.path && i.aliased == other.aliased -} - -// importTable represents a collection of importStmts. -// -// An importTable may contain multiple import statements referencing the same -// local name. All import statements aliasing to the same local name are -// technically ambiguous, as if such an import name is used in the generated -// code, it's not clear which import statement it refers to. We ignore any -// potential collisions until actually writing the import table to the generated -// source file. See importTable.write. -// -// Given the following import statements across all the files comprising a -// package marshalled: -// -// "sync" -// "pkg/sync" -// "pkg/sentry/kernel" -// ktime "pkg/sentry/kernel/time" -// -// An importTable representing them would look like this: -// -// importTable { -// is: map[string][]*importStmt { -// "sync": []*importStmt{ -// importStmt{name:"sync", path:"sync", aliased:false} -// importStmt{name:"sync", path:"pkg/sync", aliased:false} -// }, -// "kernel": []*importStmt{importStmt{ -// name: "kernel", -// path: "pkg/sentry/kernel", -// aliased: false -// }}, -// "ktime": []*importStmt{importStmt{ -// name: "ktime", -// path: "pkg/sentry/kernel/time", -// aliased: true, -// }}, -// } -// } -// -// Note that the local name "sync" is assigned to two different import -// statements. This is possible if the import statements are from different -// source files in the same package. -// -// Since go-marshal generates a single output file per package regardless of the -// number of input files, if "sync" is referenced by any generated code, it's -// unclear which import statement "sync" refers to. While it's theoretically -// possible to resolve this by assigning a unique local alias to each instance -// of the sync package, go-marshal currently aborts when it encounters such an -// ambiguity. -// -// TODO(b/151478251): importTable considers the final component of an import -// path to be the package name, but this is only a convention. The actual -// package name is determined by the package statement in the source files for -// the package. -type importTable struct { - // Map of imports and whether they should be copied to the output. - is map[string][]*importStmt -} - -func newImportTable() *importTable { - return &importTable{ - is: make(map[string][]*importStmt), - } -} - -// Merges import statements from other into i. -func (i *importTable) merge(other *importTable) { - for name, ims := range other.is { - i.is[name] = append(i.is[name], ims...) - } -} - -func (i *importTable) addStmt(s *importStmt) *importStmt { - i.is[s.name] = append(i.is[s.name], s) - return s -} - -func (i *importTable) add(s string) *importStmt { - n := newImport(s) - return i.addStmt(n) -} - -func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { - return i.addStmt(newImportFromSpec(spec, f)) -} - -// Marks the import named n as used. If no such import is in the table, returns -// false. -func (i *importTable) markUsed(n string) bool { - if ns, ok := i.is[n]; ok { - for _, n := range ns { - n.markUsed() - } - return true - } - return false -} - -func (i *importTable) clear() { - for _, is := range i.is { - for _, i := range is { - i.used = false - } - } -} - -func (i *importTable) write(out io.Writer) error { - if len(i.is) == 0 { - // Nothing to import, we're done. - return nil - } - - imports := make([]string, 0, len(i.is)) - for name, is := range i.is { - var lastUsed *importStmt - var ambiguous bool - - for _, i := range is { - if i.used { - if lastUsed != nil { - if !i.equivalent(lastUsed) { - ambiguous = true - } - } - lastUsed = i - } - } - - if ambiguous { - // We have two or more import statements across the different source - // files that share a local name, and at least one of these imports - // are used by the generated code. This ambiguity can't be resolved - // by go-marshal and requires the user intervention. Dump a list of - // the colliding import statements and let the user modify the input - // files as appropriate. - var b strings.Builder - fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) - fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) - // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't - // be true. Therefore the slicing below is safe. - for _, i := range is[:len(is)-1] { - fmt.Fprintf(&b, " %v\n", i.debugString()) - } - fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) - panic(b.String()) - } - - if lastUsed != nil { - imports = append(imports, lastUsed.String()) - } - } - sort.Strings(imports) - - var b sourceBuffer - b.emit("import (\n") - b.incIndent() - for _, i := range imports { - b.emit("%s\n", i) - } - b.decIndent() - b.emit(")\n\n") - - return b.write(out) -} |