summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal/generator_interfaces.go
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2020-02-20 16:22:45 -0800
committerCopybara-Service <copybara-worker@google.com>2020-02-20 16:23:56 -0800
commitf1b72752e5de2abc3c409a6b7447224620b7c11b (patch)
tree12f20bad0a23969311401ddcf707588c6b048424 /tools/go_marshal/gomarshal/generator_interfaces.go
parent4a73bae269ae9f52a962ae3b08a17ccaacf7ba80 (diff)
Implement automated marshalling for newtypes on primitives.
PiperOrigin-RevId: 296322954
Diffstat (limited to 'tools/go_marshal/gomarshal/generator_interfaces.go')
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go296
1 files changed, 233 insertions, 63 deletions
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 834c58cee..ea1af998e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct() {
g.forEachField(func(f *ast.Field) {
if len(f.Names) == 0 {
g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
@@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() {
g.forEachField(func(f *ast.Field) {
fieldDispatcher{
primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
+ g.validatePrimitiveNewtype(t)
},
selector: func(_, _, _ *ast.Ident) {
// No validation to perform on selector fields. However this
@@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
+func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
+// byte slice.
+func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
+// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
// packed. Returns "", false if g.t has no fields that may be potentially
// packed, otherwise returns <clause>, true, where <clause> is an expression
@@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
return strings.Join(cs, " && "), true
}
-func (g *interfaceGenerator) emitMarshallable() {
+func (g *interfaceGenerator) emitMarshallableForStruct() {
// Is g.t a packed struct without consideing field types?
thisPacked := true
g.forEachField(func(f *ast.Field) {
@@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
g.emit("}\n")
},
@@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
g.emit("}\n")
},
@@ -650,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() {
})
g.emit("}\n\n")
}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ nt := g.t.Type.(*ast.Ident)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+
+}