// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package gomarshal

import (
	"fmt"
	"go/ast"
	"go/token"
	"strings"
)

// interfaceGenerator generates marshalling interfaces for a single type.
//
// getState is not thread-safe.
type interfaceGenerator struct {
	sourceBuffer

	// The type we're serializing.
	t *ast.TypeSpec

	// Receiver argument for generated methods.
	r string

	// FileSet containing the tokens for the type we're processing.
	f *token.FileSet

	// is records external packages referenced by the generated implementation.
	is map[string]struct{}

	// ms records Marshallable types referenced by the generated implementation
	// of t's interfaces.
	ms map[string]struct{}

	// as records embedded fields in t that are potentially not packed. The key
	// is the accessor for the field.
	as map[string]struct{}
}

// typeName returns the name of the type this g represents.
func (g *interfaceGenerator) typeName() string {
	return g.t.Name.Name
}

// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
	g := &interfaceGenerator{
		t:  t,
		r:  receiverName(t),
		f:  fset,
		is: make(map[string]struct{}),
		ms: make(map[string]struct{}),
		as: make(map[string]struct{}),
	}
	g.recordUsedMarshallable(g.typeName())
	return g
}

func (g *interfaceGenerator) recordUsedMarshallable(m string) {
	g.ms[m] = struct{}{}

}

func (g *interfaceGenerator) recordUsedImport(i string) {
	g.is[i] = struct{}{}

}

func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
	g.as[fieldName] = struct{}{}
}

func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
	// This is guaranteed to succeed because g.t is always a struct.
	st := g.t.Type.(*ast.StructType)
	for _, field := range st.Fields.List {
		fn(field)
	}
}

func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
	return fmt.Sprintf("%s.%s", g.r, n.Name)
}

// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
	abortAt(g.f.Position(p), msg)
}

func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
	switch t.Name {
	case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
		// These are the only primitive types we're allow. Below, we provide
		// suggestions for some disallowed types and reject them, then attempt
		// to marshal any remaining types by invoking the marshal.Marshallable
		// interface on them. If these types don't actually implement
		// marshal.Marshallable, compilation of the generated code will fail
		// with an appropriate error message.
		return
	case "int":
		g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
	case "uint":
		g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
	case "string":
		g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
	default:
		debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
	}
}

// validateStruct ensures the type we're working with can be marshalled. These
// checks are done ahead of time and in one place so we can make assumptions
// later.
func (g *interfaceGenerator) validateStruct() {
	g.forEachField(func(f *ast.Field) {
		if len(f.Names) == 0 {
			g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
		}
	})

	g.forEachField(func(f *ast.Field) {
		fieldDispatcher{
			primitive: func(_, t *ast.Ident) {
				g.validatePrimitiveNewtype(t)
			},
			selector: func(_, _, _ *ast.Ident) {
				// No validation to perform on selector fields. However this
				// callback must still be provided.
			},
			array: func(n, _ *ast.Ident, len int) {
				a := f.Type.(*ast.ArrayType)
				if a.Len == nil {
					g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
				}

				if _, ok := a.Len.(*ast.BasicLit); !ok {
					g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
				}

				if _, ok := a.Elt.(*ast.Ident); !ok {
					g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
				}

				if len <= 0 {
					g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
				}
			},
			unhandled: func(_ *ast.Ident) {
				g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
			},
		}.dispatch(f)
	})
}

// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
	switch t.Name {
	case "int8", "uint8", "byte":
		return 1, false
	case "int16", "uint16":
		return 2, false
	case "int32", "uint32":
		return 4, false
	case "int64", "uint64":
		return 8, false
	default:
		return 0, true
	}
}

func (g *interfaceGenerator) shift(bufVar string, n int) {
	g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
}

func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
	g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}

// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
	switch typ {
	case "int8", "uint8", "byte":
		g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
		g.shift(bufVar, 1)
	case "int16", "uint16":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
		g.shift(bufVar, 2)
	case "int32", "uint32":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
		g.shift(bufVar, 4)
	case "int64", "uint64":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
		g.shift(bufVar, 8)
	default:
		g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
		g.shiftDynamic(bufVar, accessor)
	}
}

// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
// byte slice.
func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
	switch typ {
	case "byte":
		g.emit("%s = %s[0]\n", accessor, bufVar)
		g.shift(bufVar, 1)
	case "int8", "uint8":
		g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
		g.shift(bufVar, 1)
	case "int16", "uint16":
		g.recordUsedImport("usermem")
		g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
		g.shift(bufVar, 2)
	case "int32", "uint32":
		g.recordUsedImport("usermem")
		g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
		g.shift(bufVar, 4)
	case "int64", "uint64":
		g.recordUsedImport("usermem")
		g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
		g.shift(bufVar, 8)
	default:
		g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
		g.shiftDynamic(bufVar, accessor)
		g.recordPotentiallyNonPackedField(accessor)
	}
}

// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
	switch typ {
	case "int8", "uint8", "byte":
		g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
	case "int16", "uint16":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
	case "int32", "uint32":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
	case "int64", "uint64":
		g.recordUsedImport("usermem")
		g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
	default:
		g.emit("inner := (*%s)(%s)\n", typ, accessor)
		g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
	}
}

// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
	switch typ {
	case "byte":
		g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
	case "int8", "uint8":
		g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
	case "int16", "uint16":
		g.recordUsedImport("usermem")
		g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
	case "int32", "uint32":
		g.recordUsedImport("usermem")
		g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)

	case "int64", "uint64":
		g.recordUsedImport("usermem")
		g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
	default:
		g.emit("inner := (*%s)(%s)\n", typ, accessor)
		g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
	}
}

// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
// packed. Returns "", false if g.t has no fields that may be potentially
// packed, otherwise returns <clause>, true, where <clause> is an expression
// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
	if len(g.as) == 0 {
		return "", false
	}

	cs := make([]string, 0, len(g.as))
	for accessor, _ := range g.as {
		cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
	}
	return strings.Join(cs, " && "), true
}

func (g *interfaceGenerator) emitMarshallableForStruct() {
	// Is g.t a packed struct without consideing field types?
	thisPacked := true
	g.forEachField(func(f *ast.Field) {
		if f.Tag != nil {
			if f.Tag.Value == "`marshal:\"unaligned\"`" {
				if thisPacked {
					debugfAt(g.f.Position(g.t.Pos()),
						fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
					thisPacked = false
				}
			}
		}
	})

	g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
	g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
	g.inIndent(func() {
		primitiveSize := 0
		var dynamicSizeTerms []string

		g.forEachField(fieldDispatcher{
			primitive: func(n, t *ast.Ident) {
				if size, dynamic := g.scalarSize(t); !dynamic {
					primitiveSize += size
				} else {
					g.recordUsedMarshallable(t.Name)
					dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
				}
			},
			selector: func(n, tX, tSel *ast.Ident) {
				tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
				g.recordUsedImport(tX.Name)
				g.recordUsedMarshallable(tName)
				dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
			},
			array: func(n, t *ast.Ident, len int) {
				if len < 1 {
					// Zero-length arrays should've been rejected by validate().
					panic("unreachable")
				}
				if size, dynamic := g.scalarSize(t); !dynamic {
					primitiveSize += size * len
				} else {
					g.recordUsedMarshallable(t.Name)
					dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
				}
			},
		}.dispatch)
		g.emit("return %d", primitiveSize)
		if len(dynamicSizeTerms) > 0 {
			g.incIndent()
		}
		{
			for _, d := range dynamicSizeTerms {
				g.emitNoIndent(" +\n")
				g.emit(d)
			}
		}
		if len(dynamicSizeTerms) > 0 {
			g.decIndent()
		}
	})
	g.emit("\n}\n\n")

	g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
	g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.forEachField(fieldDispatcher{
			primitive: func(n, t *ast.Ident) {
				if n.Name == "_" {
					g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
					if len, dynamic := g.scalarSize(t); !dynamic {
						g.shift("dst", len)
					} else {
						// We can't use shiftDynamic here because we don't have
						// an instance of the dynamic type we can referece here
						// (since the version in this struct is anonymous). Use
						// a typed nil pointer to call SizeBytes() instead.
						g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
					}
					return
				}
				g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
			},
			selector: func(n, tX, tSel *ast.Ident) {
				g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
			},
			array: func(n, t *ast.Ident, size int) {
				if n.Name == "_" {
					g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
					if len, dynamic := g.scalarSize(t); !dynamic {
						g.shift("dst", len*size)
					} else {
						// We can't use shiftDynamic here because we don't have
						// an instance of the dynamic type we can reference here
						// (since the version in this struct is anonymous). Use
						// a typed nil pointer to call SizeBytes() instead.
						g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
					}
					return
				}

				g.emit("for idx := 0; idx < %d; idx++ {\n", size)
				g.inIndent(func() {
					g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
				})
				g.emit("}\n")
			},
		}.dispatch)
	})
	g.emit("}\n\n")

	g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
	g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.forEachField(fieldDispatcher{
			primitive: func(n, t *ast.Ident) {
				if n.Name == "_" {
					g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
					if len, dynamic := g.scalarSize(t); !dynamic {
						g.shift("src", len)
					} else {
						// We can't use shiftDynamic here because we don't have
						// an instance of the dynamic type we can reference here
						// (since the version in this struct is anonymous). Use
						// a typed nil pointer to call SizeBytes() instead.
						g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
						g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
					}
					return
				}
				g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
			},
			selector: func(n, tX, tSel *ast.Ident) {
				g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
			},
			array: func(n, t *ast.Ident, size int) {
				if n.Name == "_" {
					g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
					if len, dynamic := g.scalarSize(t); !dynamic {
						g.shift("src", len*size)
					} else {
						// We can't use shiftDynamic here because we don't have
						// an instance of the dynamic type we can referece here
						// (since the version in this struct is anonymous). Use
						// a typed nil pointer to call SizeBytes() instead.
						g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
					}
					return
				}

				g.emit("for idx := 0; idx < %d; idx++ {\n", size)
				g.inIndent(func() {
					g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
				})
				g.emit("}\n")
			},
		}.dispatch)
	})
	g.emit("}\n\n")

	g.emit("// Packed implements marshal.Marshallable.Packed.\n")
	g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
	g.inIndent(func() {
		expr, fieldsMaybePacked := g.areFieldsPackedExpression()
		switch {
		case !thisPacked:
			g.emit("return false\n")
		case fieldsMaybePacked:
			g.emit("return %s\n", expr)
		default:
			g.emit("return true\n")

		}
	})
	g.emit("}\n\n")

	g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
	g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		if thisPacked {
			g.recordUsedImport("safecopy")
			g.recordUsedImport("unsafe")
			if cond, ok := g.areFieldsPackedExpression(); ok {
				g.emit("if %s {\n", cond)
				g.inIndent(func() {
					g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
				})
				g.emit("} else {\n")
				g.inIndent(func() {
					g.emit("%s.MarshalBytes(dst)\n", g.r)
				})
				g.emit("}\n")
			} else {
				g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
			}
		} else {
			g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
			g.emit("%s.MarshalBytes(dst)\n", g.r)
		}
	})
	g.emit("}\n\n")

	g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
	g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		if thisPacked {
			g.recordUsedImport("safecopy")
			g.recordUsedImport("unsafe")
			if cond, ok := g.areFieldsPackedExpression(); ok {
				g.emit("if %s {\n", cond)
				g.inIndent(func() {
					g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
				})
				g.emit("} else {\n")
				g.inIndent(func() {
					g.emit("%s.UnmarshalBytes(src)\n", g.r)
				})
				g.emit("}\n")
			} else {
				g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
			}
		} else {
			g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
			g.emit("%s.UnmarshalBytes(src)\n", g.r)
		}
	})
	g.emit("}\n\n")

	g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
	g.recordUsedImport("marshal")
	g.recordUsedImport("usermem")
	g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
	g.inIndent(func() {
		fallback := func() {
			g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
			g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
			g.emit("%s.MarshalBytes(buf)\n", g.r)
			g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
			g.emit("return err\n")
		}
		if thisPacked {
			g.recordUsedImport("reflect")
			g.recordUsedImport("runtime")
			g.recordUsedImport("unsafe")
			if cond, ok := g.areFieldsPackedExpression(); ok {
				g.emit("if !%s {\n", cond)
				g.inIndent(fallback)
				g.emit("}\n\n")
			}
			// Fast serialization.
			g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
			g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
			g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
			g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
			g.emit("val := uintptr(ptr)\n")
			g.emit("val = val^0\n\n")

			g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
			g.emit("var buf []byte\n")
			g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
			g.emit("hdr.Data = val\n")
			g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
			g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

			g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
			g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
			g.emit("// must live until after the CopyOutBytes.\n")
			g.emit("runtime.KeepAlive(%s)\n", g.r)
			g.emit("return err\n")
		} else {
			fallback()
		}
	})
	g.emit("}\n\n")

	g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
	g.recordUsedImport("marshal")
	g.recordUsedImport("usermem")
	g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
	g.inIndent(func() {
		fallback := func() {
			g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
			g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
			g.emit("_, err := task.CopyInBytes(addr, buf)\n")
			g.emit("if err != nil {\n")
			g.inIndent(func() {
				g.emit("return err\n")
			})
			g.emit("}\n")

			g.emit("%s.UnmarshalBytes(buf)\n", g.r)
			g.emit("return nil\n")
		}
		if thisPacked {
			g.recordUsedImport("reflect")
			g.recordUsedImport("runtime")
			g.recordUsedImport("unsafe")
			if cond, ok := g.areFieldsPackedExpression(); ok {
				g.emit("if !%s {\n", cond)
				g.inIndent(fallback)
				g.emit("}\n\n")
			}
			// Fast deserialization.
			g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
			g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
			g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
			g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
			g.emit("val := uintptr(ptr)\n")
			g.emit("val = val^0\n\n")

			g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
			g.emit("var buf []byte\n")
			g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
			g.emit("hdr.Data = val\n")
			g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
			g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

			g.emit("_, err := task.CopyInBytes(addr, buf)\n")
			g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
			g.emit("// must live until after the CopyInBytes.\n")
			g.emit("runtime.KeepAlive(%s)\n", g.r)
			g.emit("return err\n")
		} else {
			fallback()
		}
	})
	g.emit("}\n\n")

	g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
	g.recordUsedImport("io")
	g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
	g.inIndent(func() {
		fallback := func() {
			g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
			g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
			g.emit("%s.MarshalBytes(buf)\n", g.r)
			g.emit("n, err := w.Write(buf)\n")
			g.emit("return int64(n), err\n")
		}
		if thisPacked {
			g.recordUsedImport("reflect")
			g.recordUsedImport("runtime")
			g.recordUsedImport("unsafe")
			if cond, ok := g.areFieldsPackedExpression(); ok {
				g.emit("if !%s {\n", cond)
				g.inIndent(fallback)
				g.emit("}\n\n")
			}
			// Fast serialization.
			g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
			g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
			g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
			g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
			g.emit("val := uintptr(ptr)\n")
			g.emit("val = val^0\n\n")

			g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
			g.emit("var buf []byte\n")
			g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
			g.emit("hdr.Data = val\n")
			g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
			g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

			g.emit("len, err := w.Write(buf)\n")
			g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
			g.emit("// must live until after the Write.\n")
			g.emit("runtime.KeepAlive(%s)\n", g.r)
			g.emit("return int64(len), err\n")
		} else {
			fallback()
		}
	})
	g.emit("}\n\n")
}

