summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal
diff options
context:
space:
mode:
authorIan Lewis <ianmlewis@gmail.com>2020-08-17 21:44:31 -0400
committerIan Lewis <ianmlewis@gmail.com>2020-08-17 21:44:31 -0400
commitac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (patch)
tree0cbc5018e8807421d701d190dc20525726c7ca76 /tools/go_marshal/gomarshal
parent352ae1022ce19de28fc72e034cc469872ad79d06 (diff)
parent6d0c5803d557d453f15ac6f683697eeb46dab680 (diff)
Merge branch 'master' into ip-forwarding
- Merges aleksej-paschenko's with HEAD - Adds vfs2 support for ip_forward
Diffstat (limited to 'tools/go_marshal/gomarshal')
-rw-r--r--tools/go_marshal/gomarshal/BUILD10
-rw-r--r--tools/go_marshal/gomarshal/generator.go253
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go457
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go146
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go289
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go622
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go125
-rw-r--r--tools/go_marshal/gomarshal/util.go184
8 files changed, 1608 insertions, 478 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index a0eae6492..44cb33ae4 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -1,17 +1,21 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "gomarshal",
srcs = [
"generator.go",
"generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
"generator_tests.go",
"util.go",
],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ stateify = False,
visibility = [
"//:sandbox",
],
+ deps = ["//tools/tags"],
)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 641ccd938..19bcd4e6a 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -23,17 +23,14 @@ import (
"go/token"
"os"
"sort"
-)
+ "strings"
-const (
- marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
- usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
- safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.dev/gvisor/tools/tags"
)
-// List of identifiers we use in generated code, that may conflict a
-// similarly-named source identifier. Avoid problems by refusing the generate
-// code when we see these.
+// List of identifiers we use in generated code that may conflict with a
+// similarly-named source identifier. Abort gracefully when we see these to
+// avoid potentially confusing compilation failures in generated code.
//
// This only applies to import aliases at the moment. All other identifiers
// are qualified by a receiver argument, since they're struct fields.
@@ -41,10 +38,21 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
+ "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
+// Constructed fromt badIdents in init().
+var badIdentsMap map[string]struct{}
+
+func init() {
+ badIdentsMap = make(map[string]struct{})
+ for _, ident := range badIdents {
+ badIdentsMap[ident] = struct{}{}
+ }
+}
+
// Generator drives code generation for a single invocation of the go_marshal
// utility.
//
@@ -62,15 +70,12 @@ type Generator struct {
outputTest *os.File
// Package name for the generated file.
pkg string
- // Go import path for package we're processing. This package should directly
- // declare the type we're generating code for.
- declaration string
// Set of extra packages to import in the generated file.
imports *importTable
}
// NewGenerator creates a new code Generator.
-func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) {
f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
@@ -80,25 +85,29 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports
return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
}
g := Generator{
- inputs: srcs,
- output: f,
- outputTest: fTest,
- pkg: pkg,
- declaration: declaration,
- imports: newImportTable(),
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ imports: newImportTable(),
}
for _, i := range imports {
// All imports on the extra imports list are unconditionally marked as
- // used, so they're always added to the generated code.
+ // used, so that they're always added to the generated code.
g.imports.add(i).markUsed()
}
- g.imports.add(marshalImport).markUsed()
- // The follow imports may or may not be used by the generated
- // code, depending what's required for the target types. Don't
- // mark these imports as used by default.
- g.imports.add(usermemImport)
- g.imports.add(safecopyImport)
+
+ // The following imports may or may not be used by the generated code,
+ // depending on what's required for the target types. Don't mark these as
+ // used by default.
+ g.imports.add("io")
+ g.imports.add("reflect")
+ g.imports.add("runtime")
g.imports.add("unsafe")
+ g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
+ g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
+ g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
return &g, nil
}
@@ -108,6 +117,14 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports
func (g *Generator) writeHeader() error {
var b sourceBuffer
b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(g.inputs); len(t) > 0 {
+ b.emit(strings.Join(t.Lines(), "\n"))
+ b.emit("\n\n")
+ }
+
+ // Package header.
b.emit("package %s\n\n", g.pkg)
if err := b.write(g.output); err != nil {
return err
@@ -172,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
-// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// sliceAPI carries information about the '+marshal slice' directive.
+type sliceAPI struct {
+ // Comment node in the AST containing the +marshal tag.
+ comment *ast.Comment
+ // Identifier fragment to use when naming generated functions for the slice
+ // API.
+ ident string
+ // Whether the generated functions should reference the newtype name, or the
+ // inner type name. Only meaningful on newtype declarations on primitives.
+ inner bool
+}
+
+// marshallableType carries information about a type marked with the '+marshal'
+// directive.
+type marshallableType struct {
+ spec *ast.TypeSpec
+ slice *sliceAPI
+}
+
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
+ mt := marshallableType{
+ spec: spec,
+ slice: nil,
+ }
+
+ var unhandledTags []string
+
+ for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
+ if strings.HasPrefix(tag, "slice:") {
+ tokens := strings.Split(tag, ":")
+ if len(tokens) < 2 || len(tokens) > 3 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
+ }
+ if len(tokens[1]) == 0 {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
+ }
+
+ sa := &sliceAPI{
+ comment: tagLine,
+ ident: tokens[1],
+ }
+ mt.slice = sa
+
+ if len(tokens) == 3 {
+ if tokens[2] != "inner" {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
+ }
+ sa.inner = true
+ }
+
+ continue
+ }
+
+ unhandledTags = append(unhandledTags, tag)
+ }
+
+ if len(unhandledTags) > 0 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
+ }
+
+ return mt
+}
+
+// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
- var types []*ast.TypeSpec
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
+ var types []marshallableType
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
@@ -190,9 +270,11 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
}
// Does the comment contain a "+marshal" line?
marked := false
+ var tagLine *ast.Comment
for _, c := range gdecl.Doc.List {
- if c.Text == "// +marshal" {
+ if strings.HasPrefix(c.Text, "// +marshal") {
marked = true
+ tagLine = c
break
}
}
@@ -201,14 +283,23 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
continue
}
for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier.
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
t := spec.(*ast.TypeSpec)
- if _, ok := t.Type.(*ast.StructType); ok {
- debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
- types = append(types, t)
- continue
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ default:
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ types = append(types, newMarshallableType(f, tagLine, t))
+
}
}
return types
@@ -222,11 +313,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
// identifiers in the generated code don't conflict with any imported package
// names.
func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
- badImportNames := make(map[string]bool)
- for _, i := range badIdents {
- badImportNames[i] = true
- }
-
is := make(map[string]importStmt)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -240,10 +326,10 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
// Make sure we have an import that doesn't use any local names that
// would conflict with identifiers in the generated code.
- if len(i.name) == 1 {
+ if len(i.name) == 1 && i.name != "_" {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
- if badImportNames[i.name] {
+ if _, ok := badIdentsMap[i.name]; ok {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
}
}
@@ -252,20 +338,40 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- // We're guaranteed to have only struct type specs by now. See
- // Generator.collectMarshallabeTypes.
- i := newInterfaceGenerator(t, fset)
- i.validate()
- i.emitMarshallable()
+func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, fset)
+ switch ty := t.spec.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct(t.spec, ty)
+ i.emitMarshallableForStruct(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForStruct(ty, t.slice)
+ }
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
+ }
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.spec.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ }
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
return i
}
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
- i := newTestGenerator(t, g.declaration)
- i.emitTests()
+func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec)
+ i.emitTests(t.slice)
return i
}
@@ -304,35 +410,24 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
- for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
- for ref, _ := range impl.ms {
+ for ref := range impl.ms {
ms[ref] = struct{}{}
}
impls = append(impls, impl)
// Collect imports referenced by the generated code and add them to
// the list of imports we need to copy to the generated code.
- for name, _ := range impl.is {
+ for name := range impl.is {
if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
}
}
ts = append(ts, g.generateOneTestSuite(t))
}
}
- // Tool was invoked with input files with no data structures marked for code
- // generation. This is probably not what the user intended.
- if len(impls) == 0 {
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
- for _, i := range g.inputs {
- fmt.Fprintf(&buf, " %s\n", i)
- }
- abort(buf.String())
- }
-
// Write output file header. These include things like package name and
// import statements.
if err := g.writeHeader(); err != nil {
@@ -359,11 +454,12 @@ func (g *Generator) Run() error {
// source file.
func (g *Generator) writeTests(ts []*testGenerator) error {
var b sourceBuffer
- b.emit("package %s_test\n\n", g.pkg)
+ b.emit("package %s\n\n", g.pkg)
if err := b.write(g.outputTest); err != nil {
return err
}
+ // Collect and write test import statements.
imports := newImportTable()
for _, t := range ts {
imports.merge(t.imports)
@@ -373,6 +469,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Write test functions.
+
+ // If we didn't generate any Marshallable implementations, we can't just
+ // emit an empty test file, since that causes the build to fail with "no
+ // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
+ // omit the entire package since the outputs are already defined before
+ // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
+ // empty example instead.
+ if len(ts) == 0 {
+ b.reset()
+ b.emit("func Example() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty to ensure this file contains at least\n")
+ b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
+ b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
+ b.emit("// in a build failure.\n")
+ })
+ b.emit("}\n")
+ return b.write(g.outputTest)
+ }
+
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index a712c14dc..e3c3dac63 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -77,25 +74,12 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) {
func (g *interfaceGenerator) recordUsedImport(i string) {
g.is[i] = struct{}{}
-
}
func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
g.as[fieldName] = struct{}{}
}
-func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
@@ -103,67 +87,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
- g.forEachField(func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- g.forEachField(func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n, _ *ast.Ident, len int) {
- a := f.Type.(*ast.ArrayType)
- if a.Len == nil {
- g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-
- if len <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
@@ -190,7 +113,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +137,26 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,250 +165,112 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor, _ := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- return strings.Join(cs, " && "), true
+// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
+// byte slice, bypassing escape analysis. The caller is responsible for ensuring
+// srcPtr lives until they're done with dstVar, the runtime does not consider
+// dstVar dependent on srcPtr due to the escape analysis bypass.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "hdr", and cannot be used
+// in a context where it is already bound.
+func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.recordUsedImport("gohacks")
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
}
-func (g *interfaceGenerator) emitMarshallable() {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
- g.forEachField(func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
- }
- }
- }
- })
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
- }
- },
- selector: func(n, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for i := 0; i < %d; i++ {\n", size)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
+// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
+// to a byte slice. As part of the cast, the byte slice is made to look
+// independent of the src slice by bypassing escape analysis. This means the
+// byte slice can be used without causing the source to escape. The caller is
+// responsible for ensuring srcPtr lives until they're done with dstVar, as the
+// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifiers "ptr", "val" and "hdr",
+// and cannot be used in a context where these identifiers are already bound.
+func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.emitNoEscapeSliceDataPointer(srcPtr, "val")
+
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(val)\n")
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
- g.emit("for i := 0; i < %d; i++ {\n", size)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
+// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
+// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
+// ensuring srcPtr lives until they're done with dstVar, as the runtime no
+// longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "ptr" cannot be used in a
+// context where this identifier is already bound.
+func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
+ g.recordUsedImport("gohacks")
+ g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
+ g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
+}
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
+func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
+ g.emit("// must live until the use above.\n")
+ g.emit("runtime.KeepAlive(%s)\n", ptrVar)
+}
- }
- })
- g.emit("}\n\n")
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- })
- g.emit("}\n\n")
+ fmt.Fprintf(b, "%s", e.Op)
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- })
- g.emit("}\n\n")
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..72ef03a22
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ lenExpr := g.arrayLenExpr(a)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d * %s\n", size, lenExpr)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..39f654ea8
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,289 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+
+ eltType := g.typeName()
+ if slice.inner {
+ eltType = nt.Name
+ }
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..4b9cea08a
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,622 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+ packed := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if packed {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ packed = false
+ }
+ }
+ }
+ })
+ return packed
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(fallback)
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(fallback)
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
+ g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("length, err := w.Write(buf)\n")
+ g.emit("return int64(length), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
+ thisPacked := g.isStructPacked(st)
+
+ if slice.inner {
+ abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
+ }
+
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+
+ g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
+ g.emit("limit := length/size\n")
+ g.emit("for idx := 0; idx < limit; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n")
+ g.emit("// final element, but may not contain valid data for the entire range. This may\n")
+ g.emit("// result in unmarshalling zero values for some parts of the object.\n")
+ g.emit("if length%size != 0 {\n")
+ g.inIndent(func() {
+ g.emit("idx := limit\n")
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index df25cb5b2..631295373 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -22,12 +22,19 @@ import (
)
var standardImports = []string{
+ "bytes",
"fmt",
"reflect",
"testing",
+
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
+var sliceAPIImports = []string{
+ "encoding/binary",
+ "gvisor.dev/gvisor/pkg/usermem",
+}
+
type testGenerator struct {
sourceBuffer
@@ -46,10 +53,7 @@ type testGenerator struct {
decl *importStmt
}
-func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
+func newTestGenerator(t *ast.TypeSpec) *testGenerator {
g := &testGenerator{
t: t,
r: receiverName(t),
@@ -59,22 +63,17 @@ func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
for _, i := range standardImports {
g.imports.add(i).markUsed()
}
- g.decl = g.imports.add(declaration)
- g.decl.markUsed()
+ // These imports are used if a type requests the slice API. Don't
+ // mark them as used by default.
+ for _, i := range sliceAPIImports {
+ g.imports.add(i)
+ }
return g
}
func (g *testGenerator) typeName() string {
- return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
-}
-
-func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
+ return g.t.Name.Name
}
func (g *testGenerator) testFuncName(base string) string {
@@ -89,10 +88,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) {
func (g *testGenerator) emitTestNonZeroSize() {
g.inTestFunction("TestSizeNonZero", func() {
- g.emit("x := &%s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("if x.SizeBytes() == 0 {\n")
g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
})
g.emit("}\n")
})
@@ -100,7 +99,7 @@ func (g *testGenerator) emitTestNonZeroSize() {
func (g *testGenerator) emitTestSuspectAlignment() {
g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("x := %s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
})
}
@@ -118,35 +117,115 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
g.emit("y.UnmarshalBytes(buf)\n")
g.emit("if !reflect.DeepEqual(x, y) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
})
g.emit("}\n")
g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
})
g.emit("}\n\n")
g.emit("z.UnmarshalUnsafe(buf)\n")
g.emit("if !reflect.DeepEqual(x, z) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
})
g.emit("}\n")
g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
+ for _, name := range []string{"binary", "usermem"} {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
+ }
+ }
+
+ g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
+ g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+ g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
+ g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
+ g.emit("buf.Reset()\n")
+ g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
+ })
+ g.emit("}\n")
+ g.emit("bufUnsafe := make([]byte, size)\n")
+ g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
+
+ g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+ })
+}
+
+func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
+ g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
+ g.emit("var x, y, yUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("var buf bytes.Buffer\n\n")
+
+ g.emit("x.WriteTo(&buf)\n")
+ g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
+ g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
+
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
+ g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
+ g.emit("var x %s\n", g.typeName())
+ g.emit("sizeFromConcrete := x.SizeBytes()\n")
+ g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
})
g.emit("}\n")
})
}
-func (g *testGenerator) emitTests() {
+func (g *testGenerator) emitTests(slice *sliceAPI) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
+
+ if slice != nil {
+ g.emitTestMarshalUnmarshalSlicePreservesData(slice)
+ }
}
func (g *testGenerator) write(out io.Writer) error {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 967537abf..d94314302 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
- "strconv"
"strings"
)
@@ -64,12 +63,18 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
- array func(n, t *ast.Ident, size int)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
@@ -96,22 +101,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, v, t)
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
@@ -219,6 +214,11 @@ type sourceBuffer struct {
b bytes.Buffer
}
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
func (b *sourceBuffer) incIndent() {
b.indent++
}
@@ -265,6 +265,11 @@ type importStmt struct {
aliased bool
// Indicates whether this import was referenced by generated code.
used bool
+ // AST node and file set representing the import statement, if any. These
+ // are only non-nil if the import statement originates from an input source
+ // file.
+ spec *ast.ImportSpec
+ fset *token.FileSet
}
func newImport(p string) *importStmt {
@@ -290,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
name: name,
path: p,
aliased: spec.Name != nil,
+ spec: spec,
+ fset: f,
}
}
+// String implements fmt.Stringer.String. This generates a string for the import
+// statement appropriate for writing directly to generated code.
func (i *importStmt) String() string {
if i.aliased {
- return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ return fmt.Sprintf("%s %q", i.name, i.path)
}
- return fmt.Sprintf("\"%s\"", i.path)
+ return fmt.Sprintf("%q", i.path)
+}
+
+// debugString returns a debug string representing an import statement. This
+// representation is not valid golang code and is used for debugging output.
+func (i *importStmt) debugString() string {
+ if i.spec != nil && i.fset != nil {
+ return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
+ }
+ return fmt.Sprintf("(go-marshal import): %s", i)
}
func (i *importStmt) markUsed() {
@@ -305,58 +323,111 @@ func (i *importStmt) markUsed() {
}
func (i *importStmt) equivalent(other *importStmt) bool {
- return i == other
+ return i.name == other.name && i.path == other.path && i.aliased == other.aliased
}
// importTable represents a collection of importStmts.
+//
+// An importTable may contain multiple import statements referencing the same
+// local name. All import statements aliasing to the same local name are
+// technically ambiguous, as if such an import name is used in the generated
+// code, it's not clear which import statement it refers to. We ignore any
+// potential collisions until actually writing the import table to the generated
+// source file. See importTable.write.
+//
+// Given the following import statements across all the files comprising a
+// package marshalled:
+//
+// "sync"
+// "pkg/sync"
+// "pkg/sentry/kernel"
+// ktime "pkg/sentry/kernel/time"
+//
+// An importTable representing them would look like this:
+//
+// importTable {
+// is: map[string][]*importStmt {
+// "sync": []*importStmt{
+// importStmt{name:"sync", path:"sync", aliased:false}
+// importStmt{name:"sync", path:"pkg/sync", aliased:false}
+// },
+// "kernel": []*importStmt{importStmt{
+// name: "kernel",
+// path: "pkg/sentry/kernel",
+// aliased: false
+// }},
+// "ktime": []*importStmt{importStmt{
+// name: "ktime",
+// path: "pkg/sentry/kernel/time",
+// aliased: true,
+// }},
+// }
+// }
+//
+// Note that the local name "sync" is assigned to two different import
+// statements. This is possible if the import statements are from different
+// source files in the same package.
+//
+// Since go-marshal generates a single output file per package regardless of the
+// number of input files, if "sync" is referenced by any generated code, it's
+// unclear which import statement "sync" refers to. While it's theoretically
+// possible to resolve this by assigning a unique local alias to each instance
+// of the sync package, go-marshal currently aborts when it encounters such an
+// ambiguity.
+//
+// TODO(b/151478251): importTable considers the final component of an import
+// path to be the package name, but this is only a convention. The actual
+// package name is determined by the package statement in the source files for
+// the package.
type importTable struct {
// Map of imports and whether they should be copied to the output.
- is map[string]*importStmt
+ is map[string][]*importStmt
}
func newImportTable() *importTable {
return &importTable{
- is: make(map[string]*importStmt),
+ is: make(map[string][]*importStmt),
}
}
-// Merges import statements from other into i. Collisions in import statements
-// result in a panic.
+// Merges import statements from other into i.
func (i *importTable) merge(other *importTable) {
- for name, im := range other.is {
- if dup, ok := i.is[name]; ok && dup.equivalent(im) {
- panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
- }
-
- i.is[name] = im
+ for name, ims := range other.is {
+ i.is[name] = append(i.is[name], ims...)
}
}
+func (i *importTable) addStmt(s *importStmt) *importStmt {
+ i.is[s.name] = append(i.is[s.name], s)
+ return s
+}
+
func (i *importTable) add(s string) *importStmt {
n := newImport(s)
- i.is[n.name] = n
- return n
+ return i.addStmt(n)
}
func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- n := newImportFromSpec(spec, f)
- i.is[n.name] = n
- return n
+ return i.addStmt(newImportFromSpec(spec, f))
}
// Marks the import named n as used. If no such import is in the table, returns
// false.
func (i *importTable) markUsed(n string) bool {
- if n, ok := i.is[n]; ok {
- n.markUsed()
+ if ns, ok := i.is[n]; ok {
+ for _, n := range ns {
+ n.markUsed()
+ }
return true
}
return false
}
func (i *importTable) clear() {
- for _, i := range i.is {
- i.used = false
+ for _, is := range i.is {
+ for _, i := range is {
+ i.used = false
+ }
}
}
@@ -367,9 +438,42 @@ func (i *importTable) write(out io.Writer) error {
}
imports := make([]string, 0, len(i.is))
- for _, i := range i.is {
- if i.used {
- imports = append(imports, i.String())
+ for name, is := range i.is {
+ var lastUsed *importStmt
+ var ambiguous bool
+
+ for _, i := range is {
+ if i.used {
+ if lastUsed != nil {
+ if !i.equivalent(lastUsed) {
+ ambiguous = true
+ }
+ }
+ lastUsed = i
+ }
+ }
+
+ if ambiguous {
+ // We have two or more import statements across the different source
+ // files that share a local name, and at least one of these imports
+ // are used by the generated code. This ambiguity can't be resolved
+ // by go-marshal and requires the user intervention. Dump a list of
+ // the colliding import statements and let the user modify the input
+ // files as appropriate.
+ var b strings.Builder
+ fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
+ fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
+ // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
+ // be true. Therefore the slicing below is safe.
+ for _, i := range is[:len(is)-1] {
+ fmt.Fprintf(&b, " %v\n", i.debugString())
+ }
+ fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
+ panic(b.String())
+ }
+
+ if lastUsed != nil {
+ imports = append(imports, lastUsed.String())
}
}
sort.Strings(imports)