// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gomarshal import ( "fmt" "go/ast" "io" "strings" ) var standardImports = []string{ "bytes", "fmt", "reflect", "testing", "gvisor.dev/gvisor/tools/go_marshal/analysis", } type testGenerator struct { sourceBuffer // The type we're serializing. t *ast.TypeSpec // Receiver argument for generated methods. r string // Imports used by generated code. imports *importTable // Import statement for the package declaring the type we generated code // for. We need this to construct test instances for the type, since the // tests aren't written in the same package. decl *importStmt } func newTestGenerator(t *ast.TypeSpec) *testGenerator { if _, ok := t.Type.(*ast.StructType); !ok { panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) } g := &testGenerator{ t: t, r: receiverName(t), imports: newImportTable(), } for _, i := range standardImports { g.imports.add(i).markUsed() } return g } func (g *testGenerator) typeName() string { return g.t.Name.Name } func (g *testGenerator) forEachField(fn func(f *ast.Field)) { // This is guaranteed to succeed because g.t is always a struct. st := g.t.Type.(*ast.StructType) for _, field := range st.Fields.List { fn(field) } } func (g *testGenerator) testFuncName(base string) string { return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name)) } func (g *testGenerator) inTestFunction(name string, body func()) { g.emit("func %s(t *testing.T) {\n", g.testFuncName(name)) g.inIndent(body) g.emit("}\n\n") } func (g *testGenerator) emitTestNonZeroSize() { g.inTestFunction("TestSizeNonZero", func() { g.emit("x := &%s{}\n", g.typeName()) g.emit("if x.SizeBytes() == 0 {\n") g.inIndent(func() { g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n") }) g.emit("}\n") }) } func (g *testGenerator) emitTestSuspectAlignment() { g.inTestFunction("TestSuspectAlignment", func() { g.emit("x := %s{}\n", g.typeName()) g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n") }) } func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() { g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() { g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName()) g.emit("analysis.RandomizeValue(&x)\n\n") g.emit("buf := make([]byte, x.SizeBytes())\n") g.emit("x.MarshalBytes(buf)\n") g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n") g.emit("x.MarshalUnsafe(bufUnsafe)\n\n") g.emit("y.UnmarshalBytes(buf)\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n\n") g.emit("z.UnmarshalUnsafe(buf)\n") g.emit("if !reflect.DeepEqual(x, z) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n") }) g.emit("}\n") g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n") g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n") }) g.emit("}\n") }) } func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() { g.inTestFunction("TestWriteToUnmarshalPreservesData", func() { g.emit("var x, y, yUnsafe %s\n", g.typeName()) g.emit("analysis.RandomizeValue(&x)\n\n") g.emit("var buf bytes.Buffer\n\n") g.emit("x.WriteTo(&buf)\n") g.emit("y.UnmarshalBytes(buf.Bytes())\n\n") g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n") g.emit("if !reflect.DeepEqual(x, y) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n") }) g.emit("}\n") g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n") g.inIndent(func() { g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n") }) g.emit("}\n") }) } func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() { g.emit("var x %s\n", g.typeName()) g.emit("sizeFromConcrete := x.SizeBytes()\n") g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName()) g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n") g.inIndent(func() { g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)") }) g.emit("}\n") }) } func (g *testGenerator) emitTests() { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() g.emitTestMarshalUnmarshalPreservesData() g.emitTestWriteToUnmarshalPreservesData() g.emitTestSizeBytesOnTypedNilPtr() } func (g *testGenerator) write(out io.Writer) error { return g.sourceBuffer.write(out) }