summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r--tools/go_marshal/gomarshal/BUILD22
-rw-r--r--tools/go_marshal/gomarshal/generator.go587
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go276
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go146
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_dynamic.go96
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go285
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go616
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go233
-rw-r--r--tools/go_marshal/gomarshal/util.go503
9 files changed, 0 insertions, 2764 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
deleted file mode 100644
index aaa203115..000000000
--- a/tools/go_marshal/gomarshal/BUILD
+++ /dev/null
@@ -1,22 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-licenses(["notice"])
-
-go_library(
- name = "gomarshal",
- srcs = [
- "generator.go",
- "generator_interfaces.go",
- "generator_interfaces_array_newtype.go",
- "generator_interfaces_dynamic.go",
- "generator_interfaces_primitive_newtype.go",
- "generator_interfaces_struct.go",
- "generator_tests.go",
- "util.go",
- ],
- stateify = False,
- visibility = [
- "//:sandbox",
- ],
- deps = ["//tools/constraintutil"],
-)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
deleted file mode 100644
index 4c23637c0..000000000
--- a/tools/go_marshal/gomarshal/generator.go
+++ /dev/null
@@ -1,587 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package gomarshal implements the go_marshal code generator. See README.md.
-package gomarshal
-
-import (
- "bytes"
- "fmt"
- "go/ast"
- "go/parser"
- "go/token"
- "os"
- "sort"
- "strings"
-
- "gvisor.dev/gvisor/tools/constraintutil"
-)
-
-// List of identifiers we use in generated code that may conflict with a
-// similarly-named source identifier. Abort gracefully when we see these to
-// avoid potentially confusing compilation failures in generated code.
-//
-// This only applies to import aliases at the moment. All other identifiers
-// are qualified by a receiver argument, since they're struct fields.
-//
-// All recievers are single letters, so we don't allow import aliases to be a
-// single letter.
-var badIdents = []string{
- "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx",
- "inner", "length", "limit", "ptr", "size", "src", "srcs", "val",
- // All single-letter identifiers.
-}
-
-// Constructed fromt badIdents in init().
-var badIdentsMap map[string]struct{}
-
-func init() {
- badIdentsMap = make(map[string]struct{})
- for _, ident := range badIdents {
- badIdentsMap[ident] = struct{}{}
- }
-}
-
-// Generator drives code generation for a single invocation of the go_marshal
-// utility.
-//
-// The Generator holds arguments passed to the tool, and drives parsing,
-// processing and code Generator for all types marked with +marshal declared in
-// the input files.
-//
-// See Generator.run() as the entry point.
-type Generator struct {
- // Paths to input go source files.
- inputs []string
- // Output file to write generated go source.
- output *os.File
- // Output file to write generated tests.
- outputTest *os.File
- // Output file to write unconditionally generated tests.
- outputTestUC *os.File
- // Package name for the generated file.
- pkg string
- // Set of extra packages to import in the generated file.
- imports *importTable
-}
-
-// NewGenerator creates a new code Generator.
-func NewGenerator(srcs []string, out, outTest, outTestUnconditional, pkg string, imports []string) (*Generator, error) {
- f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
- if err != nil {
- return nil, fmt.Errorf("couldn't open output file %q: %w", out, err)
- }
- fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
- if err != nil {
- return nil, fmt.Errorf("couldn't open test output file %q: %w", out, err)
- }
- fTestUC, err := os.OpenFile(outTestUnconditional, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
- if err != nil {
- return nil, fmt.Errorf("couldn't open unconditional test output file %q: %w", out, err)
- }
- g := Generator{
- inputs: srcs,
- output: f,
- outputTest: fTest,
- outputTestUC: fTestUC,
- pkg: pkg,
- imports: newImportTable(),
- }
- for _, i := range imports {
- // All imports on the extra imports list are unconditionally marked as
- // used, so that they're always added to the generated code.
- g.imports.add(i).markUsed()
- }
-
- // The following imports may or may not be used by the generated code,
- // depending on what's required for the target types. Don't mark these as
- // used by default.
- g.imports.add("io")
- g.imports.add("reflect")
- g.imports.add("runtime")
- g.imports.add("unsafe")
- g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
- g.imports.add("gvisor.dev/gvisor/pkg/hostarch")
- g.imports.add("gvisor.dev/gvisor/pkg/marshal")
- return &g, nil
-}
-
-// writeHeader writes the header for the generated source file. The header
-// includes the package name, package level comments and import statements.
-func (g *Generator) writeHeader() error {
- var b sourceBuffer
- b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
-
- bcexpr, err := constraintutil.CombineFromFiles(g.inputs)
- if err != nil {
- return err
- }
- if bcexpr != nil {
- // Emit build constraints.
- b.emit("// If there are issues with build constraint aggregation, see\n")
- b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n")
- b.emit("// come from the input set of files used to generate this file. This input set\n")
- b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n")
- b.emit("// see tools/defs.bzl:calculate_sets().\n\n")
- b.emit(constraintutil.Lines(bcexpr))
- }
-
- // Package header.
- b.emit("package %s\n\n", g.pkg)
- if err := b.write(g.output); err != nil {
- return err
- }
-
- return g.imports.write(g.output)
-}
-
-// writeTypeChecks writes a statement to force the compiler to perform a type
-// check for all Marshallable types referenced by the generated code.
-func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
- if len(ms) == 0 {
- return nil
- }
-
- msl := make([]string, 0, len(ms))
- for m := range ms {
- msl = append(msl, m)
- }
- sort.Strings(msl)
-
- var buf bytes.Buffer
- fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
-
- for _, m := range msl {
- fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
- }
- fmt.Fprint(&buf, "\n")
-
- _, err := fmt.Fprint(g.output, buf.String())
- return err
-}
-
-// parse processes all input files passed this generator and produces a set of
-// parsed go ASTs.
-func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
- debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
- for _, path := range g.inputs {
- debugf(" %s\n", path)
- }
-
- files := make([]*ast.File, 0, len(g.inputs))
- fsets := make([]*token.FileSet, 0, len(g.inputs))
-
- for _, path := range g.inputs {
- fset := token.NewFileSet()
- f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
- if err != nil {
- // Not a valid input file?
- return nil, nil, fmt.Errorf("input %q can't be parsed: %w", path, err)
- }
-
- if debugEnabled() {
- debugf("AST for %q:\n", path)
- ast.Print(fset, f)
- }
-
- files = append(files, f)
- fsets = append(fsets, fset)
- }
-
- return files, fsets, nil
-}
-
-// sliceAPI carries information about the '+marshal slice' directive.
-type sliceAPI struct {
- // Comment node in the AST containing the +marshal tag.
- comment *ast.Comment
- // Identifier fragment to use when naming generated functions for the slice
- // API.
- ident string
- // Whether the generated functions should reference the newtype name, or the
- // inner type name. Only meaningful on newtype declarations on primitives.
- inner bool
-}
-
-// marshallableType carries information about a type marked with the '+marshal'
-// directive.
-type marshallableType struct {
- spec *ast.TypeSpec
- slice *sliceAPI
- recv string
- dynamic bool
-}
-
-func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType {
- mt := &marshallableType{
- spec: spec,
- slice: nil,
- }
-
- var unhandledTags []string
-
- for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
- if strings.HasPrefix(tag, "slice:") {
- tokens := strings.Split(tag, ":")
- if len(tokens) < 2 || len(tokens) > 3 {
- abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
- }
- if len(tokens[1]) == 0 {
- abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
- }
-
- sa := &sliceAPI{
- comment: tagLine,
- ident: tokens[1],
- }
- mt.slice = sa
-
- if len(tokens) == 3 {
- if tokens[2] != "inner" {
- abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
- }
- sa.inner = true
- }
-
- continue
- } else if tag == "dynamic" {
- mt.dynamic = true
- continue
- }
-
- unhandledTags = append(unhandledTags, tag)
- }
-
- if len(unhandledTags) > 0 {
- abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
- }
-
- return mt
-}
-
-// collectMarshallableTypes walks the parsed AST and collects a list of type
-// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType {
- recv := make(map[string]string) // Type name to recevier name.
- types := make(map[*ast.TypeSpec]*marshallableType)
- for _, decl := range a.Decls {
- gdecl, ok := decl.(*ast.GenDecl)
- // Type declaration?
- if !ok || gdecl.Tok != token.TYPE {
- // Is this a function declaration? We remember receiver names.
- d, ok := decl.(*ast.FuncDecl)
- if ok && d.Recv != nil && len(d.Recv.List) == 1 {
- // Accept concrete methods & pointer methods.
- ident, ok := d.Recv.List[0].Type.(*ast.Ident)
- if !ok {
- var st *ast.StarExpr
- st, ok = d.Recv.List[0].Type.(*ast.StarExpr)
- if ok {
- ident, ok = st.X.(*ast.Ident)
- }
- }
- // The receiver name may be not present.
- if ok && len(d.Recv.List[0].Names) == 1 {
- // Recover the type receiver name in this case.
- recv[ident.Name] = d.Recv.List[0].Names[0].Name
- }
- }
- debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
- continue
- }
- // Does it have a comment?
- if gdecl.Doc == nil {
- debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
- continue
- }
- // Does the comment contain a "+marshal" line?
- marked := false
- var tagLine *ast.Comment
- for _, c := range gdecl.Doc.List {
- if strings.HasPrefix(c.Text, "// +marshal") {
- marked = true
- tagLine = c
- break
- }
- }
- if !marked {
- debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
- continue
- }
- for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier, so this
- // cast will succeed.
- t := spec.(*ast.TypeSpec)
- switch t.Type.(type) {
- case *ast.StructType:
- debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
- case *ast.Ident: // Newtype on primitive.
- debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
- case *ast.ArrayType: // Newtype on array.
- debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
- default:
- // A user specifically requested marshalling on this type, but we
- // don't support it.
- abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
- }
- types[t] = newMarshallableType(f, tagLine, t)
- }
- }
- // Update the types with the last seen receiver. As long as the
- // receiver name is consistent for the type, then we will generate
- // code that is still consistent with itself.
- for t, mt := range types {
- r, ok := recv[t.Name.Name]
- if !ok {
- mt.recv = receiverName(t) // Default.
- continue
- }
- mt.recv = r // Last seen.
- }
- return types
-}
-
-// collectImports collects all imports from all input source files. Some of
-// these imports are copied to the generated output, if they're referenced by
-// the generated code.
-//
-// collectImports de-duplicates imports while building the list, and ensures
-// identifiers in the generated code don't conflict with any imported package
-// names.
-func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
- is := make(map[string]importStmt)
- for _, decl := range a.Decls {
- gdecl, ok := decl.(*ast.GenDecl)
- // Import statement?
- if !ok || gdecl.Tok != token.IMPORT {
- continue
- }
- for _, spec := range gdecl.Specs {
- i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
- debugf("Collected import '%s' as '%s'\n", i.path, i.name)
-
- // Make sure we have an import that doesn't use any local names that
- // would conflict with identifiers in the generated code.
- if len(i.name) == 1 && i.name != "_" {
- abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
- }
- if _, ok := badIdentsMap[i.name]; ok {
- abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
- }
- }
- }
- return is
-
-}
-
-func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator {
- i := newInterfaceGenerator(t.spec, t.recv, fset)
- if t.dynamic {
- if t.slice != nil {
- abortAt(fset.Position(t.slice.comment.Slash), "Slice API is not supported for dynamic types because it assumes that each slice element is statically sized.")
- }
- // No validation needed, assume the user knows what they are doing.
- i.emitMarshallableForDynamicType()
- return i
- }
- switch ty := t.spec.Type.(type) {
- case *ast.StructType:
- i.validateStruct(t.spec, ty)
- i.emitMarshallableForStruct(ty)
- if t.slice != nil {
- i.emitMarshallableSliceForStruct(ty, t.slice)
- }
- case *ast.Ident:
- i.validatePrimitiveNewtype(ty)
- i.emitMarshallableForPrimitiveNewtype(ty)
- if t.slice != nil {
- i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
- }
- case *ast.ArrayType:
- i.validateArrayNewtype(t.spec.Name, ty)
- // After validate, we can safely call arrayLen.
- i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
- if t.slice != nil {
- abortAt(fset.Position(t.slice.comment.Slash), "Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")
- }
- default:
- // This should've been filtered out by collectMarshallabeTypes.
- panic(fmt.Sprintf("Unexpected type %+v", ty))
- }
- return i
-}
-
-// generateOneTestSuite generates a test suite for the automatically generated
-// implementations type t.
-func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
- i := newTestGenerator(t.spec, t.recv)
- i.emitTests(t.slice)
- return i
-}
-
-// Run is the entry point to code generation using g.
-//
-// Run parses all input source files specified in g and emits generated code.
-func (g *Generator) Run() error {
- // Parse our input source files into ASTs and token sets.
- asts, fsets, err := g.parse()
- if err != nil {
- return err
- }
-
- if len(asts) != len(fsets) {
- panic("ASTs and FileSets don't match")
- }
-
- // Map of imports in source files; key = local package name, value = import
- // path.
- is := make(map[string]importStmt)
- for i, a := range asts {
- // Collect all imports from the source files. We may need to copy some
- // of these to the generated code if they're referenced. This has to be
- // done before the loop below because we need to process all ASTs before
- // we start requesting imports to be copied one by one as we encounter
- // them in each generated source.
- for name, i := range g.collectImports(a, fsets[i]) {
- is[name] = i
- }
- }
-
- var impls []*interfaceGenerator
- var ts []*testGenerator
- // Set of Marshallable types referenced by generated code.
- ms := make(map[string]struct{})
- for i, a := range asts {
- // Collect type declarations marked for code generation and generate
- // Marshallable interfaces.
- var sortedTypes []*marshallableType
- for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
- sortedTypes = append(sortedTypes, t)
- }
- sort.Slice(sortedTypes, func(x, y int) bool {
- // Sort by type name, which should be unique within a package.
- return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String()
- })
- for _, t := range sortedTypes {
- impl := g.generateOne(t, fsets[i])
- // Collect Marshallable types referenced by the generated code.
- for ref := range impl.ms {
- ms[ref] = struct{}{}
- }
- impls = append(impls, impl)
- // Collect imports referenced by the generated code and add them to
- // the list of imports we need to copy to the generated code.
- for name := range impl.is {
- if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
- }
- }
- // Do not generate tests for dynamic types because they inherently
- // violate some go_marshal requirements.
- if !t.dynamic {
- ts = append(ts, g.generateOneTestSuite(t))
- }
- }
- }
-
- // Write output file header. These include things like package name and
- // import statements.
- if err := g.writeHeader(); err != nil {
- return err
- }
-
- // Write type checks for referenced marshallable types to output file.
- if err := g.writeTypeChecks(ms); err != nil {
- return err
- }
-
- // Write generated interfaces to output file.
- for _, i := range impls {
- if err := i.write(g.output); err != nil {
- return err
- }
- }
-
- // Write generated tests to test file.
- return g.writeTests(ts)
-}
-
-// writeTests outputs tests for the generated interface implementations to a go
-// source file.
-func (g *Generator) writeTests(ts []*testGenerator) error {
- var b sourceBuffer
-
- // Write the unconditional test file. This file is always compiled,
- // regardless of what build tags were specified on the original input
- // files. We use this file to guarantee we never end up with an empty test
- // file, as that causes the build to fail with "no tests/benchmarks/examples
- // found".
- //
- // There's no easy way to determine ahead of time if we'll end up with an
- // empty build file since build constraints can arbitrarily cause some of
- // the original types to be not defined. We also have no way to tell bazel
- // to omit the entire test suite since the output files are already defined
- // before go-marshal is called.
- b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
- b.emit("package %s\n\n", g.pkg)
- b.emit("func Example() {\n")
- b.inIndent(func() {
- b.emit("// This example is intentionally empty, and ensures this package contains at\n")
- b.emit("// least one testable entity. go-marshal is forced to emit a test package if the\n")
- b.emit("// input package is marked marshallable, but emitting no testable entities \n")
- b.emit("// results in a build failure.\n")
- })
- b.emit("}\n")
- if err := b.write(g.outputTestUC); err != nil {
- return err
- }
-
- // Now generate the real test file that contains the real types we
- // processed. These need to be conditionally compiled according to the build
- // tags, as the original types may not be defined under all build
- // configurations.
-
- b.reset()
- b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
-
- // Emit build constraints.
- bcexpr, err := constraintutil.CombineFromFiles(g.inputs)
- if err != nil {
- return err
- }
- b.emit(constraintutil.Lines(bcexpr))
-
- b.emit("package %s\n\n", g.pkg)
- if err := b.write(g.outputTest); err != nil {
- return err
- }
-
- // Collect and write test import statements.
- imports := newImportTable()
- for _, t := range ts {
- imports.merge(t.imports)
- }
-
- if err := imports.write(g.outputTest); err != nil {
- return err
- }
-
- // Write test functions.
- for _, t := range ts {
- if err := t.write(g.outputTest); err != nil {
- return err
- }
- }
- return nil
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
deleted file mode 100644
index 3e643e77f..000000000
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ /dev/null
@@ -1,276 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gomarshal
-
-import (
- "fmt"
- "go/ast"
- "go/token"
- "strings"
-)
-
-// interfaceGenerator generates marshalling interfaces for a single type.
-//
-// getState is not thread-safe.
-type interfaceGenerator struct {
- sourceBuffer
-
- // The type we're serializing.
- t *ast.TypeSpec
-
- // Receiver argument for generated methods.
- r string
-
- // FileSet containing the tokens for the type we're processing.
- f *token.FileSet
-
- // is records external packages referenced by the generated implementation.
- is map[string]struct{}
-
- // ms records Marshallable types referenced by the generated implementation
- // of t's interfaces.
- ms map[string]struct{}
-
- // as records fields in t that are potentially not packed. The key is the
- // accessor for the field.
- as map[string]struct{}
-}
-
-// typeName returns the name of the type this g represents.
-func (g *interfaceGenerator) typeName() string {
- return g.t.Name.Name
-}
-
-// newinterfaceGenerator creates a new interface generator.
-func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator {
- g := &interfaceGenerator{
- t: t,
- r: r,
- f: fset,
- is: make(map[string]struct{}),
- ms: make(map[string]struct{}),
- as: make(map[string]struct{}),
- }
- g.recordUsedMarshallable(g.typeName())
- return g
-}
-
-func (g *interfaceGenerator) recordUsedMarshallable(m string) {
- g.ms[m] = struct{}{}
-
-}
-
-func (g *interfaceGenerator) recordUsedImport(i string) {
- g.is[i] = struct{}{}
-}
-
-func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
- g.as[fieldName] = struct{}{}
-}
-
-// abortAt aborts the go_marshal tool with the given error message, with a
-// reference position to the input source. Same as abortAt, but uses g to
-// resolve p to position.
-func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
- abortAt(g.f.Position(p), msg)
-}
-
-// scalarSize returns the size of type identified by t. If t isn't a primitive
-// type, the size isn't known at code generation time, and must be resolved via
-// the marshal.Marshallable interface.
-func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
- switch t.Name {
- case "int8", "uint8", "byte":
- return 1, false
- case "int16", "uint16":
- return 2, false
- case "int32", "uint32":
- return 4, false
- case "int64", "uint64":
- return 8, false
- default:
- return 0, true
- }
-}
-
-func (g *interfaceGenerator) shift(bufVar string, n int) {
- g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
-}
-
-func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
- g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
-}
-
-// marshalScalar writes a single scalar to a byte slice.
-func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
- switch typ {
- case "int8", "uint8", "byte":
- g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
- g.shift(bufVar, 1)
- case "int16", "uint16":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
- g.shift(bufVar, 2)
- case "int32", "uint32":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
- g.shift(bufVar, 4)
- case "int64", "uint64":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
- g.shift(bufVar, 8)
- default:
- g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
- g.shiftDynamic(bufVar, accessor)
- }
-}
-
-// unmarshalScalar reads a single scalar from a byte slice.
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
- switch typ {
- case "byte":
- g.emit("%s = %s[0]\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "int8", "uint8":
- g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
- g.shift(bufVar, 1)
- case "int16", "uint16":
- g.recordUsedImport("hostarch")
- g.emit("%s = %s(hostarch.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
- g.shift(bufVar, 2)
- case "int32", "uint32":
- g.recordUsedImport("hostarch")
- g.emit("%s = %s(hostarch.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
- g.shift(bufVar, 4)
- case "int64", "uint64":
- g.recordUsedImport("hostarch")
- g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
- g.shift(bufVar, 8)
- default:
- g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
- g.shiftDynamic(bufVar, accessor)
- g.recordPotentiallyNonPackedField(accessor)
- }
-}
-
-// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
-// byte slice, bypassing escape analysis. The caller is responsible for ensuring
-// srcPtr lives until they're done with dstVar, the runtime does not consider
-// dstVar dependent on srcPtr due to the escape analysis bypass.
-//
-// srcPtr must be a pointer.
-//
-// This function uses internally uses the identifier "hdr", and cannot be used
-// in a context where it is already bound.
-func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
- g.recordUsedImport("gohacks")
- g.emit("// Construct a slice backed by dst's underlying memory.\n")
- g.emit("var %s []byte\n", dstVar)
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
- g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
- g.emit("hdr.Len = %s\n", lenExpr)
- g.emit("hdr.Cap = %s\n\n", lenExpr)
-}
-
-// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
-// to a byte slice. As part of the cast, the byte slice is made to look
-// independent of the src slice by bypassing escape analysis. This means the
-// byte slice can be used without causing the source to escape. The caller is
-// responsible for ensuring srcPtr lives until they're done with dstVar, as the
-// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
-//
-// srcPtr must be a pointer.
-//
-// This function uses internally uses the identifiers "ptr", "val" and "hdr",
-// and cannot be used in a context where these identifiers are already bound.
-func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
- g.emitNoEscapeSliceDataPointer(srcPtr, "val")
-
- g.emit("// Construct a slice backed by dst's underlying memory.\n")
- g.emit("var %s []byte\n", dstVar)
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
- g.emit("hdr.Data = uintptr(val)\n")
- g.emit("hdr.Len = %s\n", lenExpr)
- g.emit("hdr.Cap = %s\n\n", lenExpr)
-}
-
-// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
-// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
-// ensuring srcPtr lives until they're done with dstVar, as the runtime no
-// longer considers dstVar dependent on srcPtr and is free to GC it.
-//
-// srcPtr must be a pointer.
-//
-// This function uses internally uses the identifier "ptr" cannot be used in a
-// context where this identifier is already bound.
-func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
- g.recordUsedImport("gohacks")
- g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
- g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
-}
-
-func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
- g.emit("// must live until the use above.\n")
- g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar)
-}
-
-func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
- switch x := e.X.(type) {
- case *ast.BinaryExpr:
- // Recursively expand sub-expression.
- g.expandBinaryExpr(b, x)
- case *ast.Ident:
- fmt.Fprintf(b, "%s", x.Name)
- case *ast.BasicLit:
- fmt.Fprintf(b, "%s", x.Value)
- default:
- g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
- }
-
- fmt.Fprintf(b, "%s", e.Op)
-
- switch y := e.Y.(type) {
- case *ast.BinaryExpr:
- // Recursively expand sub-expression.
- g.expandBinaryExpr(b, y)
- case *ast.Ident:
- fmt.Fprintf(b, "%s", y.Name)
- case *ast.BasicLit:
- fmt.Fprintf(b, "%s", y.Value)
- default:
- g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
- }
-}
-
-// arrayLenExpr returns a string containing a valid golang expression
-// representing the length of array a. The returned expression should be treated
-// as a single value, and will be already parenthesized as required.
-func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
- var b strings.Builder
-
- switch l := a.Len.(type) {
- case *ast.Ident:
- fmt.Fprintf(&b, "%s", l.Name)
- case *ast.BasicLit:
- fmt.Fprintf(&b, "%s", l.Value)
- case *ast.BinaryExpr:
- g.expandBinaryExpr(&b, l)
- return fmt.Sprintf("(%s)", b.String())
- default:
- g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
- }
- return b.String()
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
deleted file mode 100644
index bd7741ae5..000000000
--- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
+++ /dev/null
@@ -1,146 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file contains the bits of the code generator specific to marshalling
-// newtypes on arrays.
-
-package gomarshal
-
-import (
- "fmt"
- "go/ast"
-)
-
-func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
- if a.Len == nil {
- g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-}
-
-func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
- g.recordUsedImport("gohacks")
- g.recordUsedImport("hostarch")
- g.recordUsedImport("io")
- g.recordUsedImport("marshal")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
-
- lenExpr := g.arrayLenExpr(a)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- if size, dynamic := g.scalarSize(elt); !dynamic {
- g.emit("return %d * %s\n", size, lenExpr)
- } else {
- g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
- })
- g.emit("}\n")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
- })
- g.emit("}\n")
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Array newtypes are always packed.\n")
- g.emit("return true\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := w.Write(buf)\n")
- g.emitKeepAlive(g.r)
- g.emit("return int64(length), err\n")
-
- })
- g.emit("}\n\n")
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
deleted file mode 100644
index 345020ddc..000000000
--- a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gomarshal
-
-func (g *interfaceGenerator) emitMarshallableForDynamicType() {
- // The user writes their own MarshalBytes, UnmarshalBytes and SizeBytes for
- // dynamic types. Generate the rest using these definitions.
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s is dynamic so it might have slice/string headers. Hence, it is not packed.\n", g.typeName())
- g.emit("return false\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
- g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
- g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
- g.emit("// partially unmarshalled struct.\n")
- g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.recordUsedImport("io")
- g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("length, err := writer.Write(buf)\n")
- g.emit("return int64(length), err\n")
- })
- g.emit("}\n\n")
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
deleted file mode 100644
index ba4b7324e..000000000
--- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
+++ /dev/null
@@ -1,285 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file contains the bits of the code generator specific to marshalling
-// newtypes on primitives.
-
-package gomarshal
-
-import (
- "fmt"
- "go/ast"
-)
-
-// marshalPrimitiveScalar writes a single primitive variable to a byte
-// slice.
-func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
- switch typ {
- case "int8", "uint8", "byte":
- g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
- case "int16", "uint16":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
- case "int32", "uint32":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
- case "int64", "uint64":
- g.recordUsedImport("hostarch")
- g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
- default:
- g.emit("// Explicilty cast to the underlying type before dispatching to\n")
- g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
-func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
- switch typ {
- case "byte":
- g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
- case "int8", "uint8":
- g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
- case "int16", "uint16":
- g.recordUsedImport("hostarch")
- g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
- case "int32", "uint32":
- g.recordUsedImport("hostarch")
- g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
- case "int64", "uint64":
- g.recordUsedImport("hostarch")
- g.emit("*%s = %s(%s(hostarch.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
- default:
- g.emit("// Explicilty cast to the underlying type before dispatching to\n")
- g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we provide
- // suggestions for some disallowed types and reject them, then attempt
- // to marshal any remaining types by invoking the marshal.Marshallable
- // interface on them. If these types don't actually implement
- // marshal.Marshallable, compilation of the generated code will fail
- // with an appropriate error message.
- return
- case "int":
- g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
-}
-
-// emitMarshallableForPrimitiveNewtype outputs code to implement the
-// marshal.Marshallable interface for a newtype on a primitive. Primitive
-// newtypes are always packed, so we can omit the various fallbacks required for
-// non-packed structs.
-func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
- g.recordUsedImport("gohacks")
- g.recordUsedImport("hostarch")
- g.recordUsedImport("io")
- g.recordUsedImport("marshal")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- if size, dynamic := g.scalarSize(nt); !dynamic {
- g.emit("return %d\n", size)
- } else {
- g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Scalar newtypes are always packed.\n")
- g.emit("return true\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := w.Write(buf)\n")
- g.emitKeepAlive(g.r)
- g.emit("return int64(length), err\n")
-
- })
- g.emit("}\n\n")
-}
-
-func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
-
- eltType := g.typeName()
- if slice.inner {
- eltType = nt.Name
- }
-
- g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
- g.emit("//go:nosplit\n")
- g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
- g.inIndent(func() {
- g.emit("count := len(dst)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
-
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emitKeepAlive("dst")
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
- g.emit("//go:nosplit\n")
- g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
- g.inIndent(func() {
- g.emit("count := len(src)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- g.emitCastSliceToByteSlice("&src", "buf", "size * count")
-
- g.emit("length, err := cc.CopyOutBytes(addr, buf) // escapes: okay.\n")
- g.emitKeepAlive("src")
- g.emit("return length, err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
- g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(src)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- g.emit("dst = dst[:size*count]\n")
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n")
- g.emit("return size*count, nil\n")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
- g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(dst)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- g.emit("src = src[:(size*count)]\n")
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n")
- g.emit("return size*count, nil\n")
- })
- g.emit("}\n\n")
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
deleted file mode 100644
index 4c47218f1..000000000
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ /dev/null
@@ -1,616 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file contains the bits of the code generator specific to marshalling
-// structs.
-
-package gomarshal
-
-import (
- "fmt"
- "go/ast"
- "sort"
- "strings"
-)
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- // Sort expressions for determinstic build outputs.
- sort.Strings(cs)
- return strings.Join(cs, " && "), true
-}
-
-// validateStruct ensures the type we're working with can be marshalled. These
-// checks are done ahead of time and in one place so we can make assumptions
-// later.
-func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
- forEachStructField(st, func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- g.validatePrimitiveNewtype(t)
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
- g.validateArrayNewtype(n, a)
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
-func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
- packed := true
- forEachStructField(st, func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if packed {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- packed = false
- }
- }
- }
- })
- return packed
-}
-
-func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
- thisPacked := g.isStructPacked(st)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- forEachStructField(st, fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(_, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(_ *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if size, dynamic := g.scalarSize(t); !dynamic {
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
- g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- return
- }
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- forEachStructField(st, fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We don't have an instance of the dynamic type we can
- // reference here (since the version in this struct is
- // anonymous). Use a typed nil pointer to call
- // SizeBytes() instead.
- g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
- g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
- lenExpr := g.arrayLenExpr(a)
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
- if size, dynamic := g.scalarSize(t); !dynamic {
- g.emit("src = src[%d*(%s):]\n", size, lenExpr)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("//go:nosplit\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
-
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- if thisPacked {
- g.recordUsedImport("gohacks")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("} else {\n")
- g.inIndent(fallback)
- g.emit("}\n")
- } else {
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- }
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- if thisPacked {
- g.recordUsedImport("gohacks")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- })
- g.emit("} else {\n")
- g.inIndent(fallback)
- g.emit("}\n")
- } else {
- g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r)
- }
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
- g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
- g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
- g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("//go:nosplit\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
- g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr hostarch.Addr) (int, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
- g.emit("// partially unmarshalled struct.\n")
- g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
- g.emit("return length, err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast deserialization.
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n")
- g.emitKeepAlive(g.r)
- g.emit("return length, err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.recordUsedImport("io")
- g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("length, err := writer.Write(buf)\n")
- g.emit("return int64(length), err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
-
- g.emit("length, err := writer.Write(buf)\n")
- g.emitKeepAlive(g.r)
- g.emit("return int64(length), err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-}
-
-func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
- thisPacked := g.isStructPacked(st)
-
- if slice.inner {
- abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
- }
-
- g.recordUsedImport("marshal")
- g.recordUsedImport("hostarch")
-
- g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sIn(cc marshal.CopyContext, addr hostarch.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(dst)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(size * count)\n")
- g.emit("length, err := cc.CopyInBytes(addr, buf)\n\n")
-
- g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
- g.emit("limit := length/size\n")
- g.emit("for idx := 0; idx < limit; idx++ {\n")
- g.inIndent(func() {
- g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
- })
- g.emit("}\n\n")
-
- g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n")
- g.emit("// final element, but may not contain valid data for the entire range. This may\n")
- g.emit("// result in unmarshalling zero values for some parts of the object.\n")
- g.emit("if length%size != 0 {\n")
- g.inIndent(func() {
- g.emit("idx := limit\n")
- g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
- })
- g.emit("}\n\n")
-
- g.emit("return length, err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if _, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !dst[0].Packed() {\n")
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast deserialization.
- g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
-
- g.emit("length, err := cc.CopyInBytes(addr, buf)\n")
- g.emitKeepAlive("dst")
- g.emit("return length, err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
- g.emit("func Copy%sOut(cc marshal.CopyContext, addr hostarch.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(src)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := cc.CopyScratchBuffer(size * count)\n")
- g.emit("for idx := 0; idx < count; idx++ {\n")
- g.inIndent(func() {
- g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
- })
- g.emit("}\n")
- g.emit("return cc.CopyOutBytes(addr, buf)\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if _, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !src[0].Packed() {\n")
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emitCastSliceToByteSlice("&src", "buf", "size * count")
-
- g.emit("length, err := cc.CopyOutBytes(addr, buf)\n")
- g.emitKeepAlive("src")
- g.emit("return length, err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
- g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(src)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("for idx := 0; idx < count; idx++ {\n")
- g.inIndent(func() {
- g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
- })
- g.emit("}\n")
- g.emit("return size * count, nil\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- g.recordUsedImport("gohacks")
- if _, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !src[0].Packed() {\n")
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- g.emit("dst = dst[:size*count]\n")
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(dst)))\n")
- g.emit("return size * count, nil\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
- g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
- g.inIndent(func() {
- g.emit("count := len(dst)\n")
- g.emit("if count == 0 {\n")
- g.inIndent(func() {
- g.emit("return 0, nil\n")
- })
- g.emit("}\n")
- g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("for idx := 0; idx < count; idx++ {\n")
- g.inIndent(func() {
- g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
- })
- g.emit("}\n")
- g.emit("return size * count, nil\n")
- }
- if thisPacked {
- g.recordUsedImport("gohacks")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- if _, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !dst[0].Packed() {\n")
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
-
- g.emit("src = src[:(size*count)]\n")
- g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), uintptr(len(src)))\n")
-
- g.emit("return count*size, nil\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
deleted file mode 100644
index 8f93a1de5..000000000
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ /dev/null
@@ -1,233 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gomarshal
-
-import (
- "fmt"
- "go/ast"
- "io"
- "strings"
-)
-
-var standardImports = []string{
- "bytes",
- "fmt",
- "reflect",
- "testing",
-
- "gvisor.dev/gvisor/tools/go_marshal/analysis",
-}
-
-var sliceAPIImports = []string{
- "encoding/binary",
- "gvisor.dev/gvisor/pkg/hostarch",
-}
-
-type testGenerator struct {
- sourceBuffer
-
- // The type we're serializing.
- t *ast.TypeSpec
-
- // Receiver argument for generated methods.
- r string
-
- // Imports used by generated code.
- imports *importTable
-
- // Import statement for the package declaring the type we generated code
- // for. We need this to construct test instances for the type, since the
- // tests aren't written in the same package.
- decl *importStmt
-}
-
-func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator {
- g := &testGenerator{
- t: t,
- r: r,
- imports: newImportTable(),
- }
-
- for _, i := range standardImports {
- g.imports.add(i).markUsed()
- }
- // These imports are used if a type requests the slice API. Don't
- // mark them as used by default.
- for _, i := range sliceAPIImports {
- g.imports.add(i)
- }
-
- return g
-}
-
-func (g *testGenerator) typeName() string {
- return g.t.Name.Name
-}
-
-func (g *testGenerator) testFuncName(base string) string {
- return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
-}
-
-func (g *testGenerator) inTestFunction(name string, body func()) {
- g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
- g.inIndent(body)
- g.emit("}\n\n")
-}
-
-func (g *testGenerator) emitTestNonZeroSize() {
- g.inTestFunction("TestSizeNonZero", func() {
- g.emit("var x %v\n", g.typeName())
- g.emit("if x.SizeBytes() == 0 {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
- })
- g.emit("}\n")
- })
-}
-
-func (g *testGenerator) emitTestSuspectAlignment() {
- g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("var x %v\n", g.typeName())
- g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
- })
-}
-
-func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
- g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
- g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
- g.emit("analysis.RandomizeValue(&x)\n\n")
-
- g.emit("buf := make([]byte, x.SizeBytes())\n")
- g.emit("x.MarshalBytes(buf)\n")
- g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
- g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
-
- g.emit("y.UnmarshalBytes(buf)\n")
- g.emit("if !reflect.DeepEqual(x, y) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
- })
- g.emit("}\n")
- g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
- g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
- })
- g.emit("}\n\n")
-
- g.emit("z.UnmarshalUnsafe(buf)\n")
- g.emit("if !reflect.DeepEqual(x, z) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
- })
- g.emit("}\n")
- g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
- g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
- })
- g.emit("}\n")
- })
-}
-
-func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
- for _, name := range []string{"binary", "hostarch"} {
- if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
- }
- }
-
- g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
- g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
- g.emit("analysis.RandomizeValue(&x)\n\n")
- g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
- g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
- g.emit("buf.Reset()\n")
- g.emit("if err := binary.Write(buf, hostarch.ByteOrder, x[:]); err != nil {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
- })
- g.emit("}\n")
- g.emit("bufUnsafe := make([]byte, size)\n")
- g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
-
- g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
- g.emit("if !reflect.DeepEqual(x, y) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
- })
- g.emit("}\n")
- g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
- g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
- })
- g.emit("}\n\n")
- })
-}
-
-func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
- g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
- g.emit("var x, y, yUnsafe %s\n", g.typeName())
- g.emit("analysis.RandomizeValue(&x)\n\n")
-
- g.emit("var buf bytes.Buffer\n\n")
-
- g.emit("x.WriteTo(&buf)\n")
- g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
- g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
-
- g.emit("if !reflect.DeepEqual(x, y) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
- })
- g.emit("}\n")
- g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
- g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
- })
- g.emit("}\n")
- })
-}
-
-func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
- g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
- g.emit("var x %s\n", g.typeName())
- g.emit("sizeFromConcrete := x.SizeBytes()\n")
- g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
-
- g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
- g.inIndent(func() {
- g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
- })
- g.emit("}\n")
- })
-}
-
-func (g *testGenerator) emitTests(slice *sliceAPI) {
- g.emitTestNonZeroSize()
- g.emitTestSuspectAlignment()
- g.emitTestMarshalUnmarshalPreservesData()
- g.emitTestWriteToUnmarshalPreservesData()
- g.emitTestSizeBytesOnTypedNilPtr()
-
- if slice != nil {
- g.emitTestMarshalUnmarshalSlicePreservesData(slice)
- }
-}
-
-func (g *testGenerator) write(out io.Writer) error {
- return g.sourceBuffer.write(out)
-}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
deleted file mode 100644
index 6a42691cd..000000000
--- a/tools/go_marshal/gomarshal/util.go
+++ /dev/null
@@ -1,503 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package gomarshal
-
-import (
- "bytes"
- "flag"
- "fmt"
- "go/ast"
- "go/token"
- "io"
- "os"
- "path"
- "reflect"
- "sort"
- "strings"
-)
-
-var debug = flag.Bool("debug", false, "enables debugging output")
-
-// receiverName returns an appropriate receiver name given a type spec.
-func receiverName(t *ast.TypeSpec) string {
- if len(t.Name.Name) < 1 {
- // Zero length type name?
- panic("unreachable")
- }
- return strings.ToLower(t.Name.Name[:1])
-}
-
-// kindString returns a user-friendly representation of an AST expr type.
-func kindString(e ast.Expr) string {
- switch e.(type) {
- case *ast.Ident:
- return "scalar"
- case *ast.ArrayType:
- return "array"
- case *ast.StructType:
- return "struct"
- case *ast.StarExpr:
- return "pointer"
- case *ast.FuncType:
- return "function"
- case *ast.InterfaceType:
- return "interface"
- case *ast.MapType:
- return "map"
- case *ast.ChanType:
- return "channel"
- default:
- return reflect.TypeOf(e).String()
- }
-}
-
-func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-// fieldDispatcher is a collection of callbacks for handling different types of
-// fields in a struct declaration.
-type fieldDispatcher struct {
- primitive func(n, t *ast.Ident)
- selector func(n, tX, tSel *ast.Ident)
- array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
- unhandled func(n *ast.Ident)
-}
-
-// Precondition: All dispatch callbacks that will be invoked must be
-// provided.
-func (fd fieldDispatcher) dispatch(f *ast.Field) {
- // Each field declaration may actually be multiple declarations of the same
- // type. For example, consider:
- //
- // type Point struct {
- // x, y, z int
- // }
- //
- // We invoke the call-backs once per such instance.
-
- // Handle embedded fields. Embedded fields have no names, but can be
- // referenced by the type name.
- if len(f.Names) < 1 {
- switch v := f.Type.(type) {
- case *ast.Ident:
- fd.primitive(v, v)
- case *ast.SelectorExpr:
- fd.selector(v.Sel, v.X.(*ast.Ident), v.Sel)
- default:
- // Note: Arrays can't be embedded, which is handled here.
- panic(fmt.Sprintf("Attempted to dispatch on embedded field of unsupported kind: %#v", f.Type))
- }
- return
- }
-
- // Non-embedded field.
- for _, name := range f.Names {
- switch v := f.Type.(type) {
- case *ast.Ident:
- fd.primitive(name, v)
- case *ast.SelectorExpr:
- fd.selector(name, v.X.(*ast.Ident), v.Sel)
- case *ast.ArrayType:
- switch t := v.Elt.(type) {
- case *ast.Ident:
- fd.array(name, v, t)
- default:
- // Should be handled with a better error message during validate.
- panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
- }
- default:
- fd.unhandled(name)
- }
- }
-}
-
-// debugEnabled indicates whether debugging is enabled for gomarshal.
-func debugEnabled() bool {
- return *debug
-}
-
-// abort aborts the go_marshal tool with the given error message.
-func abort(msg string) {
- if !strings.HasSuffix(msg, "\n") {
- msg += "\n"
- }
- fmt.Print(msg)
- os.Exit(1)
-}
-
-// abortAt aborts the go_marshal tool with the given error message, with
-// a reference position to the input source.
-func abortAt(p token.Position, msg string) {
- abort(fmt.Sprintf("%v:\n %s\n", p, msg))
-}
-
-// debugf conditionally prints a debug message.
-func debugf(f string, a ...interface{}) {
- if debugEnabled() {
- fmt.Printf(f, a...)
- }
-}
-
-// debugfAt conditionally prints a debug message with a reference to a position
-// in the input source.
-func debugfAt(p token.Position, f string, a ...interface{}) {
- if debugEnabled() {
- fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
- }
-}
-
-// emit generates a line of code in the output file.
-//
-// emit is a wrapper around writing a formatted string to the output
-// buffer. emit can be invoked in one of two ways:
-//
-// (1) emit("some string")
-// When emit is called with a single string argument, it is simply copied to
-// the output buffer without any further formatting.
-// (2) emit(fmtString, args...)
-// emit can also be invoked in a similar fashion to *Printf() functions,
-// where the first argument is a format string.
-//
-// Calling emit with a single argument that is not a string will result in a
-// panic, as the caller's intent is ambiguous.
-func emit(out io.Writer, indent int, a ...interface{}) {
- const spacesPerIndentLevel = 4
-
- if len(a) < 1 {
- panic("emit() called with no arguments")
- }
-
- if indent > 0 {
- if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
- // Writing to the emit output should not fail. Typically the output
- // is a byte.Buffer; writes to these never fail.
- panic(err)
- }
- }
-
- first, ok := a[0].(string)
- if !ok {
- // First argument must be either the string to emit (case 1 from
- // function-level comment), or a format string (case 2).
- panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
- }
-
- if len(a) == 1 {
- // Single string argument. Assume no formatting requested.
- if _, err := fmt.Fprint(out, first); err != nil {
- // Writing to out should not fail.
- panic(err)
- }
- return
-
- }
-
- // Formatting requested.
- if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
- // Writing to out should not fail.
- panic(err)
- }
-}
-
-// sourceBuffer represents fragments of generated go source code.
-//
-// sourceBuffer provides a convenient way to build up go souce fragments in
-// memory. May be safely zero-value initialized. Not thread-safe.
-type sourceBuffer struct {
- // Current indentation level.
- indent int
-
- // Memory buffer containing contents while they're being generated.
- b bytes.Buffer
-}
-
-func (b *sourceBuffer) reset() {
- b.indent = 0
- b.b.Reset()
-}
-
-func (b *sourceBuffer) incIndent() {
- b.indent++
-}
-
-func (b *sourceBuffer) decIndent() {
- if b.indent <= 0 {
- panic("decIndent() without matching incIndent()")
- }
- b.indent--
-}
-
-func (b *sourceBuffer) emit(a ...interface{}) {
- emit(&b.b, b.indent, a...)
-}
-
-func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
- emit(&b.b, 0 /*indent*/, a...)
-}
-
-func (b *sourceBuffer) inIndent(body func()) {
- b.incIndent()
- body()
- b.decIndent()
-}
-
-func (b *sourceBuffer) write(out io.Writer) error {
- _, err := fmt.Fprint(out, b.b.String())
- return err
-}
-
-// Write implements io.Writer.Write.
-func (b *sourceBuffer) Write(buf []byte) (int, error) {
- return (b.b.Write(buf))
-}
-
-// importStmt represents a single import statement.
-type importStmt struct {
- // Local name of the imported package.
- name string
- // Import path.
- path string
- // Indicates whether the local name is an alias, or simply the final
- // component of the path.
- aliased bool
- // Indicates whether this import was referenced by generated code.
- used bool
- // AST node and file set representing the import statement, if any. These
- // are only non-nil if the import statement originates from an input source
- // file.
- spec *ast.ImportSpec
- fset *token.FileSet
-}
-
-func newImport(p string) *importStmt {
- name := path.Base(p)
- return &importStmt{
- name: name,
- path: p,
- aliased: false,
- }
-}
-
-func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
- name := path.Base(p)
- if name == "" || name == "/" || name == "." {
- panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
- f.Position(spec.Path.Pos()), name))
- }
- if spec.Name != nil {
- name = spec.Name.Name
- }
- return &importStmt{
- name: name,
- path: p,
- aliased: spec.Name != nil,
- spec: spec,
- fset: f,
- }
-}
-
-// String implements fmt.Stringer.String. This generates a string for the import
-// statement appropriate for writing directly to generated code.
-func (i *importStmt) String() string {
- if i.aliased {
- return fmt.Sprintf("%s %q", i.name, i.path)
- }
- return fmt.Sprintf("%q", i.path)
-}
-
-// debugString returns a debug string representing an import statement. This
-// representation is not valid golang code and is used for debugging output.
-func (i *importStmt) debugString() string {
- if i.spec != nil && i.fset != nil {
- return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
- }
- return fmt.Sprintf("(go-marshal import): %s", i)
-}
-
-func (i *importStmt) markUsed() {
- i.used = true
-}
-
-func (i *importStmt) equivalent(other *importStmt) bool {
- return i.name == other.name && i.path == other.path && i.aliased == other.aliased
-}
-
-// importTable represents a collection of importStmts.
-//
-// An importTable may contain multiple import statements referencing the same
-// local name. All import statements aliasing to the same local name are
-// technically ambiguous, as if such an import name is used in the generated
-// code, it's not clear which import statement it refers to. We ignore any
-// potential collisions until actually writing the import table to the generated
-// source file. See importTable.write.
-//
-// Given the following import statements across all the files comprising a
-// package marshalled:
-//
-// "sync"
-// "pkg/sync"
-// "pkg/sentry/kernel"
-// ktime "pkg/sentry/kernel/time"
-//
-// An importTable representing them would look like this:
-//
-// importTable {
-// is: map[string][]*importStmt {
-// "sync": []*importStmt{
-// importStmt{name:"sync", path:"sync", aliased:false}
-// importStmt{name:"sync", path:"pkg/sync", aliased:false}
-// },
-// "kernel": []*importStmt{importStmt{
-// name: "kernel",
-// path: "pkg/sentry/kernel",
-// aliased: false
-// }},
-// "ktime": []*importStmt{importStmt{
-// name: "ktime",
-// path: "pkg/sentry/kernel/time",
-// aliased: true,
-// }},
-// }
-// }
-//
-// Note that the local name "sync" is assigned to two different import
-// statements. This is possible if the import statements are from different
-// source files in the same package.
-//
-// Since go-marshal generates a single output file per package regardless of the
-// number of input files, if "sync" is referenced by any generated code, it's
-// unclear which import statement "sync" refers to. While it's theoretically
-// possible to resolve this by assigning a unique local alias to each instance
-// of the sync package, go-marshal currently aborts when it encounters such an
-// ambiguity.
-//
-// TODO(b/151478251): importTable considers the final component of an import
-// path to be the package name, but this is only a convention. The actual
-// package name is determined by the package statement in the source files for
-// the package.
-type importTable struct {
- // Map of imports and whether they should be copied to the output.
- is map[string][]*importStmt
-}
-
-func newImportTable() *importTable {
- return &importTable{
- is: make(map[string][]*importStmt),
- }
-}
-
-// Merges import statements from other into i.
-func (i *importTable) merge(other *importTable) {
- for name, ims := range other.is {
- i.is[name] = append(i.is[name], ims...)
- }
-}
-
-func (i *importTable) addStmt(s *importStmt) *importStmt {
- i.is[s.name] = append(i.is[s.name], s)
- return s
-}
-
-func (i *importTable) add(s string) *importStmt {
- n := newImport(s)
- return i.addStmt(n)
-}
-
-func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- return i.addStmt(newImportFromSpec(spec, f))
-}
-
-// Marks the import named n as used. If no such import is in the table, returns
-// false.
-func (i *importTable) markUsed(n string) bool {
- if ns, ok := i.is[n]; ok {
- for _, n := range ns {
- n.markUsed()
- }
- return true
- }
- return false
-}
-
-func (i *importTable) clear() {
- for _, is := range i.is {
- for _, i := range is {
- i.used = false
- }
- }
-}
-
-func (i *importTable) write(out io.Writer) error {
- if len(i.is) == 0 {
- // Nothing to import, we're done.
- return nil
- }
-
- imports := make([]string, 0, len(i.is))
- for name, is := range i.is {
- var lastUsed *importStmt
- var ambiguous bool
-
- for _, i := range is {
- if i.used {
- if lastUsed != nil {
- if !i.equivalent(lastUsed) {
- ambiguous = true
- }
- }
- lastUsed = i
- }
- }
-
- if ambiguous {
- // We have two or more import statements across the different source
- // files that share a local name, and at least one of these imports
- // are used by the generated code. This ambiguity can't be resolved
- // by go-marshal and requires the user intervention. Dump a list of
- // the colliding import statements and let the user modify the input
- // files as appropriate.
- var b strings.Builder
- fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
- fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
- // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
- // be true. Therefore the slicing below is safe.
- for _, i := range is[:len(is)-1] {
- fmt.Fprintf(&b, " %v\n", i.debugString())
- }
- fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
- panic(b.String())
- }
-
- if lastUsed != nil {
- imports = append(imports, lastUsed.String())
- }
- }
- sort.Strings(imports)
-
- var b sourceBuffer
- b.emit("import (\n")
- b.incIndent()
- for _, i := range imports {
- b.emit("%s\n", i)
- }
- b.decIndent()
- b.emit(")\n\n")
-
- return b.write(out)
-}