// emitMarshallableForPrimitiveNewtype outputs code to implement the
// marshal.Marshallable interface for a newtype on a primitive. Primitive
// newtypes are always packed, so we can omit the various fallbacks required for
// non-packed structs.
func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
	g.recordUsedImport("io")
	g.recordUsedImport("marshal")
	g.recordUsedImport("reflect")
	g.recordUsedImport("runtime")
	g.recordUsedImport("safecopy")
	g.recordUsedImport("unsafe")
	g.recordUsedImport("usermem")

	nt := g.t.Type.(*ast.Ident)

	g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
	g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
	g.inIndent(func() {
		if size, dynamic := g.scalarSize(nt); !dynamic {
			g.emit("return %d\n", size)
		} else {
			g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
		}
	})
	g.emit("}\n\n")

	g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
	g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
	})
	g.emit("}\n\n")

	g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
	g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
	})
	g.emit("}\n\n")

	g.emit("// Packed implements marshal.Marshallable.Packed.\n")
	g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.emit("// Scalar newtypes are always packed.\n")
		g.emit("return true\n")
	})
	g.emit("}\n\n")

	g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
	g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
	})
	g.emit("}\n\n")

	g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
	g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
	})
	g.emit("}\n\n")

	g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
	g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
	g.inIndent(func() {
		// Fast serialization.
		g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
		g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
		g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
		g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
		g.emit("val := uintptr(ptr)\n")
		g.emit("val = val^0\n\n")

		g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
		g.emit("var buf []byte\n")
		g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
		g.emit("hdr.Data = val\n")
		g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
		g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

		g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
		g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
		g.emit("// must live until after the CopyOutBytes.\n")
		g.emit("runtime.KeepAlive(%s)\n", g.r)
		g.emit("return err\n")
	})
	g.emit("}\n\n")

	g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
	g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
		g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
		g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
		g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
		g.emit("val := uintptr(ptr)\n")
		g.emit("val = val^0\n\n")

		g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
		g.emit("var buf []byte\n")
		g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
		g.emit("hdr.Data = val\n")
		g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
		g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

		g.emit("_, err := task.CopyInBytes(addr, buf)\n")
		g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
		g.emit("// must live until after the CopyInBytes.\n")
		g.emit("runtime.KeepAlive(%s)\n", g.r)
		g.emit("return err\n")
	})
	g.emit("}\n\n")

	g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
	g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
	g.inIndent(func() {
		g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
		g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
		g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
		g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
		g.emit("val := uintptr(ptr)\n")
		g.emit("val = val^0\n\n")

		g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
		g.emit("var buf []byte\n")
		g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
		g.emit("hdr.Data = val\n")
		g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
		g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)

		g.emit("len, err := w.Write(buf)\n")
		g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
		g.emit("// must live until after the Write.\n")
		g.emit("runtime.KeepAlive(%s)\n", g.r)
		g.emit("return int64(len), err\n")

	})
	g.emit("}\n\n")

}