summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2020-02-27 14:51:29 -0800
committergVisor bot <gvisor-bot@google.com>2020-02-27 14:52:26 -0800
commitaa9f8abaef5c6250bdcee8fd88b2420f20791c5d (patch)
treeb91caf2f1a0f00ecff72916b00216df13249e986
parent2cccf3d27b138b677ef50a663304b1ba83d62051 (diff)
Implement automated marshalling for newtypes on arrays.
PiperOrigin-RevId: 297693838
-rw-r--r--tools/go_marshal/gomarshal/BUILD3
-rw-r--r--tools/go_marshal/gomarshal/generator.go17
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go665
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go183
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go229
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go450
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go2
-rw-r--r--tools/go_marshal/gomarshal/util.go41
-rw-r--r--tools/go_marshal/test/test.go5
9 files changed, 915 insertions, 680 deletions
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index b5d5a4487..44cb33ae4 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -7,6 +7,9 @@ go_library(
srcs = [
"generator.go",
"generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
"generator_tests.go",
"util.go",
],
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index d365a1f3c..729489de5 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -235,6 +235,10 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*a
debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
types = append(types, t)
continue
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
}
// A user specifically requested marshalling on this type, but we
// don't support it.
@@ -281,17 +285,20 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface
i := newInterfaceGenerator(t, fset)
switch ty := t.Type.(type) {
case *ast.StructType:
- i.validateStruct()
- i.emitMarshallableForStruct()
- return i
+ i.validateStruct(t, ty)
+ i.emitMarshallableForStruct(ty)
case *ast.Ident:
i.validatePrimitiveNewtype(ty)
- i.emitMarshallableForPrimitiveNewtype()
- return i
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
default:
// This should've been filtered out by collectMarshallabeTypes.
panic(fmt.Sprintf("Unexpected type %+v", ty))
}
+ return i
}
// generateOneTestSuite generates a test suite for the automatically generated
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index ea1af998e..8babf61d2 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -15,10 +15,8 @@
package gomarshal
import (
- "fmt"
"go/ast"
"go/token"
- "strings"
)
// interfaceGenerator generates marshalling interfaces for a single type.
@@ -81,18 +79,6 @@ func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
g.as[fieldName] = struct{}{}
}
-func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
@@ -100,71 +86,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we provide
- // suggestions for some disallowed types and reject them, then attempt
- // to marshal any remaining types by invoking the marshal.Marshallable
- // interface on them. If these types don't actually implement
- // marshal.Marshallable, compilation of the generated code will fail
- // with an appropriate error message.
- return
- case "int":
- g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
-}
-
-// validateStruct ensures the type we're working with can be marshalled. These
-// checks are done ahead of time and in one place so we can make assumptions
-// later.
-func (g *interfaceGenerator) validateStruct() {
- g.forEachField(func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- g.forEachField(func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- g.validatePrimitiveNewtype(t)
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n, _ *ast.Ident, len int) {
- a := f.Type.(*ast.ArrayType)
- if a.Len == nil {
- g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-
- if len <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
@@ -191,8 +112,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
-func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -215,9 +136,8 @@ func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar stri
}
}
-// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
-// byte slice.
-func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
switch typ {
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
@@ -243,580 +163,3 @@ func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar st
g.recordPotentiallyNonPackedField(accessor)
}
}
-
-// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
-func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
- switch typ {
- case "int8", "uint8", "byte":
- g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
-func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
- switch typ {
- case "byte":
- g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
- case "int8", "uint8":
- g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
- case "int16", "uint16":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
- case "int32", "uint32":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
-
- case "int64", "uint64":
- g.recordUsedImport("usermem")
- g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
- default:
- g.emit("inner := (*%s)(%s)\n", typ, accessor)
- g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
- }
-}
-
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor, _ := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- return strings.Join(cs, " && "), true
-}
-
-func (g *interfaceGenerator) emitMarshallableForStruct() {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
- g.forEachField(func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
- }
- }
- }
- })
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
- }
- },
- selector: func(n, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
- g.inIndent(func() {
- g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
-
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("return err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.recordUsedImport("marshal")
- g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("if err != nil {\n")
- g.inIndent(func() {
- g.emit("return err\n")
- })
- g.emit("}\n")
-
- g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return nil\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast deserialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.recordUsedImport("io")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- fallback := func() {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
- g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
- g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("n, err := w.Write(buf)\n")
- g.emit("return int64(n), err\n")
- }
- if thisPacked {
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if !%s {\n", cond)
- g.inIndent(fallback)
- g.emit("}\n\n")
- }
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
- } else {
- fallback()
- }
- })
- g.emit("}\n\n")
-}
-
-// emitMarshallableForPrimitiveNewtype outputs code to implement the
-// marshal.Marshallable interface for a newtype on a primitive. Primitive
-// newtypes are always packed, so we can omit the various fallbacks required for
-// non-packed structs.
-func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
- g.recordUsedImport("io")
- g.recordUsedImport("marshal")
- g.recordUsedImport("reflect")
- g.recordUsedImport("runtime")
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- g.recordUsedImport("usermem")
-
- nt := g.t.Type.(*ast.Ident)
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- if size, dynamic := g.scalarSize(nt); !dynamic {
- g.emit("return %d\n", size)
- } else {
- g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
- }
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
- })
- g.emit("}\n\n")
-
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Scalar newtypes are always packed.\n")
- g.emit("return true\n")
- })
- g.emit("}\n\n")
-
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("}\n\n")
-
- g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- // Fast serialization.
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyOutBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("_, err := task.CopyInBytes(addr, buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the CopyInBytes.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return err\n")
- })
- g.emit("}\n\n")
-
- g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
- g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
- g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
- g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
- g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
- g.emit("val := uintptr(ptr)\n")
- g.emit("val = val^0\n\n")
-
- g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
- g.emit("var buf []byte\n")
- g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
- g.emit("hdr.Data = val\n")
- g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
- g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
-
- g.emit("len, err := w.Write(buf)\n")
- g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
- g.emit("// must live until after the Write.\n")
- g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return int64(len), err\n")
-
- })
- g.emit("}\n\n")
-
-}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..da36d9305
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,183 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if arrayLen(a) <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d\n", size*len)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..159397825
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,229 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..e66a38b2e
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,450 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("return err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("if err != nil {\n")
+ g.inIndent(func() {
+ g.emit("return err\n")
+ })
+ g.emit("}\n")
+
+ g.emit("%s.UnmarshalBytes(buf)\n", g.r)
+ g.emit("return nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("n, err := w.Write(buf)\n")
+ g.emit("return int64(n), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 8ba47eb67..fd992e44a 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -164,7 +164,7 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
g.inIndent(func() {
- g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)")
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
})
g.emit("}\n")
})
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index e2bca4e7c..a0936e013 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -64,6 +64,12 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
@@ -73,6 +79,25 @@ type fieldDispatcher struct {
unhandled func(n *ast.Ident)
}
+// Precondition: a must have a literal for the array length. Consts and
+// expressions are not allowed as array lengths, and should be rejected by the
+// caller.
+func arrayLen(a *ast.ArrayType) int {
+ if a.Len == nil {
+ // Probably a slice? Must be handled by caller.
+ panic("Nil array length in array type")
+ }
+ lenLit, ok := a.Len.(*ast.BasicLit)
+ if !ok {
+ panic("Array has non-literal for length")
+ }
+ len, err := strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
+ }
+ return len
+}
+
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, t, arrayLen(v))
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 93229dedb..c829db6da 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -104,6 +104,11 @@ type Stat struct {
_ [3]int64
}
+// InetAddr is an example marshallable newtype on an array.
+//
+// +marshal
+type InetAddr [4]byte
+
// SignalSet is an example marshallable newtype on a primitive.
//
// +marshal