diff options
Diffstat (limited to 'tools/go_marshal')
26 files changed, 4538 insertions, 0 deletions
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD new file mode 100644 index 000000000..be49cf9c8 --- /dev/null +++ b/tools/go_marshal/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_binary") + +licenses(["notice"]) + +go_binary( + name = "go_marshal", + srcs = ["main.go"], + visibility = [ + "//:sandbox", + ], + deps = [ + "//tools/go_marshal/gomarshal", + ], +) + +config_setting( + name = "marshal_config_verbose", + values = {"define": "gomarshal=verbose"}, +) diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md new file mode 100644 index 000000000..4886efddf --- /dev/null +++ b/tools/go_marshal/README.md @@ -0,0 +1,116 @@ +This package implements the go_marshal utility. + +# Overview + +`go_marshal` is a code generation utility similar to `go_stateify` for +automatically generating code to marshal go data structures to memory. + +`go_marshal` attempts to improve on `binary.Write` and the sentry's +`binary.Marshal` by moving the go runtime reflection necessary to marshal a +struct to compile-time. + +`go_marshal` automatically generates implementations for `abi.Marshallable` and +`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall +implementations) can directly invoke `safemem.Reader.ReadToBlocks` and +`safemem.Writer.WriteFromBlocks`. Data structures that require custom +serialization will have manual implementations for these interfaces. + +Data structures can be flagged for code generation by adding a struct-level +comment `// +marshal`. + +# Usage + +See `defs.bzl`: a new rule is provided, `go_marshal`. + +Under the hood, the `go_marshal` rule is used to generate a file that will +appear in a Go target; the output file should appear explicitly in a srcs list. +For example (note that the above is the preferred method): + +``` +load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal") + +go_marshal( + name = "foo_abi", + srcs = ["foo.go"], + out = "foo_abi.go", + package = "foo", +) + +go_library( + name = "foo", + srcs = [ + "foo.go", + "foo_abi.go", + ], + ... +) +``` + +As part of the interface generation, `go_marshal` also generates some tests for +sanity checking the struct definitions for potential alignment issues, and a +simple round-trip test through Marshal/Unmarshal to verify the implementation. +These tests use reflection to verify properties of the ABI struct, and should be +considered part of the generated interfaces (but are too expensive to execute at +runtime). Ensure these tests run at some point. + +# Restrictions + +Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is +intended for ABI structs, which have these additional restrictions: + +- At the moment, `go_marshal` only supports struct declarations. + +- Structs are marshalled as packed types. This means no implicit padding is + inserted between fields shorter than the platform register size. For + alignment, manually insert padding fields. + +- Structs used with `go_marshal` must have a compile-time static size. This + means no dynamically sizes fields like slices or strings. Use statically + sized array (byte arrays for strings) instead. + +- No pointers, channel, map or function pointer fields, and no fields that are + arrays of these types. These don't make sense in an ABI data structure. + +- We could support opaque pointers as `uintptr`, but this is currently not + implemented. Implementing this would require handling the architecture + dependent native pointer size. + +- Fields must either be a primitive integer type (`byte`, + `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. + +- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric + type. + +- `float*` fields are currently not supported, but could be if necessary. + +# Appendix + +## Working with Non-Packed Structs + +ABI structs must generally be packed types, meaning they should have no implicit +padding between short fields. However, if a field is tagged +`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower +mechanism to deal with potentially unaligned fields. + +Note that the non-packed property is inheritted by any other struct that embeds +this struct, since the `go_marshal` tool currently can't reason about alignments +for embedded structs that are not aligned. + +Because of this, it's generally best to avoid using `marshal:"unaligned"` and +insert explicit padding fields instead. + +## Modifying the `go_marshal` Tool + +The following are some guidelines for modifying the `go_marshal` tool: + +- The `go_marshal` tool currently does a single pass over all types requesting + code generation, in arbitrary order. This means the generated code can't + directly obtain information about embedded marshallable types at + compile-time. One way to work around this restriction is to add a new + Marshallable interface method providing this piece of information, and + calling it from the generated code. Use this sparingly, as we want to rely + on compile-time information as much as possible for performance. + +- No runtime reflection in the code generated for the marshallable interface. + The entire point of the tool is to avoid runtime reflection. The generated + tests may use reflection. diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD new file mode 100644 index 000000000..c2a4d45c4 --- /dev/null +++ b/tools/go_marshal/analysis/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "analysis", + testonly = 1, + srcs = ["analysis_unsafe.go"], + visibility = [ + "//:sandbox", + ], +) diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go new file mode 100644 index 000000000..cd55cf5cb --- /dev/null +++ b/tools/go_marshal/analysis/analysis_unsafe.go @@ -0,0 +1,179 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package analysis implements common functionality used by generated +// go_marshal tests. +package analysis + +// All functions in this package are unsafe and are not intended for general +// consumption. They contain sharp edge cases and the caller is responsible for +// ensuring none of them are hit. Callers must be carefully to pass in only sane +// arguments. Failure to do so may cause panics at best and arbitrary memory +// corruption at worst. +// +// Never use outside of tests. + +import ( + "fmt" + "math/rand" + "reflect" + "testing" + "unsafe" +) + +// RandomizeValue assigns random value(s) to an abitrary type. This is intended +// for used with ABI structs from go_marshal, meaning the typical restrictions +// apply (fixed-size types, no pointers, maps, channels, etc), and should only +// be used on zeroed values to avoid overwriting pointers to active go objects. +// +// Internally, we populate the type with random data by doing an unsafe cast to +// access the underlying memory of the type and filling it as if it were a byte +// slice. This almost gets us what we want, but padding fields named "_" are +// normally not accessible, so we walk the type and recursively zero all "_" +// fields. +// +// Precondition: x must be a pointer. x must not contain any valid +// pointers to active go objects (pointer fields aren't allowed in ABI +// structs anyways), or we'd be violating the go runtime contract and +// the GC may malfunction. +func RandomizeValue(x interface{}) { + v := reflect.Indirect(reflect.ValueOf(x)) + if !v.CanSet() { + panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument") + } + + // Cast the underlying memory for the type into a byte slice. + var b []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer. + hdr.Data = v.UnsafeAddr() + hdr.Len = int(v.Type().Size()) + hdr.Cap = hdr.Len + + // Fill the byte slice with random data, which in effect fills the type with + // random values. + n, err := rand.Read(b) + if err != nil || n != len(b) { + panic("unreachable") + } + + // Normally, padding fields are not accessible, so zero them out. + reflectZeroPaddingFields(v.Type(), b, false) +} + +// reflectZeroPaddingFields assigns zero values to padding fields for the value +// of type r, represented by the memory in data. Padding fields are defined as +// fields with the name "_". If zero is true, the immediate value itself is +// zeroed. In addition, the type is recursively scanned for padding fields in +// inner types. +// +// This is used for zeroing padding fields after calling RandomizeValue. +func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) { + if zero { + for i, _ := range data { + data[i] = 0 + } + } + switch r.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // These types are explicitly allowed in an ABI type, but we don't need + // to recurse further as they're scalar types. + case reflect.Struct: + for i, numFields := 0, r.NumField(); i < numFields; i++ { + f := r.Field(i) + off := f.Offset + len := f.Type.Size() + window := data[off : off+len] + reflectZeroPaddingFields(f.Type, window, f.Name == "_") + } + case reflect.Array: + eLen := int(r.Elem().Size()) + if int(r.Size()) != eLen*r.Len() { + panic("Array has unexpected size?") + } + for i, n := 0, r.Len(); i < n; i++ { + reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false) + } + default: + panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind())) + + } +} + +// AlignmentCheck ensures the definition of the type represented by typ doesn't +// cause the go compiler to emit implicit padding between elements of the type +// (i.e. fields in a struct). +// +// AlignmentCheck doesn't explicitly recurse for embedded structs because any +// struct present in an ABI struct must also be Marshallable, and therefore +// they're aligned by definition (or their alignment check would have failed). +func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) { + switch typ.Kind() { + case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64: + // Primitive types are always considered well aligned. Primitive types + // that are fields in structs are checked independently, this branch + // exists to handle recursive calls to alignmentCheck. + case reflect.Struct: + xOff := 0 + nextXOff := 0 + skipNext := false + for i, numFields := 0, typ.NumField(); i < numFields; i++ { + xOff = nextXOff + f := typ.Field(i) + fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size()) + nextXOff = int(f.Offset + f.Type.Size()) + + if f.Name == "_" { + // Padding fields need not be aligned. + fmt.Printf("Padding field of type %v\n", f.Type) + continue + } + + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + skipNext = true + continue + } + + if skipNext { + skipNext = false + fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name) + continue + } + + if xOff != int(f.Offset) { + implicitPad := int(f.Offset) - xOff + t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad) + } + } + + // Ensure structs end on a byte explicitly defined by the type. + if typ.NumField() > 0 && nextXOff != int(typ.Size()) { + implicitPad := int(typ.Size()) - nextXOff + f := typ.Field(typ.NumField() - 1) // Final field + if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" { + // Final field explicitly marked unaligned. + break + } + t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.", + typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name) + } + case reflect.Array: + // Independent arrays are also always considered well aligned. We only + // need to worry about their alignment when they're embedded in structs, + // which we handle above. + default: + t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind()) + } + return true, uint64(typ.Size()) +} diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl new file mode 100644 index 000000000..323e33882 --- /dev/null +++ b/tools/go_marshal/defs.bzl @@ -0,0 +1,65 @@ +"""Marshal is a tool for generating marshalling interfaces for Go types.""" + +def _go_marshal_impl(ctx): + """Execute the go_marshal tool.""" + output = ctx.outputs.lib + output_test = ctx.outputs.test + + # Run the marshal command. + args = ["-output=%s" % output.path] + args += ["-pkg=%s" % ctx.attr.package] + args += ["-output_test=%s" % output_test.path] + + if ctx.attr.debug: + args += ["-debug"] + + args += ["--"] + for src in ctx.attr.srcs: + args += [f.path for f in src.files.to_list()] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output, output_test], + mnemonic = "GoMarshal", + progress_message = "go_marshal: %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +# Generates save and restore logic from a set of Go files. +# +# Args: +# name: the name of the rule. +# srcs: the input source files. These files should include all structs in the +# package that need to be saved. +# imports: an optional list of extra, non-aliased, Go-style absolute import +# paths. +# out: the name of the generated file output. This must not conflict with any +# other files and must be added to the srcs of the relevant go_library. +# package: the package name for the input sources. +go_marshal = rule( + implementation = _go_marshal_impl, + attrs = { + "srcs": attr.label_list(mandatory = True, allow_files = True), + "imports": attr.string_list(mandatory = False), + "package": attr.string(mandatory = True), + "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), + "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")), + }, + outputs = { + "lib": "%{name}_unsafe.go", + "test": "%{name}_test.go", + }, +) + +# marshal_deps are the dependencies requied by generated code. +marshal_deps = [ + "//pkg/gohacks", + "//pkg/safecopy", + "//pkg/usermem", + "//tools/go_marshal/marshal", +] + +# marshal_test_deps are required by test targets. +marshal_test_deps = [ + "//tools/go_marshal/analysis", +] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD new file mode 100644 index 000000000..44cb33ae4 --- /dev/null +++ b/tools/go_marshal/gomarshal/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "gomarshal", + srcs = [ + "generator.go", + "generator_interfaces.go", + "generator_interfaces_array_newtype.go", + "generator_interfaces_primitive_newtype.go", + "generator_interfaces_struct.go", + "generator_tests.go", + "util.go", + ], + stateify = False, + visibility = [ + "//:sandbox", + ], + deps = ["//tools/tags"], +) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go new file mode 100644 index 000000000..177013dbb --- /dev/null +++ b/tools/go_marshal/gomarshal/generator.go @@ -0,0 +1,499 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomarshal implements the go_marshal code generator. See README.md. +package gomarshal + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "sort" + "strings" + + "gvisor.dev/gvisor/tools/tags" +) + +// List of identifiers we use in generated code that may conflict with a +// similarly-named source identifier. Abort gracefully when we see these to +// avoid potentially confusing compilation failures in generated code. +// +// This only applies to import aliases at the moment. All other identifiers +// are qualified by a receiver argument, since they're struct fields. +// +// All recievers are single letters, so we don't allow import aliases to be a +// single letter. +var badIdents = []string{ + "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", + "length", "limit", "ptr", "size", "src", "srcs", "task", "val", + // All single-letter identifiers. +} + +// Constructed fromt badIdents in init(). +var badIdentsMap map[string]struct{} + +func init() { + badIdentsMap = make(map[string]struct{}) + for _, ident := range badIdents { + badIdentsMap[ident] = struct{}{} + } +} + +// Generator drives code generation for a single invocation of the go_marshal +// utility. +// +// The Generator holds arguments passed to the tool, and drives parsing, +// processing and code Generator for all types marked with +marshal declared in +// the input files. +// +// See Generator.run() as the entry point. +type Generator struct { + // Paths to input go source files. + inputs []string + // Output file to write generated go source. + output *os.File + // Output file to write generated tests. + outputTest *os.File + // Package name for the generated file. + pkg string + // Set of extra packages to import in the generated file. + imports *importTable +} + +// NewGenerator creates a new code Generator. +func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { + f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) + } + fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) + } + g := Generator{ + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + imports: newImportTable(), + } + for _, i := range imports { + // All imports on the extra imports list are unconditionally marked as + // used, so that they're always added to the generated code. + g.imports.add(i).markUsed() + } + + // The following imports may or may not be used by the generated code, + // depending on what's required for the target types. Don't mark these as + // used by default. + g.imports.add("io") + g.imports.add("reflect") + g.imports.add("runtime") + g.imports.add("unsafe") + g.imports.add("gvisor.dev/gvisor/pkg/gohacks") + g.imports.add("gvisor.dev/gvisor/pkg/safecopy") + g.imports.add("gvisor.dev/gvisor/pkg/usermem") + g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") + + return &g, nil +} + +// writeHeader writes the header for the generated source file. The header +// includes the package name, package level comments and import statements. +func (g *Generator) writeHeader() error { + var b sourceBuffer + b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") + + // Emit build tags. + if t := tags.Aggregate(g.inputs); len(t) > 0 { + b.emit(strings.Join(t.Lines(), "\n")) + b.emit("\n\n") + } + + // Package header. + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.output); err != nil { + return err + } + + return g.imports.write(g.output) +} + +// writeTypeChecks writes a statement to force the compiler to perform a type +// check for all Marshallable types referenced by the generated code. +func (g *Generator) writeTypeChecks(ms map[string]struct{}) error { + if len(ms) == 0 { + return nil + } + + msl := make([]string, 0, len(ms)) + for m, _ := range ms { + msl = append(msl, m) + } + sort.Strings(msl) + + var buf bytes.Buffer + fmt.Fprint(&buf, "// Marshallable types used by this file.\n") + + for _, m := range msl { + fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m) + } + fmt.Fprint(&buf, "\n") + + _, err := fmt.Fprint(g.output, buf.String()) + return err +} + +// parse processes all input files passed this generator and produces a set of +// parsed go ASTs. +func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) { + debugf("go_marshal invoked with %d input files:\n", len(g.inputs)) + for _, path := range g.inputs { + debugf(" %s\n", path) + } + + files := make([]*ast.File, 0, len(g.inputs)) + fsets := make([]*token.FileSet, 0, len(g.inputs)) + + for _, path := range g.inputs { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + // Not a valid input file? + return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err) + } + + if debugEnabled() { + debugf("AST for %q:\n", path) + ast.Print(fset, f) + } + + files = append(files, f) + fsets = append(fsets, fset) + } + + return files, fsets, nil +} + +// sliceAPI carries information about the '+marshal slice' directive. +type sliceAPI struct { + // Comment node in the AST containing the +marshal tag. + comment *ast.Comment + // Identifier fragment to use when naming generated functions for the slice + // API. + ident string + // Whether the generated functions should reference the newtype name, or the + // inner type name. Only meaningful on newtype declarations on primitives. + inner bool +} + +// marshallableType carries information about a type marked with the '+marshal' +// directive. +type marshallableType struct { + spec *ast.TypeSpec + slice *sliceAPI +} + +func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType { + mt := marshallableType{ + spec: spec, + slice: nil, + } + + var unhandledTags []string + + for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) { + if strings.HasPrefix(tag, "slice:") { + tokens := strings.Split(tag, ":") + if len(tokens) < 2 || len(tokens) > 3 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag)) + } + if len(tokens[1]) == 0 { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'") + } + + sa := &sliceAPI{ + comment: tagLine, + ident: tokens[1], + } + mt.slice = sa + + if len(tokens) == 3 { + if tokens[2] != "inner" { + abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'") + } + sa.inner = true + } + + continue + } + + unhandledTags = append(unhandledTags, tag) + } + + if len(unhandledTags) > 0 { + abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " "))) + } + + return mt +} + +// collectMarshallableTypes walks the parsed AST and collects a list of type +// declarations for which we need to generate the Marshallable interface. +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType { + var types []marshallableType + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Type declaration? + if !ok || gdecl.Tok != token.TYPE { + debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") + continue + } + // Does it have a comment? + if gdecl.Doc == nil { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n") + continue + } + // Does the comment contain a "+marshal" line? + marked := false + var tagLine *ast.Comment + for _, c := range gdecl.Doc.List { + if strings.HasPrefix(c.Text, "// +marshal") { + marked = true + tagLine = c + break + } + } + if !marked { + debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n") + continue + } + for _, spec := range gdecl.Specs { + // We already confirmed we're in a type declaration earlier, so this + // cast will succeed. + t := spec.(*ast.TypeSpec) + switch t.Type.(type) { + case *ast.StructType: + debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name) + case *ast.Ident: // Newtype on primitive. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name) + case *ast.ArrayType: // Newtype on array. + debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name) + default: + // A user specifically requested marshalling on this type, but we + // don't support it. + abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) + } + types = append(types, newMarshallableType(f, tagLine, t)) + + } + } + return types +} + +// collectImports collects all imports from all input source files. Some of +// these imports are copied to the generated output, if they're referenced by +// the generated code. +// +// collectImports de-duplicates imports while building the list, and ensures +// identifiers in the generated code don't conflict with any imported package +// names. +func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt { + is := make(map[string]importStmt) + for _, decl := range a.Decls { + gdecl, ok := decl.(*ast.GenDecl) + // Import statement? + if !ok || gdecl.Tok != token.IMPORT { + continue + } + for _, spec := range gdecl.Specs { + i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f) + debugf("Collected import '%s' as '%s'\n", i.path, i.name) + + // Make sure we have an import that doesn't use any local names that + // would conflict with identifiers in the generated code. + if len(i.name) == 1 && i.name != "_" { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name)) + } + if _, ok := badIdentsMap[i.name]; ok { + abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name)) + } + } + } + return is + +} + +func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator { + i := newInterfaceGenerator(t.spec, fset) + switch ty := t.spec.Type.(type) { + case *ast.StructType: + i.validateStruct(t.spec, ty) + i.emitMarshallableForStruct(ty) + if t.slice != nil { + i.emitMarshallableSliceForStruct(ty, t.slice) + } + case *ast.Ident: + i.validatePrimitiveNewtype(ty) + i.emitMarshallableForPrimitiveNewtype(ty) + if t.slice != nil { + i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice) + } + case *ast.ArrayType: + i.validateArrayNewtype(t.spec.Name, ty) + // After validate, we can safely call arrayLen. + i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident)) + if t.slice != nil { + abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?")) + } + default: + // This should've been filtered out by collectMarshallabeTypes. + panic(fmt.Sprintf("Unexpected type %+v", ty)) + } + return i +} + +// generateOneTestSuite generates a test suite for the automatically generated +// implementations type t. +func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator { + i := newTestGenerator(t.spec) + i.emitTests(t.slice) + return i +} + +// Run is the entry point to code generation using g. +// +// Run parses all input source files specified in g and emits generated code. +func (g *Generator) Run() error { + // Parse our input source files into ASTs and token sets. + asts, fsets, err := g.parse() + if err != nil { + return err + } + + if len(asts) != len(fsets) { + panic("ASTs and FileSets don't match") + } + + // Map of imports in source files; key = local package name, value = import + // path. + is := make(map[string]importStmt) + for i, a := range asts { + // Collect all imports from the source files. We may need to copy some + // of these to the generated code if they're referenced. This has to be + // done before the loop below because we need to process all ASTs before + // we start requesting imports to be copied one by one as we encounter + // them in each generated source. + for name, i := range g.collectImports(a, fsets[i]) { + is[name] = i + } + } + + var impls []*interfaceGenerator + var ts []*testGenerator + // Set of Marshallable types referenced by generated code. + ms := make(map[string]struct{}) + for i, a := range asts { + // Collect type declarations marked for code generation and generate + // Marshallable interfaces. + for _, t := range g.collectMarshallableTypes(a, fsets[i]) { + impl := g.generateOne(t, fsets[i]) + // Collect Marshallable types referenced by the generated code. + for ref, _ := range impl.ms { + ms[ref] = struct{}{} + } + impls = append(impls, impl) + // Collect imports referenced by the generated code and add them to + // the list of imports we need to copy to the generated code. + for name, _ := range impl.is { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) + } + } + ts = append(ts, g.generateOneTestSuite(t)) + } + } + + // Write output file header. These include things like package name and + // import statements. + if err := g.writeHeader(); err != nil { + return err + } + + // Write type checks for referenced marshallable types to output file. + if err := g.writeTypeChecks(ms); err != nil { + return err + } + + // Write generated interfaces to output file. + for _, i := range impls { + if err := i.write(g.output); err != nil { + return err + } + } + + // Write generated tests to test file. + return g.writeTests(ts) +} + +// writeTests outputs tests for the generated interface implementations to a go +// source file. +func (g *Generator) writeTests(ts []*testGenerator) error { + var b sourceBuffer + b.emit("package %s\n\n", g.pkg) + if err := b.write(g.outputTest); err != nil { + return err + } + + // Collect and write test import statements. + imports := newImportTable() + for _, t := range ts { + imports.merge(t.imports) + } + + if err := imports.write(g.outputTest); err != nil { + return err + } + + // Write test functions. + + // If we didn't generate any Marshallable implementations, we can't just + // emit an empty test file, since that causes the build to fail with "no + // tests/benchmarks/examples found". Unfortunately we can't signal bazel to + // omit the entire package since the outputs are already defined before + // go-marshal is called. If we'd otherwise emit an empty test suite, emit an + // empty example instead. + if len(ts) == 0 { + b.reset() + b.emit("func Example() {\n") + b.inIndent(func() { + b.emit("// This example is intentionally empty to ensure this file contains at least\n") + b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n") + b.emit("// is marked marshallable, but emitting a test file with no entities results\n") + b.emit("// in a build failure.\n") + }) + b.emit("}\n") + return b.write(g.outputTest) + } + + for _, t := range ts { + if err := t.write(g.outputTest); err != nil { + return err + } + } + return nil +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go new file mode 100644 index 000000000..e3c3dac63 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -0,0 +1,276 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "go/token" + "strings" +) + +// interfaceGenerator generates marshalling interfaces for a single type. +// +// getState is not thread-safe. +type interfaceGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // FileSet containing the tokens for the type we're processing. + f *token.FileSet + + // is records external packages referenced by the generated implementation. + is map[string]struct{} + + // ms records Marshallable types referenced by the generated implementation + // of t's interfaces. + ms map[string]struct{} + + // as records embedded fields in t that are potentially not packed. The key + // is the accessor for the field. + as map[string]struct{} +} + +// typeName returns the name of the type this g represents. +func (g *interfaceGenerator) typeName() string { + return g.t.Name.Name +} + +// newinterfaceGenerator creates a new interface generator. +func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { + g := &interfaceGenerator{ + t: t, + r: receiverName(t), + f: fset, + is: make(map[string]struct{}), + ms: make(map[string]struct{}), + as: make(map[string]struct{}), + } + g.recordUsedMarshallable(g.typeName()) + return g +} + +func (g *interfaceGenerator) recordUsedMarshallable(m string) { + g.ms[m] = struct{}{} + +} + +func (g *interfaceGenerator) recordUsedImport(i string) { + g.is[i] = struct{}{} +} + +func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) { + g.as[fieldName] = struct{}{} +} + +// abortAt aborts the go_marshal tool with the given error message, with a +// reference position to the input source. Same as abortAt, but uses g to +// resolve p to position. +func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { + abortAt(g.f.Position(p), msg) +} + +// scalarSize returns the size of type identified by t. If t isn't a primitive +// type, the size isn't known at code generation time, and must be resolved via +// the marshal.Marshallable interface. +func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) { + switch t.Name { + case "int8", "uint8", "byte": + return 1, false + case "int16", "uint16": + return 2, false + case "int32", "uint32": + return 4, false + case "int64", "uint64": + return 8, false + default: + return 0, true + } +} + +func (g *interfaceGenerator) shift(bufVar string, n int) { + g.emit("%s = %s[%d:]\n", bufVar, bufVar, n) +} + +func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { + g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) +} + +// marshalScalar writes a single scalar to a byte slice. +func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(%s)\n", bufVar, accessor) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) + g.shift(bufVar, 8) + default: + g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + } +} + +// unmarshalScalar reads a single scalar from a byte slice. +func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { + switch typ { + case "byte": + g.emit("%s = %s[0]\n", accessor, bufVar) + g.shift(bufVar, 1) + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) + g.shift(bufVar, 2) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) + g.shift(bufVar, 4) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) + g.shift(bufVar, 8) + default: + g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) + g.shiftDynamic(bufVar, accessor) + g.recordPotentiallyNonPackedField(accessor) + } +} + +// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a +// byte slice, bypassing escape analysis. The caller is responsible for ensuring +// srcPtr lives until they're done with dstVar, the runtime does not consider +// dstVar dependent on srcPtr due to the escape analysis bypass. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "hdr", and cannot be used +// in a context where it is already bound. +func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) { + g.recordUsedImport("gohacks") + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr) + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type +// to a byte slice. As part of the cast, the byte slice is made to look +// independent of the src slice by bypassing escape analysis. This means the +// byte slice can be used without causing the source to escape. The caller is +// responsible for ensuring srcPtr lives until they're done with dstVar, as the +// runtime no longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifiers "ptr", "val" and "hdr", +// and cannot be used in a context where these identifiers are already bound. +func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) { + g.emitNoEscapeSliceDataPointer(srcPtr, "val") + + g.emit("// Construct a slice backed by dst's underlying memory.\n") + g.emit("var %s []byte\n", dstVar) + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar) + g.emit("hdr.Data = uintptr(val)\n") + g.emit("hdr.Len = %s\n", lenExpr) + g.emit("hdr.Cap = %s\n\n", lenExpr) +} + +// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an +// unsafe.Pointer, bypassing escape analysis. The caller is responsible for +// ensuring srcPtr lives until they're done with dstVar, as the runtime no +// longer considers dstVar dependent on srcPtr and is free to GC it. +// +// srcPtr must be a pointer. +// +// This function uses internally uses the identifier "ptr" cannot be used in a +// context where this identifier is already bound. +func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) { + g.recordUsedImport("gohacks") + g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr) + g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar) +} + +func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) + g.emit("// must live until the use above.\n") + g.emit("runtime.KeepAlive(%s)\n", ptrVar) +} + +func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { + switch x := e.X.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, x) + case *ast.Ident: + fmt.Fprintf(b, "%s", x.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", x.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + + fmt.Fprintf(b, "%s", e.Op) + + switch y := e.Y.(type) { + case *ast.BinaryExpr: + // Recursively expand sub-expression. + g.expandBinaryExpr(b, y) + case *ast.Ident: + fmt.Fprintf(b, "%s", y.Name) + case *ast.BasicLit: + fmt.Fprintf(b, "%s", y.Value) + default: + g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } +} + +// arrayLenExpr returns a string containing a valid golang expression +// representing the length of array a. The returned expression should be treated +// as a single value, and will be already parenthesized as required. +func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string { + var b strings.Builder + + switch l := a.Len.(type) { + case *ast.Ident: + fmt.Fprintf(&b, "%s", l.Name) + case *ast.BasicLit: + fmt.Fprintf(&b, "%s", l.Value) + case *ast.BinaryExpr: + g.expandBinaryExpr(&b, l) + return fmt.Sprintf("(%s)", b.String()) + default: + g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers") + } + return b.String() +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go new file mode 100644 index 000000000..72ef03a22 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -0,0 +1,146 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on arrays. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) { + if a.Len == nil { + g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name)) + } + + if _, ok := a.Elt.(*ast.Ident); !ok { + g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt))) + } +} + +func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + lenExpr := g.arrayLenExpr(a) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(elt); !dynamic { + g.emit("return %d * %s\n", size, lenExpr) + } else { + g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") + }) + g.emit("}\n") + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Array newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go new file mode 100644 index 000000000..39f654ea8 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -0,0 +1,289 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// newtypes on primitives. + +package gomarshal + +import ( + "fmt" + "go/ast" +) + +// marshalPrimitiveScalar writes a single primitive variable to a byte +// slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("// Explicilty cast to the underlying type before dispatching to\n") + g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor) + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) { + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + + eltType := g.typeName() + if slice.inner { + eltType = nt.Name + } + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) + g.emit("//go:nosplit\n") + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go new file mode 100644 index 000000000..9cd3c9579 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -0,0 +1,618 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the bits of the code generator specific to marshalling +// structs. + +package gomarshal + +import ( + "fmt" + "go/ast" + "strings" +) + +func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { + return fmt.Sprintf("%s.%s", g.r, n.Name) +} + +// areFieldsPackedExpression returns a go expression checking whether g.t's fields are +// packed. Returns "", false if g.t has no fields that may be potentially +// packed, otherwise returns <clause>, true, where <clause> is an expression +// like "t.a.Packed() && t.b.Packed() && t.c.Packed()". +func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { + if len(g.as) == 0 { + return "", false + } + + cs := make([]string, 0, len(g.as)) + for accessor, _ := range g.as { + cs = append(cs, fmt.Sprintf("%s.Packed()", accessor)) + } + return strings.Join(cs, " && "), true +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) { + forEachStructField(st, func(f *ast.Field) { + if len(f.Names) == 0 { + g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") + } + }) + + forEachStructField(st, func(f *ast.Field) { + fieldDispatcher{ + primitive: func(_, t *ast.Ident) { + g.validatePrimitiveNewtype(t) + }, + selector: func(_, _, _ *ast.Ident) { + // No validation to perform on selector fields. However this + // callback must still be provided. + }, + array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) { + g.validateArrayNewtype(n, a) + }, + unhandled: func(_ *ast.Ident) { + g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type))) + }, + }.dispatch(f) + }) +} + +func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool { + packed := true + forEachStructField(st, func(f *ast.Field) { + if f.Tag != nil { + if f.Tag.Value == "`marshal:\"unaligned\"`" { + if packed { + debugfAt(g.f.Position(g.t.Pos()), + fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name)) + packed = false + } + } + } + }) + return packed +} + +func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { + thisPacked := g.isStructPacked(st) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + primitiveSize := 0 + var dynamicSizeTerms []string + + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if size, dynamic := g.scalarSize(t); !dynamic { + primitiveSize += size + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name)) + } + }, + selector: func(n, tX, tSel *ast.Ident) { + tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name) + g.recordUsedImport(tX.Name) + g.recordUsedMarshallable(tName) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName)) + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if size, dynamic := g.scalarSize(t); !dynamic { + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr)) + } else { + g.recordUsedMarshallable(t.Name) + dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr)) + } + }, + }.dispatch) + g.emit("return %d", primitiveSize) + if len(dynamicSizeTerms) > 0 { + g.incIndent() + } + { + for _, d := range dynamicSizeTerms { + g.emitNoIndent(" +\n") + g.emit(d) + } + } + if len(dynamicSizeTerms) > 0 { + g.decIndent() + } + }) + g.emit("\n}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("dst", len) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name) + } + return + } + g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name) + g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + return + } + g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("dst = dst[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can reference here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + forEachStructField(st, fieldDispatcher{ + primitive: func(n, t *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name) + if len, dynamic := g.scalarSize(t); !dynamic { + g.shift("src", len) + } else { + // We don't have an instance of the dynamic type we can + // reference here (since the version in this struct is + // anonymous). Use a typed nil pointer to call + // SizeBytes() instead. + g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name)) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name)) + } + return + } + g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + }, + selector: func(n, tX, tSel *ast.Ident) { + if n.Name == "_" { + g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name) + g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name) + g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name)) + return + } + g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + }, + array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) { + lenExpr := g.arrayLenExpr(a) + if n.Name == "_" { + g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr) + if size, dynamic := g.scalarSize(t); !dynamic { + g.emit("src = src[%d*(%s):]\n", size, lenExpr) + } else { + // We can't use shiftDynamic here because we don't have + // an instance of the dynamic type we can referece here + // (since the version in this struct is anonymous). Use + // a typed nil pointer to call SizeBytes() instead. + g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr) + } + return + } + + g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) + g.inIndent(func() { + g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") + }) + g.emit("}\n") + }, + }.dispatch) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("//go:nosplit\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + expr, fieldsMaybePacked := g.areFieldsPackedExpression() + switch { + case !thisPacked: + g.emit("return false\n") + case fieldsMaybePacked: + g.emit("return %s\n", expr) + default: + g.emit("return true\n") + + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.MarshalBytes(dst)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + if thisPacked { + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if %s {\n", cond) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("} else {\n") + g.inIndent(func() { + g.emit("%s.UnmarshalBytes(src)\n", g.r) + }) + g.emit("}\n") + } else { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + } + } else { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } + }) + g.emit("}\n\n") + + g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("//go:nosplit\n") + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") + g.emit("// partially unmarshalled struct.\n") + g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emitKeepAlive(g.r) + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("length, err := w.Write(buf)\n") + g.emit("return int64(length), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) + + g.emit("length, err := w.Write(buf)\n") + g.emitKeepAlive(g.r) + g.emit("return int64(length), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} + +func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) { + thisPacked := g.isStructPacked(st) + + if slice.inner { + abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident)) + } + + g.recordUsedImport("marshal") + g.recordUsedImport("usermem") + + g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + + g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") + g.emit("limit := length/size\n") + g.emit("for idx := 0; idx < limit; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("// Handle any final partial object.\n") + g.emit("if length < size*count && length%size != 0 {\n") + g.inIndent(func() { + g.emit("idx := limit\n") + g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n\n") + + g.emit("return length, err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast deserialization. + g.emitCastSliceToByteSlice("&dst", "buf", "size * count") + + g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return task.CopyOutBytes(addr, buf)\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emitCastSliceToByteSlice("&src", "buf", "size * count") + + g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(src)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !src[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&src", "val") + + g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n") + g.emitKeepAlive("src") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName()) + g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName()) + g.inIndent(func() { + g.emit("count := len(dst)\n") + g.emit("if count == 0 {\n") + g.inIndent(func() { + g.emit("return 0, nil\n") + }) + g.emit("}\n") + g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) + g.emit("for idx := 0; idx < count; idx++ {\n") + g.inIndent(func() { + g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + }) + g.emit("}\n") + g.emit("return size * count, nil\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if _, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !dst[0].Packed() {\n") + g.inIndent(fallback) + g.emit("}\n\n") + } + g.emitNoEscapeSliceDataPointer("&dst", "val") + + g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n") + g.emitKeepAlive("dst") + g.emit("return length, err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") +} diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go new file mode 100644 index 000000000..631295373 --- /dev/null +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -0,0 +1,233 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "fmt" + "go/ast" + "io" + "strings" +) + +var standardImports = []string{ + "bytes", + "fmt", + "reflect", + "testing", + + "gvisor.dev/gvisor/tools/go_marshal/analysis", +} + +var sliceAPIImports = []string{ + "encoding/binary", + "gvisor.dev/gvisor/pkg/usermem", +} + +type testGenerator struct { + sourceBuffer + + // The type we're serializing. + t *ast.TypeSpec + + // Receiver argument for generated methods. + r string + + // Imports used by generated code. + imports *importTable + + // Import statement for the package declaring the type we generated code + // for. We need this to construct test instances for the type, since the + // tests aren't written in the same package. + decl *importStmt +} + +func newTestGenerator(t *ast.TypeSpec) *testGenerator { + g := &testGenerator{ + t: t, + r: receiverName(t), + imports: newImportTable(), + } + + for _, i := range standardImports { + g.imports.add(i).markUsed() + } + // These imports are used if a type requests the slice API. Don't + // mark them as used by default. + for _, i := range sliceAPIImports { + g.imports.add(i) + } + + return g +} + +func (g *testGenerator) typeName() string { + return g.t.Name.Name +} + +func (g *testGenerator) testFuncName(base string) string { + return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) +} + +func (g *testGenerator) inTestFunction(name string, body func()) { + g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) + g.inIndent(body) + g.emit("}\n\n") +} + +func (g *testGenerator) emitTestNonZeroSize() { + g.inTestFunction("TestSizeNonZero", func() { + g.emit("var x %v\n", g.typeName()) + g.emit("if x.SizeBytes() == 0 {\n") + g.inIndent(func() { + g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSuspectAlignment() { + g.inTestFunction("TestSuspectAlignment", func() { + g.emit("var x %v\n", g.typeName()) + g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { + g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { + g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("buf := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalBytes(buf)\n") + g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") + g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") + + g.emit("y.UnmarshalBytes(buf)\n") + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + + g.emit("z.UnmarshalUnsafe(buf)\n") + g.emit("if !reflect.DeepEqual(x, z) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") + }) + g.emit("}\n") + g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") + g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) { + for _, name := range []string{"binary", "usermem"} { + if !g.imports.markUsed(name) { + panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name)) + } + } + + g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() { + g.emit("var x, y, yUnsafe [8]%s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName()) + g.emit("buf := bytes.NewBuffer(make([]byte, size))\n") + g.emit("buf.Reset()\n") + g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n") + }) + g.emit("}\n") + g.emit("bufUnsafe := make([]byte, size)\n") + g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident) + + g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident) + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { + g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { + g.emit("var x %s\n", g.typeName()) + g.emit("sizeFromConcrete := x.SizeBytes()\n") + g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) + + g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") + g.inIndent(func() { + g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTests(slice *sliceAPI) { + g.emitTestNonZeroSize() + g.emitTestSuspectAlignment() + g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() + + if slice != nil { + g.emitTestMarshalUnmarshalSlicePreservesData(slice) + } +} + +func (g *testGenerator) write(out io.Writer) error { + return g.sourceBuffer.write(out) +} diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go new file mode 100644 index 000000000..d94314302 --- /dev/null +++ b/tools/go_marshal/gomarshal/util.go @@ -0,0 +1,491 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomarshal + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/token" + "io" + "os" + "path" + "reflect" + "sort" + "strings" +) + +var debug = flag.Bool("debug", false, "enables debugging output") + +// receiverName returns an appropriate receiver name given a type spec. +func receiverName(t *ast.TypeSpec) string { + if len(t.Name.Name) < 1 { + // Zero length type name? + panic("unreachable") + } + return strings.ToLower(t.Name.Name[:1]) +} + +// kindString returns a user-friendly representation of an AST expr type. +func kindString(e ast.Expr) string { + switch e.(type) { + case *ast.Ident: + return "scalar" + case *ast.ArrayType: + return "array" + case *ast.StructType: + return "struct" + case *ast.StarExpr: + return "pointer" + case *ast.FuncType: + return "function" + case *ast.InterfaceType: + return "interface" + case *ast.MapType: + return "map" + case *ast.ChanType: + return "channel" + default: + return reflect.TypeOf(e).String() + } +} + +func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { + for _, field := range st.Fields.List { + fn(field) + } +} + +// fieldDispatcher is a collection of callbacks for handling different types of +// fields in a struct declaration. +type fieldDispatcher struct { + primitive func(n, t *ast.Ident) + selector func(n, tX, tSel *ast.Ident) + array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) + unhandled func(n *ast.Ident) +} + +// Precondition: All dispatch callbacks that will be invoked must be +// provided. Embedded fields are not allowed, len(f.Names) >= 1. +func (fd fieldDispatcher) dispatch(f *ast.Field) { + // Each field declaration may actually be multiple declarations of the same + // type. For example, consider: + // + // type Point struct { + // x, y, z int + // } + // + // We invoke the call-backs once per such instance. Embedded fields are not + // allowed, and results in a panic. + if len(f.Names) < 1 { + panic("Precondition not met: attempted to dispatch on embedded field") + } + + for _, name := range f.Names { + switch v := f.Type.(type) { + case *ast.Ident: + fd.primitive(name, v) + case *ast.SelectorExpr: + fd.selector(name, v.X.(*ast.Ident), v.Sel) + case *ast.ArrayType: + switch t := v.Elt.(type) { + case *ast.Ident: + fd.array(name, v, t) + default: + // Should be handled with a better error message during validate. + panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) + } + default: + fd.unhandled(name) + } + } +} + +// debugEnabled indicates whether debugging is enabled for gomarshal. +func debugEnabled() bool { + return *debug +} + +// abort aborts the go_marshal tool with the given error message. +func abort(msg string) { + if !strings.HasSuffix(msg, "\n") { + msg += "\n" + } + fmt.Print(msg) + os.Exit(1) +} + +// abortAt aborts the go_marshal tool with the given error message, with +// a reference position to the input source. +func abortAt(p token.Position, msg string) { + abort(fmt.Sprintf("%v:\n %s\n", p, msg)) +} + +// debugf conditionally prints a debug message. +func debugf(f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf(f, a...) + } +} + +// debugfAt conditionally prints a debug message with a reference to a position +// in the input source. +func debugfAt(p token.Position, f string, a ...interface{}) { + if debugEnabled() { + fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) + } +} + +// emit generates a line of code in the output file. +// +// emit is a wrapper around writing a formatted string to the output +// buffer. emit can be invoked in one of two ways: +// +// (1) emit("some string") +// When emit is called with a single string argument, it is simply copied to +// the output buffer without any further formatting. +// (2) emit(fmtString, args...) +// emit can also be invoked in a similar fashion to *Printf() functions, +// where the first argument is a format string. +// +// Calling emit with a single argument that is not a string will result in a +// panic, as the caller's intent is ambiguous. +func emit(out io.Writer, indent int, a ...interface{}) { + const spacesPerIndentLevel = 4 + + if len(a) < 1 { + panic("emit() called with no arguments") + } + + if indent > 0 { + if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { + // Writing to the emit output should not fail. Typically the output + // is a byte.Buffer; writes to these never fail. + panic(err) + } + } + + first, ok := a[0].(string) + if !ok { + // First argument must be either the string to emit (case 1 from + // function-level comment), or a format string (case 2). + panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) + } + + if len(a) == 1 { + // Single string argument. Assume no formatting requested. + if _, err := fmt.Fprint(out, first); err != nil { + // Writing to out should not fail. + panic(err) + } + return + + } + + // Formatting requested. + if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { + // Writing to out should not fail. + panic(err) + } +} + +// sourceBuffer represents fragments of generated go source code. +// +// sourceBuffer provides a convenient way to build up go souce fragments in +// memory. May be safely zero-value initialized. Not thread-safe. +type sourceBuffer struct { + // Current indentation level. + indent int + + // Memory buffer containing contents while they're being generated. + b bytes.Buffer +} + +func (b *sourceBuffer) reset() { + b.indent = 0 + b.b.Reset() +} + +func (b *sourceBuffer) incIndent() { + b.indent++ +} + +func (b *sourceBuffer) decIndent() { + if b.indent <= 0 { + panic("decIndent() without matching incIndent()") + } + b.indent-- +} + +func (b *sourceBuffer) emit(a ...interface{}) { + emit(&b.b, b.indent, a...) +} + +func (b *sourceBuffer) emitNoIndent(a ...interface{}) { + emit(&b.b, 0 /*indent*/, a...) +} + +func (b *sourceBuffer) inIndent(body func()) { + b.incIndent() + body() + b.decIndent() +} + +func (b *sourceBuffer) write(out io.Writer) error { + _, err := fmt.Fprint(out, b.b.String()) + return err +} + +// Write implements io.Writer.Write. +func (b *sourceBuffer) Write(buf []byte) (int, error) { + return (b.b.Write(buf)) +} + +// importStmt represents a single import statement. +type importStmt struct { + // Local name of the imported package. + name string + // Import path. + path string + // Indicates whether the local name is an alias, or simply the final + // component of the path. + aliased bool + // Indicates whether this import was referenced by generated code. + used bool + // AST node and file set representing the import statement, if any. These + // are only non-nil if the import statement originates from an input source + // file. + spec *ast.ImportSpec + fset *token.FileSet +} + +func newImport(p string) *importStmt { + name := path.Base(p) + return &importStmt{ + name: name, + path: p, + aliased: false, + } +} + +func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. + name := path.Base(p) + if name == "" || name == "/" || name == "." { + panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", + f.Position(spec.Path.Pos()), name)) + } + if spec.Name != nil { + name = spec.Name.Name + } + return &importStmt{ + name: name, + path: p, + aliased: spec.Name != nil, + spec: spec, + fset: f, + } +} + +// String implements fmt.Stringer.String. This generates a string for the import +// statement appropriate for writing directly to generated code. +func (i *importStmt) String() string { + if i.aliased { + return fmt.Sprintf("%s %q", i.name, i.path) + } + return fmt.Sprintf("%q", i.path) +} + +// debugString returns a debug string representing an import statement. This +// representation is not valid golang code and is used for debugging output. +func (i *importStmt) debugString() string { + if i.spec != nil && i.fset != nil { + return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) + } + return fmt.Sprintf("(go-marshal import): %s", i) +} + +func (i *importStmt) markUsed() { + i.used = true +} + +func (i *importStmt) equivalent(other *importStmt) bool { + return i.name == other.name && i.path == other.path && i.aliased == other.aliased +} + +// importTable represents a collection of importStmts. +// +// An importTable may contain multiple import statements referencing the same +// local name. All import statements aliasing to the same local name are +// technically ambiguous, as if such an import name is used in the generated +// code, it's not clear which import statement it refers to. We ignore any +// potential collisions until actually writing the import table to the generated +// source file. See importTable.write. +// +// Given the following import statements across all the files comprising a +// package marshalled: +// +// "sync" +// "pkg/sync" +// "pkg/sentry/kernel" +// ktime "pkg/sentry/kernel/time" +// +// An importTable representing them would look like this: +// +// importTable { +// is: map[string][]*importStmt { +// "sync": []*importStmt{ +// importStmt{name:"sync", path:"sync", aliased:false} +// importStmt{name:"sync", path:"pkg/sync", aliased:false} +// }, +// "kernel": []*importStmt{importStmt{ +// name: "kernel", +// path: "pkg/sentry/kernel", +// aliased: false +// }}, +// "ktime": []*importStmt{importStmt{ +// name: "ktime", +// path: "pkg/sentry/kernel/time", +// aliased: true, +// }}, +// } +// } +// +// Note that the local name "sync" is assigned to two different import +// statements. This is possible if the import statements are from different +// source files in the same package. +// +// Since go-marshal generates a single output file per package regardless of the +// number of input files, if "sync" is referenced by any generated code, it's +// unclear which import statement "sync" refers to. While it's theoretically +// possible to resolve this by assigning a unique local alias to each instance +// of the sync package, go-marshal currently aborts when it encounters such an +// ambiguity. +// +// TODO(b/151478251): importTable considers the final component of an import +// path to be the package name, but this is only a convention. The actual +// package name is determined by the package statement in the source files for +// the package. +type importTable struct { + // Map of imports and whether they should be copied to the output. + is map[string][]*importStmt +} + +func newImportTable() *importTable { + return &importTable{ + is: make(map[string][]*importStmt), + } +} + +// Merges import statements from other into i. +func (i *importTable) merge(other *importTable) { + for name, ims := range other.is { + i.is[name] = append(i.is[name], ims...) + } +} + +func (i *importTable) addStmt(s *importStmt) *importStmt { + i.is[s.name] = append(i.is[s.name], s) + return s +} + +func (i *importTable) add(s string) *importStmt { + n := newImport(s) + return i.addStmt(n) +} + +func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { + return i.addStmt(newImportFromSpec(spec, f)) +} + +// Marks the import named n as used. If no such import is in the table, returns +// false. +func (i *importTable) markUsed(n string) bool { + if ns, ok := i.is[n]; ok { + for _, n := range ns { + n.markUsed() + } + return true + } + return false +} + +func (i *importTable) clear() { + for _, is := range i.is { + for _, i := range is { + i.used = false + } + } +} + +func (i *importTable) write(out io.Writer) error { + if len(i.is) == 0 { + // Nothing to import, we're done. + return nil + } + + imports := make([]string, 0, len(i.is)) + for name, is := range i.is { + var lastUsed *importStmt + var ambiguous bool + + for _, i := range is { + if i.used { + if lastUsed != nil { + if !i.equivalent(lastUsed) { + ambiguous = true + } + } + lastUsed = i + } + } + + if ambiguous { + // We have two or more import statements across the different source + // files that share a local name, and at least one of these imports + // are used by the generated code. This ambiguity can't be resolved + // by go-marshal and requires the user intervention. Dump a list of + // the colliding import statements and let the user modify the input + // files as appropriate. + var b strings.Builder + fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) + fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) + // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't + // be true. Therefore the slicing below is safe. + for _, i := range is[:len(is)-1] { + fmt.Fprintf(&b, " %v\n", i.debugString()) + } + fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) + panic(b.String()) + } + + if lastUsed != nil { + imports = append(imports, lastUsed.String()) + } + } + sort.Strings(imports) + + var b sourceBuffer + b.emit("import (\n") + b.incIndent() + for _, i := range imports { + b.emit("%s\n", i) + } + b.decIndent() + b.emit(")\n\n") + + return b.write(out) +} diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go new file mode 100644 index 000000000..f74be5c29 --- /dev/null +++ b/tools/go_marshal/main.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go_marshal is a code generation utility for automatically generating code to +// marshal go data structures to memory. +// +// This binary is typically run as part of the build process, and is invoked by +// the go_marshal bazel rule defined in defs.bzl. +// +// See README.md. +package main + +import ( + "flag" + "fmt" + "os" + "strings" + + "gvisor.dev/gvisor/tools/go_marshal/gomarshal" +) + +var ( + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + + if *pkg == "" { + flag.Usage() + fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n") + os.Exit(1) + } + + var extraImports []string + if len(*imports) > 0 { + // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus + // we check for an empty imports list to avoid emitting an empty string + // as an import. + extraImports = strings.Split(*imports, ",") + } + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) + if err != nil { + panic(err) + } + + if err := g.Run(); err != nil { + panic(err) + } +} diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD new file mode 100644 index 000000000..bacfaa5a4 --- /dev/null +++ b/tools/go_marshal/marshal/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "marshal", + srcs = [ + "marshal.go", + ], + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/usermem", + ], +) diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go new file mode 100644 index 000000000..cb2166252 --- /dev/null +++ b/tools/go_marshal/marshal/marshal.go @@ -0,0 +1,187 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal defines the Marshallable interface for +// serialize/deserializing go data structures to/from memory, according to the +// Linux ABI. +// +// Implementations of this interface are typically automatically generated by +// tools/go_marshal. See the go_marshal README for details. +package marshal + +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" +) + +// Task provides a subset of kernel.Task, used in marshalling. We don't import +// the kernel package directly to avoid circular dependency. +type Task interface { + // CopyScratchBuffer provides a task goroutine-local scratch buffer. See + // kernel.CopyScratchBuffer. + CopyScratchBuffer(size int) []byte + + // CopyOutBytes writes the contents of b to the task's memory. See + // kernel.CopyOutBytes. + CopyOutBytes(addr usermem.Addr, b []byte) (int, error) + + // CopyInBytes reads the contents of the task's memory to b. See + // kernel.CopyInBytes. + CopyInBytes(addr usermem.Addr, b []byte) (int, error) +} + +// Marshallable represents operations on a type that can be marshalled to and +// from memory. +// +// go-marshal automatically generates implementations for this interface for +// types marked as '+marshal'. +type Marshallable interface { + io.WriterTo + + // SizeBytes is the size of the memory representation of a type in + // marshalled form. + // + // SizeBytes must handle a nil receiver. Practically, this means SizeBytes + // cannot deference any fields on the object implementing it (but will + // likely make use of the type of these fields). + SizeBytes() int + + // MarshalBytes serializes a copy of a type to dst. dst may be smaller than + // SizeBytes(), which results in a part of the struct being marshalled. Note + // that this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is serialized. + MarshalBytes(dst []byte) + + // UnmarshalBytes deserializes a type from src. src may be smaller than + // SizeBytes(), which results in a partially deserialized struct. Note that + // this may have unexpected results for non-packed types, as implicit + // padding needs to be taken into account when reasoning about how much of + // the type is deserialized. + UnmarshalBytes(src []byte) + + // Packed returns true if the marshalled size of the type is the same as the + // size it occupies in memory. This happens when the type has no fields + // starting at unaligned addresses (should always be true by default for ABI + // structs, verified by automatically generated tests when using + // go_marshal), and has no fields marked `marshal:"unaligned"`. + // + // Packed must return the same result for all possible values of the type + // implementing it. Violating this constraint implies the type doesn't have + // a static memory layout, and will lead to memory corruption. + // Go-marshal-generated code reuses the result of Packed for multiple values + // of the same type. + Packed() bool + + // MarshalUnsafe serializes a type by bulk copying its in-memory + // representation to the dst buffer. This is only safe to do when the type + // has no implicit padding, see Marshallable.Packed. When Packed would + // return false, MarshalUnsafe should fall back to the safer but slower + // MarshalBytes. dst may be smaller than SizeBytes(), see comment for + // MarshalBytes for implications. + MarshalUnsafe(dst []byte) + + // UnmarshalUnsafe deserializes a type by directly copying to the underlying + // memory allocated for the object by the runtime. + // + // This allows much faster unmarshalling of types which have no implicit + // padding, see Marshallable.Packed. When Packed would return false, + // UnmarshalUnsafe should fall back to the safer but slower unmarshal + // mechanism implemented in UnmarshalBytes. src may be smaller than + // SizeBytes(), see comment for UnmarshalBytes for implications. + UnmarshalUnsafe(src []byte) + + // CopyIn deserializes a Marshallable type from a task's memory. This may + // only be called from a task goroutine. This is more efficient than calling + // UnmarshalUnsafe on Marshallable.Packed types, as the type being + // marshalled does not escape. The implementation should avoid creating + // extra copies in memory by directly deserializing to the object's + // underlying memory. + // + // If the copy-in from the task memory is only partially successful, CopyIn + // should still attempt to deserialize as much data as possible. See comment + // for UnmarshalBytes. + CopyIn(task Task, addr usermem.Addr) (int, error) + + // CopyOut serializes a Marshallable type to a task's memory. This may only + // be called from a task goroutine. This is more efficient than calling + // MarshalUnsafe on Marshallable.Packed types, as the type being serialized + // does not escape. The implementation should avoid creating extra copies in + // memory by directly serializing from the object's underlying memory. + // + // The copy-out to the task memory may be partially successful, in which + // case CopyOut returns how much data was serialized. See comment for + // MarshalBytes for implications. + CopyOut(task Task, addr usermem.Addr) (int, error) + + // CopyOutN is like CopyOut, but explicitly requests a partial + // copy-out. Note that this may yield unexpected results for non-packed + // types and the caller may only want to allow this for packed types. See + // comment on MarshalBytes. + // + // The limit must be less than or equal to SizeBytes(). + CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) +} + +// go-marshal generates additional functions for a type based on additional +// clauses to the +marshal directive. They are documented below. +// +// Slice API +// ========= +// +// Adding a "slice" clause to the +marshal directive for structs or newtypes on +// primitives like this: +// +// // +marshal slice:FooSlice +// type Foo struct { ... } +// +// Generates four additional functions for marshalling slices of Foos like this: +// +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a +// // []Foo in a loop. +// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } +// +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's +// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a +// // []Foo in a loop. +// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } +// +// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. +// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... } +// +// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. +// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... } +// +// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where +// %s is the first argument to the slice clause. This directive is not supported +// for newtypes on arrays. +// +// The slice clause also takes an optional second argument, which must be the +// value "inner": +// +// // +marshal slice:Int32Slice:inner +// type Int32 int32 +// +// This is only valid on newtypes on primitives, and causes the generated +// functions to accept slices of the inner type instead: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... } +// +// Without "inner", they would instead be: +// +// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... } +// +// This may help avoid a cast depending on how the generated functions are used. diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD new file mode 100644 index 000000000..cc08ba63a --- /dev/null +++ b/tools/go_marshal/primitive/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "primitive", + srcs = [ + "primitive.go", + ], + marshal = True, + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + ], +) diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go new file mode 100644 index 000000000..ebcf130ae --- /dev/null +++ b/tools/go_marshal/primitive/primitive.go @@ -0,0 +1,175 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package primitive defines marshal.Marshallable implementations for primitive +// types. +package primitive + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// Int16 is a marshal.Marshallable implementation for int16. +// +// +marshal slice:Int16Slice:inner +type Int16 int16 + +// Uint16 is a marshal.Marshallable implementation for uint16. +// +// +marshal slice:Uint16Slice:inner +type Uint16 uint16 + +// Int32 is a marshal.Marshallable implementation for int32. +// +// +marshal slice:Int32Slice:inner +type Int32 int32 + +// Uint32 is a marshal.Marshallable implementation for uint32. +// +// +marshal slice:Uint32Slice:inner +type Uint32 uint32 + +// Int64 is a marshal.Marshallable implementation for int64. +// +// +marshal slice:Int64Slice:inner +type Int64 int64 + +// Uint64 is a marshal.Marshallable implementation for uint64. +// +// +marshal slice:Uint64Slice:inner +type Uint64 uint64 + +// Below, we define some convenience functions for marshalling primitive types +// using the newtypes above, without requiring superfluous casts. + +// 16-bit integers + +// CopyInt16In is a convenient wrapper for copying in an int16 from the task's +// memory. +func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) { + var buf Int16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int16(buf) + return n, nil +} + +// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's +// memory. +func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) { + srcP := Int16(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's +// memory. +func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) { + var buf Uint16 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint16(buf) + return n, nil +} + +// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's +// memory. +func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) { + srcP := Uint16(src) + return srcP.CopyOut(task, addr) +} + +// 32-bit integers + +// CopyInt32In is a convenient wrapper for copying in an int32 from the task's +// memory. +func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) { + var buf Int32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int32(buf) + return n, nil +} + +// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's +// memory. +func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) { + srcP := Int32(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's +// memory. +func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) { + var buf Uint32 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint32(buf) + return n, nil +} + +// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's +// memory. +func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) { + srcP := Uint32(src) + return srcP.CopyOut(task, addr) +} + +// 64-bit integers + +// CopyInt64In is a convenient wrapper for copying in an int64 from the task's +// memory. +func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) { + var buf Int64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = int64(buf) + return n, nil +} + +// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's +// memory. +func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) { + srcP := Int64(src) + return srcP.CopyOut(task, addr) +} + +// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's +// memory. +func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) { + var buf Uint64 + n, err := buf.CopyIn(task, addr) + if err != nil { + return n, err + } + *dst = uint64(buf) + return n, nil +} + +// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's +// memory. +func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) { + srcP := Uint64(src) + return srcP.CopyOut(task, addr) +} diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD new file mode 100644 index 000000000..2fbcc8a03 --- /dev/null +++ b/tools/go_marshal/test/BUILD @@ -0,0 +1,44 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +licenses(["notice"]) + +package_group( + name = "gomarshal_test", + packages = [ + "//tools/go_marshal/test/...", + ], +) + +go_test( + name = "benchmark_test", + srcs = ["benchmark_test.go"], + deps = [ + ":test", + "//pkg/binary", + "//pkg/usermem", + "//tools/go_marshal/analysis", + ], +) + +go_library( + name = "test", + testonly = 1, + srcs = ["test.go"], + marshal = True, + visibility = ["//tools/go_marshal/test:__subpackages__"], + deps = ["//tools/go_marshal/test/external"], +) + +go_test( + name = "marshal_test", + size = "small", + srcs = ["marshal_test.go"], + deps = [ + ":test", + "//pkg/syserror", + "//pkg/usermem", + "//tools/go_marshal/analysis", + "//tools/go_marshal/marshal", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go new file mode 100644 index 000000000..224d308c7 --- /dev/null +++ b/tools/go_marshal/test/benchmark_test.go @@ -0,0 +1,220 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package benchmark_test + +import ( + "bytes" + encbin "encoding/binary" + "fmt" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// Marshalling using the standard encoding/binary package. +func BenchmarkEncodingBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := encbin.Size(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := bytes.NewBuffer(make([]byte, size)) + buf.Reset() + if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil { + b.Error("Write:", err) + } + if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil { + b.Error("Read:", err) + } + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling using the sentry's binary.Marshal. +func BenchmarkBinary(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling field-by-field with manually-written code. +func BenchmarkMarshalManual(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, s1.SizeBytes()) + + // Marshal + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID) + buf = binary.AppendUint32(buf, usermem.ByteOrder, 0) + buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec)) + buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec)) + + // Unmarshal + s2.Dev = usermem.ByteOrder.Uint64(buf[0:8]) + s2.Ino = usermem.ByteOrder.Uint64(buf[8:16]) + s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24]) + s2.Mode = usermem.ByteOrder.Uint32(buf[24:28]) + s2.UID = usermem.ByteOrder.Uint32(buf[28:32]) + s2.GID = usermem.ByteOrder.Uint32(buf[32:36]) + // Padding: buf[36:40] + s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48]) + s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56])) + s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64])) + s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72])) + s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80])) + s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88])) + s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96])) + s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104])) + s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112])) + s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120])) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal safe API. +func BenchmarkGoMarshalSafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalBytes(buf) + s2.UnmarshalBytes(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +// Marshalling with the go_marshal unsafe API. +func BenchmarkGoMarshalUnsafe(b *testing.B) { + var s1, s2 test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, s1.SizeBytes()) + s1.MarshalUnsafe(buf) + s2.UnmarshalUnsafe(buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +func BenchmarkBinarySlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + size := binary.Size(s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, &s1) + binary.Unmarshal(buf, usermem.ByteOrder, &s2) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} + +func BenchmarkGoMarshalUnsafeSlice(b *testing.B) { + var s1, s2 [64]test.Stat + analysis.RandomizeValue(&s1) + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1)) + test.MarshalUnsafeStatSlice(s1[:], buf) + test.UnmarshalUnsafeStatSlice(s2[:], buf) + } + + b.StopTimer() + + // Sanity check, make sure the values were preserved. + if !reflect.DeepEqual(s1, s2) { + panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2)) + } +} diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD new file mode 100644 index 000000000..f74e6ffae --- /dev/null +++ b/tools/go_marshal/test/escape/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "escape", + testonly = 1, + srcs = ["escape.go"], + deps = [ + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/test", + ], +) diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go new file mode 100644 index 000000000..6a46ddbf8 --- /dev/null +++ b/tools/go_marshal/test/escape/escape.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package escape + +import ( + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +// dummyTask implements marshal.Task. +type dummyTask struct { +} + +func (*dummyTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { + return len(b), nil +} + +func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalBytes(buf) + t.CopyOutBytes(addr, buf) +} + +func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { + buf := t.CopyScratchBuffer(marshallable.SizeBytes()) + marshallable.MarshalUnsafe(buf) + t.CopyOutBytes(addr, buf) +} + +// +checkescape:all +//go:nosplit +func doCopyIn(t *dummyTask) { + var stat test.Stat + stat.CopyIn(t, usermem.Addr(0xf000ba12)) +} + +// +checkescape:all +//go:nosplit +func doCopyOut(t *dummyTask) { + var stat test.Stat + stat.CopyOut(t, usermem.Addr(0xf000ba12)) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalBytesDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalBytes(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:builtin +// +mustescape:stack +func doMarshalUnsafeDirect(t *dummyTask) { + var stat test.Stat + buf := t.CopyScratchBuffer(stat.SizeBytes()) + stat.MarshalUnsafe(buf) + t.CopyOutBytes(usermem.Addr(0xf000ba12), buf) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalBytesViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalBytes(usermem.Addr(0xf000ba12), &stat) +} + +// +mustescape:local,heap +// +mustescape:stack +func doMarshalUnsafeViaMarshallable(t *dummyTask) { + var stat test.Stat + t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) +} diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD new file mode 100644 index 000000000..0cf6da603 --- /dev/null +++ b/tools/go_marshal/test/external/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "external", + testonly = 1, + srcs = ["external.go"], + marshal = True, + visibility = ["//tools/go_marshal/test:gomarshal_test"], +) diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go new file mode 100644 index 000000000..26fe8e0c8 --- /dev/null +++ b/tools/go_marshal/test/external/external.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package external defines types we can import for testing. +package external + +// External is a public Marshallable type for use in testing. +// +// +marshal +type External struct { + j int64 +} + +// NotPacked is an unaligned Marshallable type for use in testing. +// +// +marshal +type NotPacked struct { + a int32 + b byte `marshal:"unaligned"` +} diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go new file mode 100644 index 000000000..16829ee45 --- /dev/null +++ b/tools/go_marshal/test/marshal_test.go @@ -0,0 +1,515 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package marshal_test contains manual tests for the marshal interface. These +// are intended to test behaviour not covered by the automatically generated +// tests. +package marshal_test + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "runtime" + "testing" + "unsafe" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/analysis" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/test" +) + +var simulatedErr error = syserror.EFAULT + +// mockTask implements marshal.Task. +type mockTask struct { + taskMem usermem.BytesIO +} + +// populate fills the task memory with the contents of val. +func (t *mockTask) populate(val interface{}) { + var buf bytes.Buffer + // Use binary.Write so we aren't testing go-marshal against its own + // potentially buggy implementation. + if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil { + panic(err) + } + t.taskMem.Bytes = buf.Bytes() +} + +func (t *mockTask) setLimit(n int) { + if len(t.taskMem.Bytes) < n { + grown := make([]byte, n) + copy(grown, t.taskMem.Bytes) + t.taskMem.Bytes = grown + return + } + t.taskMem.Bytes = t.taskMem.Bytes[:n] +} + +// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer. +func (t *mockTask) CopyScratchBuffer(size int) []byte { + return make([]byte, size) +} + +// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation +// completely ignores the target address and stores a copy of b in its +// internally buffer, overriding any previous contents. +func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{}) +} + +// CopyInBytes implements marshal.Task.CopyInBytes. The implementation +// completely ignores the source address and always fills b from the begining of +// its internal buffer. +func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { + return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{}) +} + +// unsafeMemory returns the underlying memory for m. The returned slice is only +// valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +func unsafeMemory(m marshal.Marshallable) []byte { + if !m.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + // reflect.ValueOf(m) + // .Elem() // Unwrap interface to inner concrete object + // .Addr() // Pointer value to object + // .Pointer() // Actual address from the pointer value + ptr := reflect.ValueOf(m).Elem().Addr().Pointer() + + size := m.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = ptr + hdr.Len = size + hdr.Cap = size + + return mem +} + +// unsafeMemorySlice returns the underlying memory for m. The returned slice is +// only valid for the lifetime for m. The garbage collector isn't aware that the +// returned slice is related to m, the caller must ensure m lives long enough. +// +// Precondition: m must be a slice. +func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte { + kind := reflect.TypeOf(m).Kind() + if kind != reflect.Slice { + panic("unsafeMemorySlice called on non-slice") + } + + if !elt.Packed() { + // We can't return a slice pointing to the underlying memory + // since the layout isn't packed. Allocate a temporary buffer + // and marshal instead. + var buf bytes.Buffer + if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil { + panic(err) + } + return buf.Bytes() + } + + v := reflect.ValueOf(m) + length := v.Len() * elt.SizeBytes() + + var mem []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem)) + hdr.Data = v.Pointer() // This is a pointer to the first elem for slices. + hdr.Len = length + hdr.Cap = length + + return mem +} + +func isZeroes(buf []byte) bool { + for _, b := range buf { + if b != 0 { + return false + } + } + return true +} + +// compareMemory compares the first n bytes of two chuncks of memory represented +// by expected and actual. +func compareMemory(t *testing.T, expected, actual []byte, n int) { + t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:]) + t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:]) + + if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" { + t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff) + } +} + +// limitedCopyIn populates task memory with src, then unmarshals task memory to +// dst. The task signals an error at limit bytes during copy-in, which should +// result in a truncated unmarshalling. +func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { + var task mockTask + task.populate(src) + task.setLimit(limit) + + n, err := dst.CopyIn(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := unsafeMemory(dst) + defer runtime.KeepAlive(dst) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if dst.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem) + } +} + +// limitedCopyOut marshals src to task memory. The task signals an error at +// limit bytes during copy-out, which should result in a truncated marshalling. +func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOut(&task, usermem.Addr(0)) + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if err != simulatedErr { + t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) +} + +// copyOutN marshals src to task memory, requesting the marshalling to be +// limited to limit bytes. +func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { + var task mockTask + task.setLimit(limit) + + n, err := src.CopyOutN(&task, usermem.Addr(0), limit) + if err != nil { + t.Errorf("CopyOut returned unexpected error: %v", err) + } + if n != limit { + t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) + } + + expectedMem := unsafeMemory(src) + defer runtime.KeepAlive(src) + actualMem := task.taskMem.Bytes + + t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:]) + t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:]) + + compareMemory(t, expectedMem, actualMem, n) +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the +// underyling copy in/out operations partially succeed. +func TestLimitedMarshalling(t *testing.T) { + types := []reflect.Type{ + // Packed types. + reflect.TypeOf((*test.Type2)(nil)), + reflect.TypeOf((*test.Type3)(nil)), + reflect.TypeOf((*test.Timespec)(nil)), + reflect.TypeOf((*test.Stat)(nil)), + reflect.TypeOf((*test.InetAddr)(nil)), + reflect.TypeOf((*test.SignalSet)(nil)), + reflect.TypeOf((*test.SignalSetAlias)(nil)), + // Non-packed types. + reflect.TypeOf((*test.Type1)(nil)), + reflect.TypeOf((*test.Type4)(nil)), + reflect.TypeOf((*test.Type5)(nil)), + reflect.TypeOf((*test.Type6)(nil)), + reflect.TypeOf((*test.Type7)(nil)), + reflect.TypeOf((*test.Type8)(nil)), + } + + for _, tyPtr := range types { + // Remove one level of pointer-indirection from the type. We get this + // back when we pass the type to reflect.New. + ty := tyPtr.Elem() + + // Partial copy-in. + t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + actual := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyIn(t, expected, actual, expected.SizeBytes()/2) + }) + + // Partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + limitedCopyOut(t, expected, expected.SizeBytes()/2) + }) + + // Explicitly request partial copy-out. + t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) { + expected := reflect.New(ty).Interface().(marshal.Marshallable) + analysis.RandomizeValue(expected) + + copyOutN(t, expected, expected.SizeBytes()/2) + }) + } +} + +// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of +// marshallable types succeed when the underyling copy in/out operations +// partially succeed. +func TestLimitedSliceMarshalling(t *testing.T) { + types := []struct { + arrayPtrType reflect.Type + copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error) + copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error) + unsafeMemory func(arrPtr interface{}) []byte + }{ + // Packed types. + { + reflect.TypeOf((*[20]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Stat)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Stat)[:] + return test.CopyStatSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Stat)[:] + return test.CopyStatSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Stat)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[5]test.SignalSetAlias)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[5]test.SignalSetAlias)[:] + return test.CopySignalSetAliasSliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[5]test.SignalSetAlias)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + // Non-packed types. + { + reflect.TypeOf((*[20]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[20]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[20]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[20]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[1]test.Type1)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[1]test.Type1)[:] + return test.CopyType1SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[1]test.Type1)[:] + return test.CopyType1SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[1]test.Type1)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + { + reflect.TypeOf((*[7]test.Type8)(nil)), + func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + slice := dst.(*[7]test.Type8)[:] + return test.CopyType8SliceIn(task, addr, slice) + }, + func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + slice := src.(*[7]test.Type8)[:] + return test.CopyType8SliceOut(task, addr, slice) + }, + func(a interface{}) []byte { + slice := a.(*[7]test.Type8)[:] + return unsafeMemorySlice(slice, &slice[0]) + }, + }, + } + + for _, tt := range types { + // The body of this loop is generic over the type tt.arrayPtrType, with + // the help of reflection. To aid in readability, comments below show + // the equivalent go code assuming + // tt.arrayPtrType = typeof(*[20]test.Stat). + + // Equivalent: + // var x *[20]test.Stat + // arrayTy := reflect.TypeOf(*x) + arrayTy := tt.arrayPtrType.Elem() + + // Partial copy-in of slices. + t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + actual := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceIn(&task, usermem.Addr(0), actual) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := tt.unsafeMemory(actual) + defer runtime.KeepAlive(actual) + + compareMemory(t, expectedMem, actualMem, n) + + // The last n bytes should be zero for actual, since actual was + // zero-initialized, and CopyIn shouldn't have touched those bytes. However + // we can only guarantee we didn't touch anything in the last n bytes if the + // layout is packed. + if elem.Packed() && !isZeroes(actualMem[n:]) { + t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem) + } + }) + + // Partial copy-out of slices. + t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) { + // Equivalent: + // var x [20]test.Stat + // length := len(x) + length := arrayTy.Len() + if length < 1 { + panic("Test type can't be zero-length array") + } + // Equivalent: + // elem := new(test.Stat).(marshal.Marshallable) + elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable) + + // Equivalent: + // var expected, actual interface{} + // expected = new([20]test.Stat) + // actual = new([20]test.Stat) + expected := reflect.New(arrayTy).Interface() + + analysis.RandomizeValue(expected) + + limit := (length * elem.SizeBytes()) / 2 + // Also make sure the limit is partially inside one of the elements. + limit += elem.SizeBytes() / 2 + analysis.RandomizeValue(expected) + + var task mockTask + task.populate(expected) + task.setLimit(limit) + + n, err := tt.copySliceOut(&task, usermem.Addr(0), expected) + if n != limit { + t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) + } + if n < length*elem.SizeBytes() && err != simulatedErr { + t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err) + } + + expectedMem := tt.unsafeMemory(expected) + defer runtime.KeepAlive(expected) + actualMem := task.taskMem.Bytes + + compareMemory(t, expectedMem, actualMem, n) + }) + } +} diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go new file mode 100644 index 000000000..f75ca1b7f --- /dev/null +++ b/tools/go_marshal/test/test.go @@ -0,0 +1,176 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test contains data structures for testing the go_marshal tool. +package test + +import ( + // We're intentionally using a package name alias here even though it's not + // necessary to test the code generator's ability to handle package aliases. + ex "gvisor.dev/gvisor/tools/go_marshal/test/external" +) + +// Type1 is a test data type. +// +// +marshal slice:Type1Slice +type Type1 struct { + a Type2 + x, y int64 // Multiple field names. + b byte `marshal:"unaligned"` // Short field. + c uint64 + _ uint32 // Unnamed scalar field. + _ [6]byte // Unnamed vector field, typical padding. + _ [2]byte + xs [8]int32 + as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects. + ss Type3 +} + +// Type2 is a test data type. +// +// +marshal +type Type2 struct { + n int64 + c byte + _ [7]byte + m int64 + a int64 +} + +// Type3 is a test data type. +// +// +marshal +type Type3 struct { + s int64 + x ex.External // Type defined in another package. +} + +// Type4 is a test data type. +// +// +marshal +type Type4 struct { + c byte + x int64 `marshal:"unaligned"` + d byte + _ [7]byte +} + +// Type5 is a test data type. +// +// +marshal +type Type5 struct { + n int64 + t Type4 + m int64 +} + +// Type6 is a test data type ends mid-word. +// +// +marshal +type Type6 struct { + a int64 + b int64 + // If c isn't marked unaligned, analysis fails (as it should, since + // the unsafe API corrupts Type7). + c byte `marshal:"unaligned"` +} + +// Type7 is a test data type that contains a child struct that ends +// mid-word. +// +marshal +type Type7 struct { + x Type6 + y int64 +} + +// Type8 is a test data type which contains an external non-packed field. +// +// +marshal slice:Type8Slice +type Type8 struct { + a int64 + np ex.NotPacked + b int64 +} + +// Timespec represents struct timespec in <time.h>. +// +// +marshal +type Timespec struct { + Sec int64 + Nsec int64 +} + +// Stat represents struct stat. +// +// +marshal slice:StatSlice +type Stat struct { + Dev uint64 + Ino uint64 + Nlink uint64 + Mode uint32 + UID uint32 + GID uint32 + _ int32 + Rdev uint64 + Size int64 + Blksize int64 + Blocks int64 + ATime Timespec + MTime Timespec + CTime Timespec + _ [3]int64 +} + +// InetAddr is an example marshallable newtype on an array. +// +// +marshal +type InetAddr [4]byte + +// SignalSet is an example marshallable newtype on a primitive. +// +// +marshal slice:SignalSetSlice:inner +type SignalSet uint64 + +// SignalSetAlias is an example newtype on another marshallable type. +// +// +marshal slice:SignalSetAliasSlice +type SignalSetAlias SignalSet + +const sizeA = 64 +const sizeB = 8 + +// TestArray is a test data structure on an array with a constant length. +// +// +marshal +type TestArray [sizeA]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// constants for the array length. +// +// +marshal +type TestArray2 [sizeA * sizeB]int32 + +// TestArray2 is a newtype on an array with a simple arithmetic expression of +// mixed constants and literals for the array length. +// +// +marshal +type TestArray3 [sizeA*sizeB + 12]int32 + +// Type9 is a test data type containing an array with a non-literal length. +// +// +marshal +type Type9 struct { + x int64 + y [sizeA]int32 +} |