From 3d32ad1367b4e84a0822808f44bd7b9f9351db71 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 14 Feb 2020 18:31:55 -0800 Subject: Generate implementation of io.WriterTo via go-marshal. PiperOrigin-RevId: 295269654 --- tools/go_marshal/gomarshal/generator.go | 6 ++- tools/go_marshal/gomarshal/generator_interfaces.go | 46 ++++++++++++++++++++++ tools/go_marshal/gomarshal/generator_tests.go | 34 ++++++++++++++-- tools/go_marshal/marshal/marshal.go | 4 ++ 4 files changed, 84 insertions(+), 6 deletions(-) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index fbec7bb9a..0294ba5ba 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -101,14 +101,16 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G // used, so that they're always added to the generated code. g.imports.add(i).markUsed() } - g.imports.add(marshalImport).markUsed() + // The following imports may or may not be used by the generated code, // depending on what's required for the target types. Don't mark these as // used by default. + g.imports.add("io") g.imports.add("reflect") g.imports.add("runtime") - g.imports.add(safecopyImport) g.imports.add("unsafe") + g.imports.add(marshalImport) + g.imports.add(safecopyImport) g.imports.add(usermemImport) return &g, nil diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index f25331ac5..22aae0f6b 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -602,4 +602,50 @@ func (g *interfaceGenerator) emitMarshallable() { } }) g.emit("}\n\n") + + g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") + g.recordUsedImport("io") + g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) + g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) + g.emit("%s.MarshalBytes(buf)\n", g.r) + g.emit("n, err := w.Write(buf)\n") + g.emit("return int64(n), err\n") + } + if thisPacked { + g.recordUsedImport("reflect") + g.recordUsedImport("runtime") + g.recordUsedImport("unsafe") + if cond, ok := g.areFieldsPackedExpression(); ok { + g.emit("if !%s {\n", cond) + g.inIndent(fallback) + g.emit("}\n\n") + } + // Fast serialization. + g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r) + g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r) + g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n") + g.emit("ptr := unsafe.Pointer(%s)\n", g.r) + g.emit("val := uintptr(ptr)\n") + g.emit("val = val^0\n\n") + + g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r) + g.emit("var buf []byte\n") + g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n") + g.emit("hdr.Data = val\n") + g.emit("hdr.Len = %s.SizeBytes()\n", g.r) + g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r) + + g.emit("len, err := w.Write(buf)\n") + g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r) + g.emit("// must live until after the Write.\n") + g.emit("runtime.KeepAlive(%s)\n", g.r) + g.emit("return int64(len), err\n") + } else { + fallback() + } + }) + g.emit("}\n\n") } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index cc760b6d0..5ad97af14 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -22,6 +22,7 @@ import ( ) var standardImports = []string{ + "bytes", "fmt", "reflect", "testing", @@ -117,26 +118,50 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { - g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n") + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") + }) + g.emit("}\n") + }) +} + +func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { + g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { + g.emit("var x, y, yUnsafe %s\n", g.typeName()) + g.emit("analysis.RandomizeValue(&x)\n\n") + + g.emit("var buf bytes.Buffer\n\n") + + g.emit("x.WriteTo(&buf)\n") + g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") + g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") + + g.emit("if !reflect.DeepEqual(x, y) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") + }) + g.emit("}\n") + g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") + g.inIndent(func() { + g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n") }) @@ -146,6 +171,7 @@ func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() } func (g *testGenerator) write(out io.Writer) error { diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index 10614ec4d..e521b50bd 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -21,6 +21,8 @@ package marshal import ( + "io" + "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,6 +44,8 @@ type Task interface { // Marshallable represents a type that can be marshalled to and from memory. type Marshallable interface { + io.WriterTo + // SizeBytes is the size of the memory representation of a type in // marshalled form. SizeBytes() int -- cgit v1.2.3