From f1b72752e5de2abc3c409a6b7447224620b7c11b Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 16:22:45 -0800 Subject: Implement automated marshalling for newtypes on primitives. PiperOrigin-RevId: 296322954 --- tools/go_marshal/gomarshal/generator_interfaces.go | 296 ++++++++++++++++----- 1 file changed, 233 insertions(+), 63 deletions(-) (limited to 'tools/go_marshal/gomarshal/generator_interfaces.go') diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 834c58cee..ea1af998e 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string { // newinterfaceGenerator creates a new interface generator. func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { - if _, ok := t.Type.(*ast.StructType); !ok { - panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) - } g := &interfaceGenerator{ t: t, r: receiverName(t), @@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) { abortAt(g.f.Position(p), msg) } -// validate ensures the type we're working with can be marshalled. These checks -// are done ahead of time and in one place so we can make assumptions later. -func (g *interfaceGenerator) validate() { +func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) { + switch t.Name { + case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": + // These are the only primitive types we're allow. Below, we provide + // suggestions for some disallowed types and reject them, then attempt + // to marshal any remaining types by invoking the marshal.Marshallable + // interface on them. If these types don't actually implement + // marshal.Marshallable, compilation of the generated code will fail + // with an appropriate error message. + return + case "int": + g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64") + case "uint": + g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") + case "string": + g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") + default: + debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) + } +} + +// validateStruct ensures the type we're working with can be marshalled. These +// checks are done ahead of time and in one place so we can make assumptions +// later. +func (g *interfaceGenerator) validateStruct() { g.forEachField(func(f *ast.Field) { if len(f.Names) == 0 { g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields") @@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() { g.forEachField(func(f *ast.Field) { fieldDispatcher{ primitive: func(_, t *ast.Ident) { - switch t.Name { - case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64": - // These are the only primitive types we're allow. Below, we - // provide suggestions for some disallowed types and reject - // them, then attempt to marshal any remaining types by - // invoking the marshal.Marshallable interface on them. If - // these types don't actually implement - // marshal.Marshallable, compilation of the generated code - // will fail with an appropriate error message. - return - case "int": - g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64") - case "uint": - g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64") - case "string": - g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead") - default: - debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name)) - } + g.validatePrimitiveNewtype(t) }, selector: func(_, _, _ *ast.Ident) { // No validation to perform on selector fields. However this @@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) { g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name) } -func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) { +// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice. +func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { case "int8", "uint8", "byte": g.emit("%s[0] = byte(%s)\n", bufVar, accessor) @@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) } } -func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) { +// unmarshalStructFieldScalar reads a single scalar field from a struct, from a +// byte slice. +func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) { switch typ { - case "int8": - g.emit("%s = int8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) - case "uint8": - g.emit("%s = uint8(%s[0])\n", accessor, bufVar) - g.shift(bufVar, 1) case "byte": g.emit("%s = %s[0]\n", accessor, bufVar) g.shift(bufVar, 1) - - case "int16": - g.recordUsedImport("usermem") - g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar) - g.shift(bufVar, 2) - case "uint16": + case "int8", "uint8": + g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar) + g.shift(bufVar, 1) + case "int16", "uint16": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar) g.shift(bufVar, 2) - - case "int32": - g.recordUsedImport("usermem") - g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar) - g.shift(bufVar, 4) - case "uint32": + case "int32", "uint32": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar) g.shift(bufVar, 4) - - case "int64": - g.recordUsedImport("usermem") - g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar) - g.shift(bufVar, 8) - case "uint64": + case "int64", "uint64": g.recordUsedImport("usermem") - g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar) + g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) @@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string } } +// marshalPrimitiveScalar writes a single primitive variable to a byte slice. +func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) { + switch typ { + case "int8", "uint8", "byte": + g.emit("%s[0] = byte(*%s)\n", bufVar, accessor) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor) + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + +// unmarshalPrimitiveScalar read a single primitive variable from a byte slice. +func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) { + switch typ { + case "byte": + g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar) + case "int8", "uint8": + g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar) + case "int16", "uint16": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar) + case "int32", "uint32": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar) + + case "int64", "uint64": + g.recordUsedImport("usermem") + g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar) + default: + g.emit("inner := (*%s)(%s)\n", typ, accessor) + g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor) + } +} + // areFieldsPackedExpression returns a go expression checking whether g.t's fields are // packed. Returns "", false if g.t has no fields that may be potentially // packed, otherwise returns , true, where is an expression @@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { return strings.Join(cs, " && "), true } -func (g *interfaceGenerator) emitMarshallable() { +func (g *interfaceGenerator) emitMarshallableForStruct() { // Is g.t a packed struct without consideing field types? thisPacked := true g.forEachField(func(f *ast.Field) { @@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.marshalScalar(g.fieldAccessor(n), t.Name, "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst") }, selector: func(n, tX, tSel *ast.Ident) { - g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") + g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst") + g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst") }) g.emit("}\n") }, @@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() { } return } - g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src") }, selector: func(n, tX, tSel *ast.Ident) { - g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") + g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src") }, array: func(n, t *ast.Ident, size int) { if n.Name == "_" { @@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() { return } - g.emit("for i := 0; i < %d; i++ {\n", size) + g.emit("for idx := 0; idx < %d; idx++ {\n", size) g.inIndent(func() { - g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src") + g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src") }) g.emit("}\n") }, @@ -650,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() { }) g.emit("}\n\n") } + +// emitMarshallableForPrimitiveNewtype outputs code to implement the +// marshal.Marshallable interface for a newtype on a primitive. Primitive +// newtypes are always packed, so we can omit the various fallbacks required for +// non-packed structs. +func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() { + g.recordUsedImport("io") + g.recordUsedImport("marshal") + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("safecopy") + g.recordUsedImport("unsafe") + g.recordUsedImport("usermem") + + nt := g.t.Type.(*ast.Ident) + + g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n") + g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName()) + g.inIndent(func() { + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return %d\n", size) + } else { + g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name) + } + }) + g.emit("}\n\n") + + g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") + g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + }) + g.emit("}\n\n") + + g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") + g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + }) + g.emit("}\n\n") + + g.emit("// Packed implements marshal.Marshallable.Packed.\n") + g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Scalar newtypes are always packed.\n") + g.emit("return true\n") + }) + g.emit("}\n\n") + + g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) + }) + g.emit("}\n\n") + + g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") + g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyOutBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyOutBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") + g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("_, err := task.CopyInBytes(addr, buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the CopyInBytes.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return err\n") + }) + g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + + }) + g.emit("}\n\n") + +} -- cgit v1.2.3