// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gomarshal import ( "bytes" "flag" "fmt" "go/ast" "go/token" "io" "os" "path" "reflect" "sort" "strings" ) var debug = flag.Bool("debug", false, "enables debugging output") // receiverName returns an appropriate receiver name given a type spec. func receiverName(t *ast.TypeSpec) string { if len(t.Name.Name) < 1 { // Zero length type name? panic("unreachable") } return strings.ToLower(t.Name.Name[:1]) } // kindString returns a user-friendly representation of an AST expr type. func kindString(e ast.Expr) string { switch e.(type) { case *ast.Ident: return "scalar" case *ast.ArrayType: return "array" case *ast.StructType: return "struct" case *ast.StarExpr: return "pointer" case *ast.FuncType: return "function" case *ast.InterfaceType: return "interface" case *ast.MapType: return "map" case *ast.ChanType: return "channel" default: return reflect.TypeOf(e).String() } } func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) { for _, field := range st.Fields.List { fn(field) } } // fieldDispatcher is a collection of callbacks for handling different types of // fields in a struct declaration. type fieldDispatcher struct { primitive func(n, t *ast.Ident) selector func(n, tX, tSel *ast.Ident) array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) unhandled func(n *ast.Ident) } // Precondition: All dispatch callbacks that will be invoked must be // provided. Embedded fields are not allowed, len(f.Names) >= 1. func (fd fieldDispatcher) dispatch(f *ast.Field) { // Each field declaration may actually be multiple declarations of the same // type. For example, consider: // // type Point struct { // x, y, z int // } // // We invoke the call-backs once per such instance. Embedded fields are not // allowed, and results in a panic. if len(f.Names) < 1 { panic("Precondition not met: attempted to dispatch on embedded field") } for _, name := range f.Names { switch v := f.Type.(type) { case *ast.Ident: fd.primitive(name, v) case *ast.SelectorExpr: fd.selector(name, v.X.(*ast.Ident), v.Sel) case *ast.ArrayType: switch t := v.Elt.(type) { case *ast.Ident: fd.array(name, v, t) default: // Should be handled with a better error message during validate. panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t)) } default: fd.unhandled(name) } } } // debugEnabled indicates whether debugging is enabled for gomarshal. func debugEnabled() bool { return *debug } // abort aborts the go_marshal tool with the given error message. func abort(msg string) { if !strings.HasSuffix(msg, "\n") { msg += "\n" } fmt.Print(msg) os.Exit(1) } // abortAt aborts the go_marshal tool with the given error message, with // a reference position to the input source. func abortAt(p token.Position, msg string) { abort(fmt.Sprintf("%v:\n %s\n", p, msg)) } // debugf conditionally prints a debug message. func debugf(f string, a ...interface{}) { if debugEnabled() { fmt.Printf(f, a...) } } // debugfAt conditionally prints a debug message with a reference to a position // in the input source. func debugfAt(p token.Position, f string, a ...interface{}) { if debugEnabled() { fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...)) } } // emit generates a line of code in the output file. // // emit is a wrapper around writing a formatted string to the output // buffer. emit can be invoked in one of two ways: // // (1) emit("some string") // When emit is called with a single string argument, it is simply copied to // the output buffer without any further formatting. // (2) emit(fmtString, args...) // emit can also be invoked in a similar fashion to *Printf() functions, // where the first argument is a format string. // // Calling emit with a single argument that is not a string will result in a // panic, as the caller's intent is ambiguous. func emit(out io.Writer, indent int, a ...interface{}) { const spacesPerIndentLevel = 4 if len(a) < 1 { panic("emit() called with no arguments") } if indent > 0 { if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil { // Writing to the emit output should not fail. Typically the output // is a byte.Buffer; writes to these never fail. panic(err) } } first, ok := a[0].(string) if !ok { // First argument must be either the string to emit (case 1 from // function-level comment), or a format string (case 2). panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0])) } if len(a) == 1 { // Single string argument. Assume no formatting requested. if _, err := fmt.Fprint(out, first); err != nil { // Writing to out should not fail. panic(err) } return } // Formatting requested. if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil { // Writing to out should not fail. panic(err) } } // sourceBuffer represents fragments of generated go source code. // // sourceBuffer provides a convenient way to build up go souce fragments in // memory. May be safely zero-value initialized. Not thread-safe. type sourceBuffer struct { // Current indentation level. indent int // Memory buffer containing contents while they're being generated. b bytes.Buffer } func (b *sourceBuffer) reset() { b.indent = 0 b.b.Reset() } func (b *sourceBuffer) incIndent() { b.indent++ } func (b *sourceBuffer) decIndent() { if b.indent <= 0 { panic("decIndent() without matching incIndent()") } b.indent-- } func (b *sourceBuffer) emit(a ...interface{}) { emit(&b.b, b.indent, a...) } func (b *sourceBuffer) emitNoIndent(a ...interface{}) { emit(&b.b, 0 /*indent*/, a...) } func (b *sourceBuffer) inIndent(body func()) { b.incIndent() body() b.decIndent() } func (b *sourceBuffer) write(out io.Writer) error { _, err := fmt.Fprint(out, b.b.String()) return err } // Write implements io.Writer.Write. func (b *sourceBuffer) Write(buf []byte) (int, error) { return (b.b.Write(buf)) } // importStmt represents a single import statement. type importStmt struct { // Local name of the imported package. name string // Import path. path string // Indicates whether the local name is an alias, or simply the final // component of the path. aliased bool // Indicates whether this import was referenced by generated code. used bool // AST node and file set representing the import statement, if any. These // are only non-nil if the import statement originates from an input source // file. spec *ast.ImportSpec fset *token.FileSet } func newImport(p string) *importStmt { name := path.Base(p) return &importStmt{ name: name, path: p, aliased: false, } } func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path. name := path.Base(p) if name == "" || name == "/" || name == "." { panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)", f.Position(spec.Path.Pos()), name)) } if spec.Name != nil { name = spec.Name.Name } return &importStmt{ name: name, path: p, aliased: spec.Name != nil, spec: spec, fset: f, } } // String implements fmt.Stringer.String. This generates a string for the import // statement appropriate for writing directly to generated code. func (i *importStmt) String() string { if i.aliased { return fmt.Sprintf("%s %q", i.name, i.path) } return fmt.Sprintf("%q", i.path) } // debugString returns a debug string representing an import statement. This // representation is not valid golang code and is used for debugging output. func (i *importStmt) debugString() string { if i.spec != nil && i.fset != nil { return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i) } return fmt.Sprintf("(go-marshal import): %s", i) } func (i *importStmt) markUsed() { i.used = true } func (i *importStmt) equivalent(other *importStmt) bool { return i.name == other.name && i.path == other.path && i.aliased == other.aliased } // importTable represents a collection of importStmts. // // An importTable may contain multiple import statements referencing the same // local name. All import statements aliasing to the same local name are // technically ambiguous, as if such an import name is used in the generated // code, it's not clear which import statement it refers to. We ignore any // potential collisions until actually writing the import table to the generated // source file. See importTable.write. // // Given the following import statements across all the files comprising a // package marshalled: // // "sync" // "pkg/sync" // "pkg/sentry/kernel" // ktime "pkg/sentry/kernel/time" // // An importTable representing them would look like this: // // importTable { // is: map[string][]*importStmt { // "sync": []*importStmt{ // importStmt{name:"sync", path:"sync", aliased:false} // importStmt{name:"sync", path:"pkg/sync", aliased:false} // }, // "kernel": []*importStmt{importStmt{ // name: "kernel", // path: "pkg/sentry/kernel", // aliased: false // }}, // "ktime": []*importStmt{importStmt{ // name: "ktime", // path: "pkg/sentry/kernel/time", // aliased: true, // }}, // } // } // // Note that the local name "sync" is assigned to two different import // statements. This is possible if the import statements are from different // source files in the same package. // // Since go-marshal generates a single output file per package regardless of the // number of input files, if "sync" is referenced by any generated code, it's // unclear which import statement "sync" refers to. While it's theoretically // possible to resolve this by assigning a unique local alias to each instance // of the sync package, go-marshal currently aborts when it encounters such an // ambiguity. // // TODO(b/151478251): importTable considers the final component of an import // path to be the package name, but this is only a convention. The actual // package name is determined by the package statement in the source files for // the package. type importTable struct { // Map of imports and whether they should be copied to the output. is map[string][]*importStmt } func newImportTable() *importTable { return &importTable{ is: make(map[string][]*importStmt), } } // Merges import statements from other into i. func (i *importTable) merge(other *importTable) { for name, ims := range other.is { i.is[name] = append(i.is[name], ims...) } } func (i *importTable) addStmt(s *importStmt) *importStmt { i.is[s.name] = append(i.is[s.name], s) return s } func (i *importTable) add(s string) *importStmt { n := newImport(s) return i.addStmt(n) } func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt { return i.addStmt(newImportFromSpec(spec, f)) } // Marks the import named n as used. If no such import is in the table, returns // false. func (i *importTable) markUsed(n string) bool { if ns, ok := i.is[n]; ok { for _, n := range ns { n.markUsed() } return true } return false } func (i *importTable) clear() { for _, is := range i.is { for _, i := range is { i.used = false } } } func (i *importTable) write(out io.Writer) error { if len(i.is) == 0 { // Nothing to import, we're done. return nil } imports := make([]string, 0, len(i.is)) for name, is := range i.is { var lastUsed *importStmt var ambiguous bool for _, i := range is { if i.used { if lastUsed != nil { if !i.equivalent(lastUsed) { ambiguous = true } } lastUsed = i } } if ambiguous { // We have two or more import statements across the different source // files that share a local name, and at least one of these imports // are used by the generated code. This ambiguity can't be resolved // by go-marshal and requires the user intervention. Dump a list of // the colliding import statements and let the user modify the input // files as appropriate. var b strings.Builder fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name) fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name) // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't // be true. Therefore the slicing below is safe. for _, i := range is[:len(is)-1] { fmt.Fprintf(&b, " %v\n", i.debugString()) } fmt.Fprintf(&b, " %v", is[len(is)-1].debugString()) panic(b.String()) } if lastUsed != nil { imports = append(imports, lastUsed.String()) } } sort.Strings(imports) var b sourceBuffer b.emit("import (\n") b.incIndent() for _, i := range imports { b.emit("%s\n", i) } b.decIndent() b.emit(")\n\n") return b.write(out) }