summaryrefslogtreecommitdiffhomepage
path: root/tools/go_marshal
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2020-04-01 16:50:16 -0700
committergVisor bot <gvisor-bot@google.com>2020-04-01 16:51:28 -0700
commit1561ae3037e5a3efdd26320f229fc4c602258dba (patch)
tree23749691318ccd454c6799bc8bd5fe1ff7a8edcb /tools/go_marshal
parentaecd3a25a99bc04fe2c032ea3422f10b2dba3256 (diff)
go-marshal: Allow array lens to be consts and simple expressions.
Previously, go-marshal only allowed literals for array lengths. However, it's very common for ABI structs to have a fix-sized array whose length is defined by a constant; for example PATH_MAX. Having to convert all such arrays to have literal lengths is too awkward. PiperOrigin-RevId: 304289345
Diffstat (limited to 'tools/go_marshal')
-rw-r--r--tools/go_marshal/gomarshal/generator.go2
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go50
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go20
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go41
-rw-r--r--tools/go_marshal/gomarshal/util.go24
-rw-r--r--tools/go_marshal/test/test.go28
6 files changed, 108 insertions, 57 deletions
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 43e668b63..177013dbb 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -356,7 +356,7 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf
case *ast.ArrayType:
i.validateArrayNewtype(t.spec.Name, ty)
// After validate, we can safely call arrayLen.
- i.emitMarshallableForArrayNewtype(t.spec.Name, ty.Elt.(*ast.Ident), arrayLen(ty))
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
if t.slice != nil {
abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 8f1c27145..e3c3dac63 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -15,8 +15,10 @@
package gomarshal
import (
+ "fmt"
"go/ast"
"go/token"
+ "strings"
)
// interfaceGenerator generates marshalling interfaces for a single type.
@@ -224,3 +226,51 @@ func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
g.emit("// must live until the use above.\n")
g.emit("runtime.KeepAlive(%s)\n", ptrVar)
}
+
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+
+ fmt.Fprintf(b, "%s", e.Op)
+
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
index 5ba74a606..8d6f102d5 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -27,20 +27,12 @@ func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType
g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
}
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don't use consts or expressions"))
- }
-
if _, ok := a.Elt.(*ast.Ident); !ok {
g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
}
-
- if arrayLen(a) <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
}
-func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident, len int) {
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
g.recordUsedImport("io")
g.recordUsedImport("marshal")
g.recordUsedImport("reflect")
@@ -49,13 +41,15 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.recordUsedImport("unsafe")
g.recordUsedImport("usermem")
+ lenExpr := g.arrayLenExpr(a)
+
g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
g.inIndent(func() {
if size, dynamic := g.scalarSize(elt); !dynamic {
- g.emit("return %d\n", size*len)
+ g.emit("return %d * %s\n", size, lenExpr)
} else {
- g.emit("return (*%s)(nil).SizeBytes() * %d\n", n.Name, len)
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
}
})
g.emit("}\n\n")
@@ -63,7 +57,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
})
@@ -74,7 +68,7 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n, elt *ast.Ident,
g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
g.inIndent(func() {
- g.emit("for idx := 0; idx < %d; idx++ {\n", len)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
})
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
index e837f58db..4236e978e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -62,8 +62,8 @@ func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType
// No validation to perform on selector fields. However this
// callback must still be provided.
},
- array: func(n, _ *ast.Ident, len int) {
- g.validateArrayNewtype(n, f.Type.(*ast.ArrayType))
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
},
unhandled: func(_ *ast.Ident) {
g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
@@ -112,16 +112,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
g.recordUsedMarshallable(tName)
dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
},
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
} else {
g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
}
},
}.dispatch)
@@ -169,22 +166,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
}
g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
- array: func(n, t *ast.Ident, size int) {
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can reference here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
@@ -224,22 +222,23 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
}
g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
- array: func(n, t *ast.Ident, size int) {
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
} else {
// We can't use shiftDynamic here because we don't have
// an instance of the dynamic type we can referece here
// (since the version in this struct is anonymous). Use
// a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
}
return
}
- g.emit("for idx := 0; idx < %d; idx++ {\n", size)
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
g.inIndent(func() {
g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 96025ff39..d94314302 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
- "strconv"
"strings"
)
@@ -75,29 +74,10 @@ func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
- array func(n, t *ast.Ident, size int)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
-// Precondition: a must have a literal for the array length. Consts and
-// expressions are not allowed as array lengths, and should be rejected by the
-// caller.
-func arrayLen(a *ast.ArrayType) int {
- if a.Len == nil {
- // Probably a slice? Must be handled by caller.
- panic("Nil array length in array type")
- }
- lenLit, ok := a.Len.(*ast.BasicLit)
- if !ok {
- panic("Array has non-literal for length")
- }
- len, err := strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(fmt.Sprintf("Failed to parse array length '%s' as number: %v", lenLit.Value, err))
- }
- return len
-}
-
// Precondition: All dispatch callbacks that will be invoked must be
// provided. Embedded fields are not allowed, len(f.Names) >= 1.
func (fd fieldDispatcher) dispatch(f *ast.Field) {
@@ -123,7 +103,7 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.ArrayType:
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, arrayLen(v))
+ fd.array(name, v, t)
default:
// Should be handled with a better error message during validate.
panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 43df73545..f75ca1b7f 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -146,3 +146,31 @@ type SignalSet uint64
//
// +marshal slice:SignalSetAliasSlice
type SignalSetAlias SignalSet
+
+const sizeA = 64
+const sizeB = 8
+
+// TestArray is a test data structure on an array with a constant length.
+//
+// +marshal
+type TestArray [sizeA]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// constants for the array length.
+//
+// +marshal
+type TestArray2 [sizeA * sizeB]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// mixed constants and literals for the array length.
+//
+// +marshal
+type TestArray3 [sizeA*sizeB + 12]int32
+
+// Type9 is a test data type containing an array with a non-literal length.
+//
+// +marshal
+type Type9 struct {
+ x int64
+ y [sizeA]int32
+}