summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal/gomarshal/util.go
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_marshal/gomarshal/util.go')
-rw-r--r--tools/go_marshal/gomarshal/util.go71
1 files changed, 51 insertions, 20 deletions
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 967537abf..a0936e013 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -64,6 +64,12 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
@@ -73,6 +79,25 @@ type fieldDispatcher struct {
unhandled func(n *ast.Ident)
}
+// Precondition: a must have a literal for the array length. Consts and
+// expressions are not allowed as array lengths, and should be rejected by the
+// caller.
+func arrayLen(a *ast.ArrayType) int {
+ if a.Len == nil {
+ // Probably a slice? Must be handled by caller.
+ panic("Nil array length in array type")
+ }
+ lenLit, ok := a.Len.(*ast.BasicLit)
+ if !ok {
+ panic("Array has non-literal for length")
+ }
+ len, err := strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
+ }
+ return len
+}
+
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@@ -96,22 +121,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, t, arrayLen(v))
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
@@ -219,6 +234,11 @@ type sourceBuffer struct {
b bytes.Buffer
}
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
func (b *sourceBuffer) incIndent() {
b.indent++
}
@@ -305,7 +325,7 @@ func (i *importStmt) markUsed() {
}
func (i *importStmt) equivalent(other *importStmt) bool {
- return i == other
+ return i.name == other.name && i.path == other.path && i.aliased == other.aliased
}
// importTable represents a collection of importStmts.
@@ -324,7 +344,7 @@ func newImportTable() *importTable {
// result in a panic.
func (i *importTable) merge(other *importTable) {
for name, im := range other.is {
- if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ if dup, ok := i.is[name]; ok && !dup.equivalent(im) {
panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
}
@@ -332,16 +352,27 @@ func (i *importTable) merge(other *importTable) {
}
}
+func (i *importTable) addStmt(s *importStmt) *importStmt {
+ if old, ok := i.is[s.name]; ok && !old.equivalent(s) {
+ // A collision should always be between an import inserted by the
+ // go-marshal tool and an import from the original source file (assuming
+ // the original source file was valid). We could theoretically handle
+ // the collision by assigning a local name to our import. However, this
+ // would need to be plumbed throughout the generator. Given that
+ // collisions should be rare, simply panic on collision.
+ panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name))
+ }
+ i.is[s.name] = s
+ return s
+}
+
func (i *importTable) add(s string) *importStmt {
n := newImport(s)
- i.is[n.name] = n
- return n
+ return i.addStmt(n)
}
func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- n := newImportFromSpec(spec, f)
- i.is[n.name] = n
- return n
+ return i.addStmt(newImportFromSpec(spec, f))
}
// Marks the import named n as used. If no such import is in the table, returns