summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2019-10-02 13:00:07 -0700
committerGitHub <noreply@github.com>2019-10-02 13:00:07 -0700
commit9a875306dbabcf335a2abccc08119a1b67d0e51a (patch)
tree0f72c12e951a5eee7156df7a5d63351bc89befa6 /tools/go_marshal
parent38bc0b6b6addd25ceec4f66ef1af41c1e61e2985 (diff)
parent03ce4dd86c9acd6b6148f68d5d2cf025d8c254bb (diff)
Merge branch 'master' into pr_syscall_linux
Diffstat (limited to 'tools/go_marshal')
-rw-r--r--tools/go_marshal/BUILD14
-rw-r--r--tools/go_marshal/README.md164
-rw-r--r--tools/go_marshal/analysis/BUILD13
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go175
-rw-r--r--tools/go_marshal/defs.bzl152
-rw-r--r--tools/go_marshal/gomarshal/BUILD17
-rw-r--r--tools/go_marshal/gomarshal/generator.go382
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go507
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go154
-rw-r--r--tools/go_marshal/gomarshal/util.go387
-rw-r--r--tools/go_marshal/main.go73
-rw-r--r--tools/go_marshal/marshal/BUILD14
-rw-r--r--tools/go_marshal/marshal/marshal.go60
-rw-r--r--tools/go_marshal/test/BUILD31
-rw-r--r--tools/go_marshal/test/benchmark_test.go178
-rw-r--r--tools/go_marshal/test/external/BUILD11
-rw-r--r--tools/go_marshal/test/external/external.go23
-rw-r--r--tools/go_marshal/test/test.go105
18 files changed, 2460 insertions, 0 deletions
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
new file mode 100644
index 000000000..c862b277c
--- /dev/null
+++ b/tools/go_marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_marshal",
+ srcs = ["main.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//tools/go_marshal/gomarshal",
+ ],
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
new file mode 100644
index 000000000..481575bd3
--- /dev/null
+++ b/tools/go_marshal/README.md
@@ -0,0 +1,164 @@
+This package implements the go_marshal utility.
+
+# Overview
+
+`go_marshal` is a code generation utility similar to `go_stateify` for
+automatically generating code to marshal go data structures to memory.
+
+`go_marshal` attempts to improve on `binary.Write` and the sentry's
+`binary.Marshal` by moving the go runtime reflection necessary to marshal a
+struct to compile-time.
+
+`go_marshal` automatically generates implementations for `abi.Marshallable` and
+`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
+implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
+`safemem.Writer.WriteFromBlocks`. Data structures that require custom
+serialization will have manual implementations for these interfaces.
+
+Data structures can be flagged for code generation by adding a struct-level
+comment `// +marshal`.
+
+# Usage
+
+See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`.
+
+The recommended way to generate a go library with marshalling is to use the
+`go_library` with mostly identical configuration as the native go_library rule.
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+```
+
+Under the hood, the `go_marshal` rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (note that the above is the preferred method):
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "<PKGPATH>/gvisor/pkg/abi",
+ "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem",
+ "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem",
+ ],
+)
+```
+
+As part of the interface generation, `go_marshal` also generates some tests for
+sanity checking the struct definitions for potential alignment issues, and a
+simple round-trip test through Marshal/Unmarshal to verify the implementation.
+These tests use reflection to verify properties of the ABI struct, and should be
+considered part of the generated interfaces (but are too expensive to execute at
+runtime). Ensure these tests run at some point.
+
+```
+$ cat BUILD
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+$ blaze build :foo
+$ blaze query ...
+<path-to-dir>:foo_abi_autogen
+<path-to-dir>:foo_abi_autogen_test
+$ blaze test :foo_abi_autogen_test
+<test-output>
+```
+
+# Restrictions
+
+Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
+intended for ABI structs, which have these additional restrictions:
+
+- At the moment, `go_marshal` only supports struct declarations.
+
+- Structs are marshalled as packed types. This means no implicit padding is
+ inserted between fields shorter than the platform register size. For
+ alignment, manually insert padding fields.
+
+- Structs used with `go_marshal` must have a compile-time static size. This
+ means no dynamically sizes fields like slices or strings. Use statically
+ sized array (byte arrays for strings) instead.
+
+- No pointers, channel, map or function pointer fields, and no fields that are
+ arrays of these types. These don't make sense in an ABI data structure.
+
+- We could support opaque pointers as `uintptr`, but this is currently not
+ implemented. Implementing this would require handling the architecture
+ dependent native pointer size.
+
+- Fields must either be a primitive integer type (`byte`,
+ `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+
+- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
+ type.
+
+- `float*` fields are currently not supported, but could be if necessary.
+
+# Appendix
+
+## Working with Non-Packed Structs
+
+ABI structs must generally be packed types, meaning they should have no implicit
+padding between short fields. However, if a field is tagged
+`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower
+mechanism to deal with potentially unaligned fields.
+
+Note that the non-packed property is inheritted by any other struct that embeds
+this struct, since the `go_marshal` tool currently can't reason about alignments
+for embedded structs that are not aligned.
+
+Because of this, it's generally best to avoid using `marshal:"unaligned"` and
+insert explicit padding fields instead.
+
+## Debugging go_marshal
+
+To enable debugging output from the go marshal tool, pass the `-debug` flag to
+the tool. When using the build rules from above, add a `debug = True` field to
+the build rule like this:
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ debug = True,
+)
+```
+
+## Modifying the `go_marshal` Tool
+
+The following are some guidelines for modifying the `go_marshal` tool:
+
+- The `go_marshal` tool currently does a single pass over all types requesting
+ code generation, in arbitrary order. This means the generated code can't
+ directly obtain information about embedded marshallable types at
+ compile-time. One way to work around this restriction is to add a new
+ Marshallable interface method providing this piece of information, and
+ calling it from the generated code. Use this sparingly, as we want to rely
+ on compile-time information as much as possible for performance.
+
+- No runtime reflection in the code generated for the marshallable interface.
+ The entire point of the tool is to avoid runtime reflection. The generated
+ tests may use reflection.
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
new file mode 100644
index 000000000..c859ced77
--- /dev/null
+++ b/tools/go_marshal/analysis/BUILD
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "analysis",
+ testonly = 1,
+ srcs = ["analysis_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
new file mode 100644
index 000000000..9a9a4f298
--- /dev/null
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -0,0 +1,175 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package analysis implements common functionality used by generated
+// go_marshal tests.
+package analysis
+
+// All functions in this package are unsafe and are not intended for general
+// consumption. They contain sharp edge cases and the caller is responsible for
+// ensuring none of them are hit. Callers must be carefully to pass in only sane
+// arguments. Failure to do so may cause panics at best and arbitrary memory
+// corruption at worst.
+//
+// Never use outside of tests.
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+// RandomizeValue assigns random value(s) to an abitrary type. This is intended
+// for used with ABI structs from go_marshal, meaning the typical restrictions
+// apply (fixed-size types, no pointers, maps, channels, etc), and should only
+// be used on zeroed values to avoid overwriting pointers to active go objects.
+//
+// Internally, we populate the type with random data by doing an unsafe cast to
+// access the underlying memory of the type and filling it as if it were a byte
+// slice. This almost gets us what we want, but padding fields named "_" are
+// normally not accessible, so we walk the type and recursively zero all "_"
+// fields.
+//
+// Precondition: x must be a pointer. x must not contain any valid
+// pointers to active go objects (pointer fields aren't allowed in ABI
+// structs anyways), or we'd be violating the go runtime contract and
+// the GC may malfunction.
+func RandomizeValue(x interface{}) {
+ v := reflect.Indirect(reflect.ValueOf(x))
+ if !v.CanSet() {
+ panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument")
+ }
+
+ // Cast the underlying memory for the type into a byte slice.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer.
+ hdr.Data = v.UnsafeAddr()
+ hdr.Len = int(v.Type().Size())
+ hdr.Cap = hdr.Len
+
+ // Fill the byte slice with random data, which in effect fills the type with
+ // random values.
+ n, err := rand.Read(b)
+ if err != nil || n != len(b) {
+ panic("unreachable")
+ }
+
+ // Normally, padding fields are not accessible, so zero them out.
+ reflectZeroPaddingFields(v.Type(), b, false)
+}
+
+// reflectZeroPaddingFields assigns zero values to padding fields for the value
+// of type r, represented by the memory in data. Padding fields are defined as
+// fields with the name "_". If zero is true, the immediate value itself is
+// zeroed. In addition, the type is recursively scanned for padding fields in
+// inner types.
+//
+// This is used for zeroing padding fields after calling RandomizeValue.
+func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) {
+ if zero {
+ for i, _ := range data {
+ data[i] = 0
+ }
+ }
+ switch r.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // These types are explicitly allowed in an ABI type, but we don't need
+ // to recurse further as they're scalar types.
+ case reflect.Struct:
+ for i, numFields := 0, r.NumField(); i < numFields; i++ {
+ f := r.Field(i)
+ off := f.Offset
+ len := f.Type.Size()
+ window := data[off : off+len]
+ reflectZeroPaddingFields(f.Type, window, f.Name == "_")
+ }
+ case reflect.Array:
+ eLen := int(r.Elem().Size())
+ if int(r.Size()) != eLen*r.Len() {
+ panic("Array has unexpected size?")
+ }
+ for i, n := 0, r.Len(); i < n; i++ {
+ reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false)
+ }
+ default:
+ panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind()))
+
+ }
+}
+
+// AlignmentCheck ensures the definition of the type represented by typ doesn't
+// cause the go compiler to emit implicit padding between elements of the type
+// (i.e. fields in a struct).
+//
+// AlignmentCheck doesn't explicitly recurse for embedded structs because any
+// struct present in an ABI struct must also be Marshallable, and therefore
+// they're aligned by definition (or their alignment check would have failed).
+func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // Primitive types are always considered well aligned. Primitive types
+ // that are fields in structs are checked independently, this branch
+ // exists to handle recursive calls to alignmentCheck.
+ case reflect.Struct:
+ xOff := 0
+ nextXOff := 0
+ skipNext := false
+ for i, numFields := 0, typ.NumField(); i < numFields; i++ {
+ xOff = nextXOff
+ f := typ.Field(i)
+ fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size())
+ nextXOff = int(f.Offset + f.Type.Size())
+
+ if f.Name == "_" {
+ // Padding fields need not be aligned.
+ fmt.Printf("Padding field of type %v\n", f.Type)
+ continue
+ }
+
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ skipNext = true
+ continue
+ }
+
+ if skipNext {
+ skipNext = false
+ fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name)
+ continue
+ }
+
+ if xOff != int(f.Offset) {
+ implicitPad := int(f.Offset) - xOff
+ t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad)
+ }
+ }
+
+ // Ensure structs end on a byte explicitly defined by the type.
+ if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
+ implicitPad := int(typ.Size()) - nextXOff
+ f := typ.Field(typ.NumField() - 1) // Final field
+ t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
+ typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
+ }
+ case reflect.Array:
+ // Independent arrays are also always considered well aligned. We only
+ // need to worry about their alignment when they're embedded in structs,
+ // which we handle above.
+ default:
+ t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind())
+ }
+ return true, uint64(typ.Size())
+}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
new file mode 100644
index 000000000..c32eb559f
--- /dev/null
+++ b/tools/go_marshal/defs.bzl
@@ -0,0 +1,152 @@
+"""Marshal is a tool for generating marshalling interfaces for Go types.
+
+The recommended way is to use the go_library rule defined below with mostly
+identical configuration as the native go_library rule.
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+
+Under the hood, the go_marshal rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (the above is still the preferred way):
+
+load("//tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "//tools/go_marshal:marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ],
+)
+"""
+
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+
+def _go_marshal_impl(ctx):
+ """Execute the go_marshal tool."""
+ output = ctx.outputs.lib
+ output_test = ctx.outputs.test
+ (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD")
+
+ decl = "/".join(["gvisor.dev/gvisor", build_dir])
+
+ # Run the marshal command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ args += ["-output_test=%s" % output_test.path]
+ args += ["-declarationPkg=%s" % decl]
+
+ if ctx.attr.debug:
+ args += ["-debug"]
+
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output, output_test],
+ mnemonic = "GoMarshal",
+ progress_message = "go_marshal: %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the
+# package that need to be saved.
+# imports: an optional list of extra, non-aliased, Go-style absolute import
+# paths.
+# out: the name of the generated file output. This must not conflict with any
+# other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+go_marshal = rule(
+ implementation = _go_marshal_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "libname": attr.string(mandatory = True),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")),
+ },
+ outputs = {
+ "lib": "%{name}_unsafe.go",
+ "test": "%{name}_test.go",
+ },
+)
+
+def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs):
+ """wraps the standard go_library and does mashalling interface generation.
+
+ Args:
+ name: Same as native go_library.
+ srcs: Same as native go_library.
+ deps: Same as native go_library.
+ imports: Extra import paths to pass to the go_marshal tool.
+ debug: Enables debugging output from the go_marshal tool.
+ **kwargs: Remaining args to pass to the native go_library rule unmodified.
+ """
+ go_marshal(
+ name = name + "_abi_autogen",
+ libname = name,
+ srcs = [src for src in srcs if src.endswith(".go")],
+ debug = debug,
+ imports = imports,
+ package = name,
+ )
+
+ extra_deps = [
+ "//tools/go_marshal/marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ]
+
+ all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
+ all_deps = deps + [] # + extra_deps
+
+ for extra in extra_deps:
+ if extra not in deps:
+ all_deps.append(extra)
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+
+ # Don't pass importpath arg to go_test.
+ kwargs.pop("importpath", "")
+
+ _go_test(
+ name = name + "_abi_autogen_test",
+ srcs = [name + "_abi_autogen_test.go"],
+ # Generated test has a fixed set of dependencies since we generate these
+ # tests. They should only depend on the library generated above, and the
+ # Marshallable interface.
+ deps = [
+ ":" + name,
+ "//tools/go_marshal/analysis",
+ ],
+ **kwargs
+ )
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
new file mode 100644
index 000000000..a0eae6492
--- /dev/null
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gomarshal",
+ srcs = [
+ "generator.go",
+ "generator_interfaces.go",
+ "generator_tests.go",
+ "util.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
new file mode 100644
index 000000000..641ccd938
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -0,0 +1,382 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gomarshal implements the go_marshal code generator. See README.md.
+package gomarshal
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "sort"
+)
+
+const (
+ marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
+ safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// List of identifiers we use in generated code, that may conflict a
+// similarly-named source identifier. Avoid problems by refusing the generate
+// code when we see these.
+//
+// This only applies to import aliases at the moment. All other identifiers
+// are qualified by a receiver argument, since they're struct fields.
+//
+// All recievers are single letters, so we don't allow import aliases to be a
+// single letter.
+var badIdents = []string{
+ "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ // All single-letter identifiers.
+}
+
+// Generator drives code generation for a single invocation of the go_marshal
+// utility.
+//
+// The Generator holds arguments passed to the tool, and drives parsing,
+// processing and code Generator for all types marked with +marshal declared in
+// the input files.
+//
+// See Generator.run() as the entry point.
+type Generator struct {
+ // Paths to input go source files.
+ inputs []string
+ // Output file to write generated go source.
+ output *os.File
+ // Output file to write generated tests.
+ outputTest *os.File
+ // Package name for the generated file.
+ pkg string
+ // Go import path for package we're processing. This package should directly
+ // declare the type we're generating code for.
+ declaration string
+ // Set of extra packages to import in the generated file.
+ imports *importTable
+}
+
+// NewGenerator creates a new code Generator.
+func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
+ }
+ fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
+ }
+ g := Generator{
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ declaration: declaration,
+ imports: newImportTable(),
+ }
+ for _, i := range imports {
+ // All imports on the extra imports list are unconditionally marked as
+ // used, so they're always added to the generated code.
+ g.imports.add(i).markUsed()
+ }
+ g.imports.add(marshalImport).markUsed()
+ // The follow imports may or may not be used by the generated
+ // code, depending what's required for the target types. Don't
+ // mark these imports as used by default.
+ g.imports.add(usermemImport)
+ g.imports.add(safecopyImport)
+ g.imports.add("unsafe")
+
+ return &g, nil
+}
+
+// writeHeader writes the header for the generated source file. The header
+// includes the package name, package level comments and import statements.
+func (g *Generator) writeHeader() error {
+ var b sourceBuffer
+ b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.output); err != nil {
+ return err
+ }
+
+ return g.imports.write(g.output)
+}
+
+// writeTypeChecks writes a statement to force the compiler to perform a type
+// check for all Marshallable types referenced by the generated code.
+func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
+ if len(ms) == 0 {
+ return nil
+ }
+
+ msl := make([]string, 0, len(ms))
+ for m, _ := range ms {
+ msl = append(msl, m)
+ }
+ sort.Strings(msl)
+
+ var buf bytes.Buffer
+ fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
+
+ for _, m := range msl {
+ fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
+ }
+ fmt.Fprint(&buf, "\n")
+
+ _, err := fmt.Fprint(g.output, buf.String())
+ return err
+}
+
+// parse processes all input files passed this generator and produces a set of
+// parsed go ASTs.
+func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
+ debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
+ for _, path := range g.inputs {
+ debugf(" %s\n", path)
+ }
+
+ files := make([]*ast.File, 0, len(g.inputs))
+ fsets := make([]*token.FileSet, 0, len(g.inputs))
+
+ for _, path := range g.inputs {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err)
+ }
+
+ if debugEnabled() {
+ debugf("AST for %q:\n", path)
+ ast.Print(fset, f)
+ }
+
+ files = append(files, f)
+ fsets = append(fsets, fset)
+ }
+
+ return files, fsets, nil
+}
+
+// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// declarations for which we need to generate the Marshallable interface.
+func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+ var types []*ast.TypeSpec
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Type declaration?
+ if !ok || gdecl.Tok != token.TYPE {
+ debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
+ continue
+ }
+ // Does it have a comment?
+ if gdecl.Doc == nil {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
+ continue
+ }
+ // Does the comment contain a "+marshal" line?
+ marked := false
+ for _, c := range gdecl.Doc.List {
+ if c.Text == "// +marshal" {
+ marked = true
+ break
+ }
+ }
+ if !marked {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ // We already confirmed we're in a type declaration earlier.
+ t := spec.(*ast.TypeSpec)
+ if _, ok := t.Type.(*ast.StructType); ok {
+ debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ }
+ debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ }
+ }
+ return types
+}
+
+// collectImports collects all imports from all input source files. Some of
+// these imports are copied to the generated output, if they're referenced by
+// the generated code.
+//
+// collectImports de-duplicates imports while building the list, and ensures
+// identifiers in the generated code don't conflict with any imported package
+// names.
+func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
+ badImportNames := make(map[string]bool)
+ for _, i := range badIdents {
+ badImportNames[i] = true
+ }
+
+ is := make(map[string]importStmt)
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Import statement?
+ if !ok || gdecl.Tok != token.IMPORT {
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
+ debugf("Collected import '%s' as '%s'\n", i.path, i.name)
+
+ // Make sure we have an import that doesn't use any local names that
+ // would conflict with identifiers in the generated code.
+ if len(i.name) == 1 {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
+ }
+ if badImportNames[i.name] {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
+ }
+ }
+ }
+ return is
+
+}
+
+func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ // We're guaranteed to have only struct type specs by now. See
+ // Generator.collectMarshallabeTypes.
+ i := newInterfaceGenerator(t, fset)
+ i.validate()
+ i.emitMarshallable()
+ return i
+}
+
+// generateOneTestSuite generates a test suite for the automatically generated
+// implementations type t.
+func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
+ i := newTestGenerator(t, g.declaration)
+ i.emitTests()
+ return i
+}
+
+// Run is the entry point to code generation using g.
+//
+// Run parses all input source files specified in g and emits generated code.
+func (g *Generator) Run() error {
+ // Parse our input source files into ASTs and token sets.
+ asts, fsets, err := g.parse()
+ if err != nil {
+ return err
+ }
+
+ if len(asts) != len(fsets) {
+ panic("ASTs and FileSets don't match")
+ }
+
+ // Map of imports in source files; key = local package name, value = import
+ // path.
+ is := make(map[string]importStmt)
+ for i, a := range asts {
+ // Collect all imports from the source files. We may need to copy some
+ // of these to the generated code if they're referenced. This has to be
+ // done before the loop below because we need to process all ASTs before
+ // we start requesting imports to be copied one by one as we encounter
+ // them in each generated source.
+ for name, i := range g.collectImports(a, fsets[i]) {
+ is[name] = i
+ }
+ }
+
+ var impls []*interfaceGenerator
+ var ts []*testGenerator
+ // Set of Marshallable types referenced by generated code.
+ ms := make(map[string]struct{})
+ for i, a := range asts {
+ // Collect type declarations marked for code generation and generate
+ // Marshallable interfaces.
+ for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ impl := g.generateOne(t, fsets[i])
+ // Collect Marshallable types referenced by the generated code.
+ for ref, _ := range impl.ms {
+ ms[ref] = struct{}{}
+ }
+ impls = append(impls, impl)
+ // Collect imports referenced by the generated code and add them to
+ // the list of imports we need to copy to the generated code.
+ for name, _ := range impl.is {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ }
+ }
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
+ }
+
+ // Tool was invoked with input files with no data structures marked for code
+ // generation. This is probably not what the user intended.
+ if len(impls) == 0 {
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
+ for _, i := range g.inputs {
+ fmt.Fprintf(&buf, " %s\n", i)
+ }
+ abort(buf.String())
+ }
+
+ // Write output file header. These include things like package name and
+ // import statements.
+ if err := g.writeHeader(); err != nil {
+ return err
+ }
+
+ // Write type checks for referenced marshallable types to output file.
+ if err := g.writeTypeChecks(ms); err != nil {
+ return err
+ }
+
+ // Write generated interfaces to output file.
+ for _, i := range impls {
+ if err := i.write(g.output); err != nil {
+ return err
+ }
+ }
+
+ // Write generated tests to test file.
+ return g.writeTests(ts)
+}
+
+// writeTests outputs tests for the generated interface implementations to a go
+// source file.
+func (g *Generator) writeTests(ts []*testGenerator) error {
+ var b sourceBuffer
+ b.emit("package %s_test\n\n", g.pkg)
+ if err := b.write(g.outputTest); err != nil {
+ return err
+ }
+
+ imports := newImportTable()
+ for _, t := range ts {
+ imports.merge(t.imports)
+ }
+
+ if err := imports.write(g.outputTest); err != nil {
+ return err
+ }
+
+ for _, t := range ts {
+ if err := t.write(g.outputTest); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
new file mode 100644
index 000000000..a712c14dc
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -0,0 +1,507 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+// interfaceGenerator generates marshalling interfaces for a single type.
+//
+// getState is not thread-safe.
+type interfaceGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // FileSet containing the tokens for the type we're processing.
+ f *token.FileSet
+
+ // is records external packages referenced by the generated implementation.
+ is map[string]struct{}
+
+ // ms records Marshallable types referenced by the generated implementation
+ // of t's interfaces.
+ ms map[string]struct{}
+
+ // as records embedded fields in t that are potentially not packed. The key
+ // is the accessor for the field.
+ as map[string]struct{}
+}
+
+// typeName returns the name of the type this g represents.
+func (g *interfaceGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+// newinterfaceGenerator creates a new interface generator.
+func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &interfaceGenerator{
+ t: t,
+ r: receiverName(t),
+ f: fset,
+ is: make(map[string]struct{}),
+ ms: make(map[string]struct{}),
+ as: make(map[string]struct{}),
+ }
+ g.recordUsedMarshallable(g.typeName())
+ return g
+}
+
+func (g *interfaceGenerator) recordUsedMarshallable(m string) {
+ g.ms[m] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordUsedImport(i string) {
+ g.is[i] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
+ g.as[fieldName] = struct{}{}
+}
+
+func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with a
+// reference position to the input source. Same as abortAt, but uses g to
+// resolve p to position.
+func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
+ abortAt(g.f.Position(p), msg)
+}
+
+// validate ensures the type we're working with can be marshalled. These checks
+// are done ahead of time and in one place so we can make assumptions later.
+func (g *interfaceGenerator) validate() {
+ g.forEachField(func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ g.forEachField(func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we
+ // provide suggestions for some disallowed types and reject
+ // them, then attempt to marshal any remaining types by
+ // invoking the marshal.Marshallable interface on them. If
+ // these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code
+ // will fail with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ a := f.Type.(*ast.ArrayType)
+ if a.Len == nil {
+ g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if len <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+// scalarSize returns the size of type identified by t. If t isn't a primitive
+// type, the size isn't known at code generation time, and must be resolved via
+// the marshal.Marshallable interface.
+func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
+ switch t.Name {
+ case "int8", "uint8", "byte":
+ return 1, false
+ case "int16", "uint16":
+ return 2, false
+ case "int32", "uint32":
+ return 4, false
+ case "int64", "uint64":
+ return 8, false
+ default:
+ return 0, true
+ }
+}
+
+func (g *interfaceGenerator) shift(bufVar string, n int) {
+ g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
+}
+
+func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
+ g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
+}
+
+func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8":
+ g.emit("%s = int8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "uint8":
+ g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "byte":
+ g.emit("%s = %s[0]\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+
+ case "int16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+ case "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+
+ case "int32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+ case "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+
+ case "int64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ case "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ g.recordPotentiallyNonPackedField(accessor)
+ }
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+func (g *interfaceGenerator) emitMarshallable() {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ g.forEachField(func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
new file mode 100644
index 000000000..df25cb5b2
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "io"
+ "strings"
+)
+
+var standardImports = []string{
+ "fmt",
+ "reflect",
+ "testing",
+ "gvisor.dev/gvisor/tools/go_marshal/analysis",
+}
+
+type testGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // Imports used by generated code.
+ imports *importTable
+
+ // Import statement for the package declaring the type we generated code
+ // for. We need this to construct test instances for the type, since the
+ // tests aren't written in the same package.
+ decl *importStmt
+}
+
+func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &testGenerator{
+ t: t,
+ r: receiverName(t),
+ imports: newImportTable(),
+ }
+
+ for _, i := range standardImports {
+ g.imports.add(i).markUsed()
+ }
+ g.decl = g.imports.add(declaration)
+ g.decl.markUsed()
+
+ return g
+}
+
+func (g *testGenerator) typeName() string {
+ return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
+}
+
+func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *testGenerator) testFuncName(base string) string {
+ return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
+}
+
+func (g *testGenerator) inTestFunction(name string, body func()) {
+ g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
+ g.inIndent(body)
+ g.emit("}\n\n")
+}
+
+func (g *testGenerator) emitTestNonZeroSize() {
+ g.inTestFunction("TestSizeNonZero", func() {
+ g.emit("x := &%s{}\n", g.typeName())
+ g.emit("if x.SizeBytes() == 0 {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSuspectAlignment() {
+ g.inTestFunction("TestSuspectAlignment", func() {
+ g.emit("x := %s{}\n", g.typeName())
+ g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
+ g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
+ g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("buf := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalBytes(buf)\n")
+ g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
+
+ g.emit("y.UnmarshalBytes(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("z.UnmarshalUnsafe(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, z) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ })
+ g.emit("}\n")
+ g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTests() {
+ g.emitTestNonZeroSize()
+ g.emitTestSuspectAlignment()
+ g.emitTestMarshalUnmarshalPreservesData()
+}
+
+func (g *testGenerator) write(out io.Writer) error {
+ return g.sourceBuffer.write(out)
+}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
new file mode 100644
index 000000000..967537abf
--- /dev/null
+++ b/tools/go_marshal/gomarshal/util.go
@@ -0,0 +1,387 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "io"
+ "os"
+ "path"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+var debug = flag.Bool("debug", false, "enables debugging output")
+
+// receiverName returns an appropriate receiver name given a type spec.
+func receiverName(t *ast.TypeSpec) string {
+ if len(t.Name.Name) < 1 {
+ // Zero length type name?
+ panic("unreachable")
+ }
+ return strings.ToLower(t.Name.Name[:1])
+}
+
+// kindString returns a user-friendly representation of an AST expr type.
+func kindString(e ast.Expr) string {
+ switch e.(type) {
+ case *ast.Ident:
+ return "scalar"
+ case *ast.ArrayType:
+ return "array"
+ case *ast.StructType:
+ return "struct"
+ case *ast.StarExpr:
+ return "pointer"
+ case *ast.FuncType:
+ return "function"
+ case *ast.InterfaceType:
+ return "interface"
+ case *ast.MapType:
+ return "map"
+ case *ast.ChanType:
+ return "channel"
+ default:
+ return reflect.TypeOf(e).String()
+ }
+}
+
+// fieldDispatcher is a collection of callbacks for handling different types of
+// fields in a struct declaration.
+type fieldDispatcher struct {
+ primitive func(n, t *ast.Ident)
+ selector func(n, tX, tSel *ast.Ident)
+ array func(n, t *ast.Ident, size int)
+ unhandled func(n *ast.Ident)
+}
+
+// Precondition: All dispatch callbacks that will be invoked must be
+// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+func (fd fieldDispatcher) dispatch(f *ast.Field) {
+ // Each field declaration may actually be multiple declarations of the same
+ // type. For example, consider:
+ //
+ // type Point struct {
+ // x, y, z int
+ // }
+ //
+ // We invoke the call-backs once per such instance. Embedded fields are not
+ // allowed, and results in a panic.
+ if len(f.Names) < 1 {
+ panic("Precondition not met: attempted to dispatch on embedded field")
+ }
+
+ for _, name := range f.Names {
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(name, v)
+ case *ast.SelectorExpr:
+ fd.selector(name, v.X.(*ast.Ident), v.Sel)
+ case *ast.ArrayType:
+ len := 0
+ if v.Len != nil {
+ // Non-literal array length is handled by generatorInterfaces.validate().
+ if lenLit, ok := v.Len.(*ast.BasicLit); ok {
+ var err error
+ len, err = strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ switch t := v.Elt.(type) {
+ case *ast.Ident:
+ fd.array(name, t, len)
+ default:
+ fd.array(name, nil, len)
+ }
+ default:
+ fd.unhandled(name)
+ }
+ }
+}
+
+// debugEnabled indicates whether debugging is enabled for gomarshal.
+func debugEnabled() bool {
+ return *debug
+}
+
+// abort aborts the go_marshal tool with the given error message.
+func abort(msg string) {
+ if !strings.HasSuffix(msg, "\n") {
+ msg += "\n"
+ }
+ fmt.Print(msg)
+ os.Exit(1)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with
+// a reference position to the input source.
+func abortAt(p token.Position, msg string) {
+ abort(fmt.Sprintf("%v:\n %s\n", p, msg))
+}
+
+// debugf conditionally prints a debug message.
+func debugf(f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf(f, a...)
+ }
+}
+
+// debugfAt conditionally prints a debug message with a reference to a position
+// in the input source.
+func debugfAt(p token.Position, f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
+ }
+}
+
+// emit generates a line of code in the output file.
+//
+// emit is a wrapper around writing a formatted string to the output
+// buffer. emit can be invoked in one of two ways:
+//
+// (1) emit("some string")
+// When emit is called with a single string argument, it is simply copied to
+// the output buffer without any further formatting.
+// (2) emit(fmtString, args...)
+// emit can also be invoked in a similar fashion to *Printf() functions,
+// where the first argument is a format string.
+//
+// Calling emit with a single argument that is not a string will result in a
+// panic, as the caller's intent is ambiguous.
+func emit(out io.Writer, indent int, a ...interface{}) {
+ const spacesPerIndentLevel = 4
+
+ if len(a) < 1 {
+ panic("emit() called with no arguments")
+ }
+
+ if indent > 0 {
+ if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
+ // Writing to the emit output should not fail. Typically the output
+ // is a byte.Buffer; writes to these never fail.
+ panic(err)
+ }
+ }
+
+ first, ok := a[0].(string)
+ if !ok {
+ // First argument must be either the string to emit (case 1 from
+ // function-level comment), or a format string (case 2).
+ panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
+ }
+
+ if len(a) == 1 {
+ // Single string argument. Assume no formatting requested.
+ if _, err := fmt.Fprint(out, first); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+ return
+
+ }
+
+ // Formatting requested.
+ if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+}
+
+// sourceBuffer represents fragments of generated go source code.
+//
+// sourceBuffer provides a convenient way to build up go souce fragments in
+// memory. May be safely zero-value initialized. Not thread-safe.
+type sourceBuffer struct {
+ // Current indentation level.
+ indent int
+
+ // Memory buffer containing contents while they're being generated.
+ b bytes.Buffer
+}
+
+func (b *sourceBuffer) incIndent() {
+ b.indent++
+}
+
+func (b *sourceBuffer) decIndent() {
+ if b.indent <= 0 {
+ panic("decIndent() without matching incIndent()")
+ }
+ b.indent--
+}
+
+func (b *sourceBuffer) emit(a ...interface{}) {
+ emit(&b.b, b.indent, a...)
+}
+
+func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
+ emit(&b.b, 0 /*indent*/, a...)
+}
+
+func (b *sourceBuffer) inIndent(body func()) {
+ b.incIndent()
+ body()
+ b.decIndent()
+}
+
+func (b *sourceBuffer) write(out io.Writer) error {
+ _, err := fmt.Fprint(out, b.b.String())
+ return err
+}
+
+// Write implements io.Writer.Write.
+func (b *sourceBuffer) Write(buf []byte) (int, error) {
+ return (b.b.Write(buf))
+}
+
+// importStmt represents a single import statement.
+type importStmt struct {
+ // Local name of the imported package.
+ name string
+ // Import path.
+ path string
+ // Indicates whether the local name is an alias, or simply the final
+ // component of the path.
+ aliased bool
+ // Indicates whether this import was referenced by generated code.
+ used bool
+}
+
+func newImport(p string) *importStmt {
+ name := path.Base(p)
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: false,
+ }
+}
+
+func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
+ name := path.Base(p)
+ if name == "" || name == "/" || name == "." {
+ panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
+ f.Position(spec.Path.Pos()), name))
+ }
+ if spec.Name != nil {
+ name = spec.Name.Name
+ }
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: spec.Name != nil,
+ }
+}
+
+func (i *importStmt) String() string {
+ if i.aliased {
+ return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ }
+ return fmt.Sprintf("\"%s\"", i.path)
+}
+
+func (i *importStmt) markUsed() {
+ i.used = true
+}
+
+func (i *importStmt) equivalent(other *importStmt) bool {
+ return i == other
+}
+
+// importTable represents a collection of importStmts.
+type importTable struct {
+ // Map of imports and whether they should be copied to the output.
+ is map[string]*importStmt
+}
+
+func newImportTable() *importTable {
+ return &importTable{
+ is: make(map[string]*importStmt),
+ }
+}
+
+// Merges import statements from other into i. Collisions in import statements
+// result in a panic.
+func (i *importTable) merge(other *importTable) {
+ for name, im := range other.is {
+ if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
+ }
+
+ i.is[name] = im
+ }
+}
+
+func (i *importTable) add(s string) *importStmt {
+ n := newImport(s)
+ i.is[n.name] = n
+ return n
+}
+
+func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ n := newImportFromSpec(spec, f)
+ i.is[n.name] = n
+ return n
+}
+
+// Marks the import named n as used. If no such import is in the table, returns
+// false.
+func (i *importTable) markUsed(n string) bool {
+ if n, ok := i.is[n]; ok {
+ n.markUsed()
+ return true
+ }
+ return false
+}
+
+func (i *importTable) clear() {
+ for _, i := range i.is {
+ i.used = false
+ }
+}
+
+func (i *importTable) write(out io.Writer) error {
+ if len(i.is) == 0 {
+ // Nothing to import, we're done.
+ return nil
+ }
+
+ imports := make([]string, 0, len(i.is))
+ for _, i := range i.is {
+ if i.used {
+ imports = append(imports, i.String())
+ }
+ }
+ sort.Strings(imports)
+
+ var b sourceBuffer
+ b.emit("import (\n")
+ b.incIndent()
+ for _, i := range imports {
+ b.emit("%s\n", i)
+ }
+ b.decIndent()
+ b.emit(")\n\n")
+
+ return b.write(out)
+}
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
new file mode 100644
index 000000000..3d12eb93c
--- /dev/null
+++ b/tools/go_marshal/main.go
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_marshal is a code generation utility for automatically generating code to
+// marshal go data structures to memory.
+//
+// This binary is typically run as part of the build process, and is invoked by
+// the go_marshal bazel rule defined in defs.bzl.
+//
+// See README.md.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_marshal/gomarshal"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+ declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on")
+)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ if *pkg == "" {
+ flag.Usage()
+ fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n")
+ os.Exit(1)
+ }
+
+ var extraImports []string
+ if len(*imports) > 0 {
+ // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus
+ // we check for an empty imports list to avoid emitting an empty string
+ // as an import.
+ extraImports = strings.Split(*imports, ",")
+ }
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := g.Run(); err != nil {
+ panic(err)
+ }
+}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
new file mode 100644
index 000000000..47dda97a1
--- /dev/null
+++ b/tools/go_marshal/marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
new file mode 100644
index 000000000..a313a27ed
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+// Marshallable represents a type that can be marshalled to and from memory.
+type Marshallable interface {
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst. dst must be at least
+ // SizeBytes() long.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src. src must be at least
+ // SizeBytes() long.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type directly to the underlying memory
+ // allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes (usually by calling
+ // UnmarshalBytes directly).
+ UnmarshalUnsafe(src []byte)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
new file mode 100644
index 000000000..fa82f8e9b
--- /dev/null
+++ b/tools/go_marshal/test/BUILD
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+package_group(
+ name = "gomarshal_test",
+ packages = [
+ "//tools/go_marshal/test/...",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":test",
+ "//pkg/binary",
+ "//pkg/sentry/usermem",
+ "//tools/go_marshal/analysis",
+ ],
+)
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test",
+ deps = ["//tools/go_marshal/test/external"],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
new file mode 100644
index 000000000..e70db06d8
--- /dev/null
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -0,0 +1,178 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "bytes"
+ encbin "encoding/binary"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ test "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// Marshalling using the standard encoding/binary package.
+func BenchmarkEncodingBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := encbin.Size(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := bytes.NewBuffer(make([]byte, size))
+ buf.Reset()
+ if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling using the sentry's binary.Marshal.
+func BenchmarkBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling field-by-field with manually-written code.
+func BenchmarkMarshalManual(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, s1.SizeBytes())
+
+ // Marshal
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+
+ // Unmarshal
+ s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
+ s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
+ s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ // Padding: buf[36:40]
+ s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal safe API.
+func BenchmarkGoMarshalSafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalBytes(buf)
+ s2.UnmarshalBytes(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal unsafe API.
+func BenchmarkGoMarshalUnsafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalUnsafe(buf)
+ s2.UnmarshalUnsafe(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
new file mode 100644
index 000000000..8fb43179b
--- /dev/null
+++ b/tools/go_marshal/test/external/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "external",
+ testonly = 1,
+ srcs = ["external.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external",
+ visibility = ["//tools/go_marshal/test:gomarshal_test"],
+)
diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go
new file mode 100644
index 000000000..4be3722f3
--- /dev/null
+++ b/tools/go_marshal/test/external/external.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package external defines types we can import for testing.
+package external
+
+// External is a public Marshallable type for use in testing.
+//
+// +marshal
+type External struct {
+ j int64
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
new file mode 100644
index 000000000..8de02d707
--- /dev/null
+++ b/tools/go_marshal/test/test.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test contains data structures for testing the go_marshal tool.
+package test
+
+import (
+ // We're intentionally using a package name alias here even though it's not
+ // necessary to test the code generator's ability to handle package aliases.
+ ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
+)
+
+// Type1 is a test data type.
+//
+// +marshal
+type Type1 struct {
+ a Type2
+ x, y int64 // Multiple field names.
+ b byte `marshal:"unaligned"` // Short field.
+ c uint64
+ _ uint32 // Unnamed scalar field.
+ _ [6]byte // Unnamed vector field, typical padding.
+ _ [2]byte
+ xs [8]int32
+ as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects.
+ ss Type3
+}
+
+// Type2 is a test data type.
+//
+// +marshal
+type Type2 struct {
+ n int64
+ c byte
+ _ [7]byte
+ m int64
+ a int64
+}
+
+// Type3 is a test data type.
+//
+// +marshal
+type Type3 struct {
+ s int64
+ x ex.External // Type defined in another package.
+}
+
+// Type4 is a test data type.
+//
+// +marshal
+type Type4 struct {
+ c byte
+ x int64 `marshal:"unaligned"`
+ d byte
+ _ [7]byte
+}
+
+// Type5 is a test data type.
+//
+// +marshal
+type Type5 struct {
+ n int64
+ t Type4
+ m int64
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}