// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gomarshal import ( "bytes" "flag" "fmt" "go/ast" "go/token" "io" "os" "path" "reflect" "sort" "strconv" "strings" ) var debug = flag.Bool("debug", false, "enables debugging output") // receiverName returns an appropriate receiver name given a type spec. func receiverName(t *ast.TypeSpec) string { if len(t.Name.Name) < 1 { // Zero length type name? panic("unreachable") } return strings.ToLower(t.Name.Name[:1]) } // kindString returns a user-friendly representation of an AST expr type. func kindString(e ast.Expr) string { switch e.(type) { case *ast.Ident: return "scalar" case *ast.ArrayType: return "array" case *ast.StructType: return "struct" case *ast.StarExpr: return "pointer" case *ast.FuncType: return "function" case *ast.InterfaceType: return "interface" case *ast.MapType: return "map" case *ast.ChanType: return "channel" default: return reflect.TypeOf(e).String() } } func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { for _, field := range st.Fields.List { fn(field) } } // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { primitive func(n, t *ast.Ident) selector func(n, tX, tSel *ast.Ident) array func(n, t *ast.Ident, size int) unhandled func(n *ast.Ident) } // Precondition: a must have a literal for the array length. Consts and // expressions are not allowed as array lengths, and should be rejected by the // caller. func arrayLen(a *ast.ArrayType) int { if a.Len == nil { // Probably a slice? Must be handled by caller. panic("Nil array length in array type") } lenLit, ok := a.Len.(*ast.BasicLit) if !ok { panic("Array has non-literal for length") } len, err := strconv.Atoi(lenLit.Value) if err != nil { panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err)) } return len } // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { // Each field declaration may actually be multiple declarations of the same // type. For example, consider: // // type Point struct { // x, y, z int // } // // We invoke the call-backs once per such instance. Embedded fields are not // allowed, and results in a panic. if len(f.Names) < 1 { panic("Precondition not met: attempted to dispatch on embedded field") } for _, name := range f.Names { switch v := f.Type.(type) { case *ast.Ident: fd.primitive(name, v) case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: switch t := v.Elt.(type) { case *ast.Ident: fd.array(name, t, arrayLen(v)) default: // Should be handled with a better error message during validate. panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) } } } // debugEnabled indicates whether debugging is enabled for gomarshal. func debugEnabled() bool { return *debug } // abort aborts the go_marshal tool with the given error message. func abort(msg string) { if !strings.HasSuffix(msg, "\n") { msg += "\n" } fmt.Print(msg) os.Exit(1) } // abortAt aborts the go_marshal tool with the given error message, with // a reference position to the input source. func abortAt(p token.Position, msg string) { abort(fmt.Sprintf("%v:\n %s\n", p, msg)) } // debugf conditionally prints a debug message. func debugf(f string, a ...interface{}) { if debugEnabled() { fmt.Printf(f, a...) } } // debugfAt conditionally prints a debug message with a reference to a position // in the input source. func debugfAt(p token.Position, f string, a ...interface{}) { if debugEnabled() { fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) } } // emit generates a line of code in the output file. // // emit is a wrapper around writing a formatted string to the output // buffer. emit can be invoked in one of two ways: // // (1) emit("some string") // When emit is called with a single string argument, it is simply copied to // the output buffer without any further formatting. // (2) emit(fmtString, args...) // emit can also be invoked in a similar fashion to *Printf() functions, // where the first argument is a format string. // // Calling emit with a single argument that is not a string will result in a // panic, as the caller's intent is ambiguous. func emit(out io.Writer, indent int, a ...interface{}) { const spacesPerIndentLevel = 4 if len(a) < 1 { panic("emit() called with no arguments") } if indent > 0 { if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { // Writing to the emit output should not fail. Typically the output // is a byte.Buffer; writes to these never fail. panic(err) } } first, ok := a[0].(string) if !ok { // First argument must be either the string to emit (case 1 from // function-level comment), or a format string (case 2). panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) } if len(a) == 1 { // Single string argument. Assume no formatting requested. if _, err := fmt.Fprint(out, first); err != nil { // Writing to out should not fail. panic(err) } return } // Formatting requested. if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { // Writing to out should not fail. panic(err) } } // sourceBuffer represents fragments of generated go source code. // // sourceBuffer provides a convenient way to build up go souce fragments in // memory. May be safely zero-value initialized. Not thread-safe. type sourceBuffer struct { // Current indentation level. indent int // Memory buffer containing contents while they're being generated. b bytes.Buffer } func (b *sourceBuffer) reset() { b.indent = 0 b.b.Reset() } func (b *sourceBuffer) incIndent() { b.indent++ } func (b *sourceBuffer) decIndent() { if b.indent <= 0 { panic("decIndent() without matching incIndent()") } b.indent-- } func (b *sourceBuffer) emit(a ...interface{}) { emit(&b.b, b.indent, a...) } func (b *sourceBuffer) emitNoIndent(a ...interface{}) { emit(&b.b, 0 /*indent*/, a...) } func (b *sourceBuffer) inIndent(body func()) { b.incIndent() body() b.decIndent() } func (b *sourceBuffer) write(out io.Writer) error { _, err := fmt.Fprint(out, b.b.String()) return err } // Write implements io.Writer.Write. func (b *sourceBuffer) Write(buf []byte) (int, error) { return (b.b.Write(buf)) } // importStmt represents a single import statement. type importStmt struct { // Local name of the imported package. name string // Import path. path string // Indicates whether the local name is an alias, or simply the final // component of the path. aliased bool // Indicates whether this import was referenced by generated code. used bool } func newImport(p string) *importStmt { name := path.Base(p) return &importStmt{ name: name, path: p, aliased: false, } } func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. name := path.Base(p) if name == "" || name == "/" || name == "." { panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", f.Position(spec.Path.Pos()), name)) } if spec.Name != nil { name = spec.Name.Name } return &importStmt{ name: name, path: p, aliased: spec.Name != nil, } } func (i *importStmt) String() string { if i.aliased { return fmt.Sprintf("%s \"%s\"", i.name, i.path) } return fmt.Sprintf("\"%s\"", i.path) } func (i *importStmt) markUsed() { i.used = true } func (i *importStmt) equivalent(other *importStmt) bool { return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. type importTable struct { // Map of imports and whether they should be copied to the output. is map[string]*importStmt } func newImportTable() *importTable { return &importTable{ is: make(map[string]*importStmt), } } // Merges import statements from other into i. Collisions in import statements // result in a panic. func (i *importTable) merge(other *importTable) { for name, im := range other.is { dup, ok := i.is[name] if ok { // When merging two imports, if either are marked used, the merged entry // should also be marked used. im.used = im.used || dup.used if !dup.equivalent(im) { panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im)) } } i.is[name] = im } } func (i *importTable) addStmt(s *importStmt) *importStmt { if old, ok := i.is[s.name]; ok && !old.equivalent(s) { // We could theoretically handle the collision by assigning a local name // to one of the imports. However, this is a non-trivial transformation. // Given that collisions should be rare, simply panic on collision. panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name)) } i.is[s.name] = s return s } func (i *importTable) add(s string) *importStmt { n := newImport(s) return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns // false. func (i *importTable) markUsed(n string) bool { if n, ok := i.is[n]; ok { n.markUsed() return true } return false } func (i *importTable) clear() { for _, i := range i.is { i.used = false } } func (i *importTable) write(out io.Writer) error { if len(i.is) == 0 { // Nothing to import, we're done. return nil } imports := make([]string, 0, len(i.is)) for _, i := range i.is { if i.used { imports = append(imports, i.String()) } } sort.Strings(imports) var b sourceBuffer b.emit("import (\n") b.incIndent() for _, i := range imports { b.emit("%s\n", i) } b.decIndent() b.emit(")\n\n") return b.write(out) }