summaryrefslogtreecommitdiffhomepage
path: root/tools
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /tools
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'tools')
-rw-r--r--tools/go_generics/BUILD46
-rw-r--r--tools/go_generics/defs.bzl152
-rw-r--r--tools/go_generics/generics.go274
-rw-r--r--tools/go_generics/generics_tests/all_stmts/input.go290
-rw-r--r--tools/go_generics/generics_tests/all_stmts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/all_stmts/output/output.go288
-rw-r--r--tools/go_generics/generics_tests/all_types/input.go43
-rw-r--r--tools/go_generics/generics_tests/all_types/lib/lib.go17
-rw-r--r--tools/go_generics/generics_tests/all_types/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/all_types/output/output.go41
-rw-r--r--tools/go_generics/generics_tests/consts/input.go26
-rw-r--r--tools/go_generics/generics_tests/consts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/consts/output/output.go26
-rw-r--r--tools/go_generics/generics_tests/imports/input.go24
-rw-r--r--tools/go_generics/generics_tests/imports/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/imports/output/output.go27
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/input.go37
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/output/output.go29
-rw-r--r--tools/go_generics/generics_tests/simple/input.go45
-rw-r--r--tools/go_generics/generics_tests/simple/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/simple/output/output.go43
-rw-r--r--tools/go_generics/globals/BUILD13
-rw-r--r--tools/go_generics/globals/globals_visitor.go588
-rw-r--r--tools/go_generics/globals/scope.go80
-rwxr-xr-xtools/go_generics/go_generics_unittest.sh70
-rw-r--r--tools/go_generics/imports.go150
-rw-r--r--tools/go_generics/merge.go139
-rw-r--r--tools/go_generics/remove.go105
-rw-r--r--tools/go_generics/rules_tests/BUILD43
-rw-r--r--tools/go_generics/rules_tests/template.go42
-rw-r--r--tools/go_generics/rules_tests/template_test.go48
-rw-r--r--tools/go_stateify/BUILD9
-rw-r--r--tools/go_stateify/defs.bzl77
-rw-r--r--tools/go_stateify/main.go386
35 files changed, 3164 insertions, 0 deletions
diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD
new file mode 100644
index 000000000..1afc58625
--- /dev/null
+++ b/tools/go_generics/BUILD
@@ -0,0 +1,46 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+go_binary(
+ name = "go_generics",
+ srcs = [
+ "generics.go",
+ "imports.go",
+ "remove.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = ["//tools/go_generics/globals"],
+)
+
+go_binary(
+ name = "go_merge",
+ srcs = [
+ "merge.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "go_generics_tests",
+ srcs = glob(["generics_tests/**"]) + [":go_generics"],
+ outs = ["go_generics_tests.tgz"],
+ cmd = "tar -czvhf $@ $(SRCS)",
+)
+
+genrule(
+ name = "go_generics_test_bundle",
+ srcs = [
+ ":go_generics_tests.tgz",
+ ":go_generics_unittest.sh",
+ ],
+ outs = ["go_generics_test.sh"],
+ cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@",
+ executable = True,
+)
+
+sh_test(
+ name = "go_generics_test",
+ size = "small",
+ srcs = ["go_generics_test.sh"],
+)
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
new file mode 100644
index 000000000..0b2467805
--- /dev/null
+++ b/tools/go_generics/defs.bzl
@@ -0,0 +1,152 @@
+def _go_template_impl(ctx):
+ input = ctx.files.srcs
+ output = ctx.outputs.out
+
+ args = ["-o=%s" % output.path] + [f.path for f in input]
+
+ ctx.actions.run(
+ inputs = input,
+ outputs = [output],
+ mnemonic = "GoGenericsTemplate",
+ progress_message = "Building Go template %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+ return struct(
+ types = ctx.attr.types,
+ opt_types = ctx.attr.opt_types,
+ consts = ctx.attr.consts,
+ opt_consts = ctx.attr.opt_consts,
+ deps = ctx.attr.deps,
+ file = output,
+ )
+
+"""
+Generates a Go template from a set of Go files.
+
+A Go template is similar to a go library, except that it has certain types that
+can be replaced before usage. For example, one could define a templatized List
+struct, whose elements are of type T, then instantiate that template for
+T=segment, where "segment" is the concrete type.
+
+Args:
+ name: the name of the template.
+ srcs: the list of source files that comprise the template.
+ types: the list of generic types in the template that are required to be specified.
+ opt_types: the list of generic types in the template that can but aren't required to be specified.
+ consts: the list of constants in the template that are required to be specified.
+ opt_consts: the list of constants in the template that can but aren't required to be specified.
+ deps: the list of dependencies.
+"""
+
+go_template = rule(
+ attrs = {
+ "srcs": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ "deps": attr.label_list(allow_files = True),
+ "types": attr.string_list(),
+ "opt_types": attr.string_list(),
+ "consts": attr.string_list(),
+ "opt_consts": attr.string_list(),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_generics:go_merge"),
+ ),
+ },
+ outputs = {
+ "out": "%{name}_template.go",
+ },
+ implementation = _go_template_impl,
+)
+
+def _go_template_instance_impl(ctx):
+ template = ctx.attr.template
+ output = ctx.outputs.out
+
+ # Check that all required types are defined.
+ for t in template.types:
+ if t not in ctx.attr.types:
+ fail("Missing value for type %s in %s" % (t, ctx.attr.template.label))
+
+ # Check that all defined types are expected by the template.
+ for t in ctx.attr.types:
+ if (t not in template.types) and (t not in template.opt_types):
+ fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
+
+ # Check that all required consts are defined.
+ for t in template.consts:
+ if t not in ctx.attr.consts:
+ fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label))
+
+ # Check that all defined consts are expected by the template.
+ for t in ctx.attr.consts:
+ if (t not in template.consts) and (t not in template.opt_consts):
+ fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
+
+ # Build the argument list.
+ args = ["-i=%s" % template.file.path, "-o=%s" % output.path]
+ args += ["-p=%s" % ctx.attr.package]
+
+ if len(ctx.attr.prefix) > 0:
+ args += ["-prefix=%s" % ctx.attr.prefix]
+
+ if len(ctx.attr.suffix) > 0:
+ args += ["-suffix=%s" % ctx.attr.suffix]
+
+ args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()]
+ args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()]
+ args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()]
+
+ ctx.actions.run(
+ inputs = [template.file],
+ outputs = [output],
+ mnemonic = "GoGenericsInstance",
+ progress_message = "Building Go template instance %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+ # TODO: How can we get the dependencies out?
+ return struct(
+ files = depset([output])
+ )
+
+"""
+Instantiates a Go template by replacing all generic types with concrete ones.
+
+Args:
+ name: the name of the template instance.
+ template: the label of the template to be instatiated.
+ prefix: a prefix to be added to globals in the template.
+ suffix: a suffix to be added to global in the template.
+ types: the map from generic type names to concrete ones.
+ consts: the map from constant names to their values.
+ imports: the map from imports used in types/consts to their import paths.
+ package: the name of the package the instantiated template will be compiled into.
+"""
+
+go_template_instance = rule(
+ attrs = {
+ "template": attr.label(
+ mandatory = True,
+ providers = ["types"],
+ ),
+ "prefix": attr.string(),
+ "suffix": attr.string(),
+ "types": attr.string_dict(),
+ "consts": attr.string_dict(),
+ "imports": attr.string_dict(),
+ "package": attr.string(mandatory = True),
+ "out": attr.output(mandatory = True),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_generics"),
+ ),
+ },
+ implementation = _go_template_instance_impl,
+)
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
new file mode 100644
index 000000000..033923442
--- /dev/null
+++ b/tools/go_generics/generics.go
@@ -0,0 +1,274 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_generics reads a Go source file and writes a new version of that file with
+// a few transformations applied to each. Namely:
+//
+// 1. Global types can be explicitly renamed with the -t option. For example,
+// if -t=A=B is passed in, all references to A will be replaced with
+// references to B; a function declaration like:
+//
+// func f(arg *A)
+//
+// would be renamed to:
+//
+// fun f(arg *B)
+//
+// 2. Global type definitions and their method sets will be removed when they're
+// being renamed with -t. For example, if -t=A=B is passed in, the following
+// definition and methods that existed in the input file wouldn't exist at
+// all in the output file:
+//
+// type A struct{}
+//
+// func (*A) f() {}
+//
+// 3. All global types, variables, constants and functions (not methods) are
+// prefixed and suffixed based on the option -prefix and -suffix arguments.
+// For example, if -suffix=A is passed in, the following globals:
+//
+// func f()
+// type t struct{}
+//
+// would be renamed to:
+//
+// func fA()
+// type tA struct{}
+//
+// Some special tags are also modified. For example:
+//
+// "state:.(t)"
+//
+// would become:
+//
+// "state:.(tA)"
+//
+// 4. The package is renamed to the value via the -p argument.
+// 5. Value of constants can be modified with -c argument.
+//
+// Note that not just the top-level declarations are renamed, all references to
+// them are also properly renamed as well, taking into account visibility rules
+// and shadowing. For example, if -suffix=A is passed in, the following:
+//
+// var b = 100
+//
+// func f() {
+// g(b)
+// b := 0
+// g(b)
+// }
+//
+// Would be replaced with:
+//
+// var bA = 100
+//
+// func f() {
+// g(bA)
+// b := 0
+// g(b)
+// }
+//
+// Note that the second call to g() kept "b" as an argument because it refers to
+// the local variable "b".
+//
+// Unfortunately, go_generics does not handle anonymous fields with renamed types.
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "regexp"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/tools/go_generics/globals"
+)
+
+var (
+ input = flag.String("i", "", "input `file`")
+ output = flag.String("o", "", "output `file`")
+ suffix = flag.String("suffix", "", "`suffix` to add to each global symbol")
+ prefix = flag.String("prefix", "", "`prefix` to add to each global symbol")
+ packageName = flag.String("p", "main", "output package `name`")
+ printAST = flag.Bool("ast", false, "prints the AST")
+ types = make(mapValue)
+ consts = make(mapValue)
+ imports = make(mapValue)
+)
+
+// mapValue implements flag.Value. We use a mapValue flag instead of a regular
+// string flag when we want to allow more than one instance of the flag. For
+// example, we allow several "-t A=B" arguments, and will rename them all.
+type mapValue map[string]string
+
+func (m mapValue) String() string {
+ var b bytes.Buffer
+ first := true
+ for k, v := range m {
+ if !first {
+ b.WriteRune(',')
+ } else {
+ first = false
+ }
+ b.WriteString(k)
+ b.WriteRune('=')
+ b.WriteString(v)
+ }
+ return b.String()
+}
+
+func (m mapValue) Set(s string) error {
+ sep := strings.Index(s, "=")
+ if sep == -1 {
+ return fmt.Errorf("missing '=' from '%s'", s)
+ }
+
+ m[s[:sep]] = s[sep+1:]
+
+ return nil
+}
+
+// stateTagRegexp matches against the 'typed' state tags.
+var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`)
+
+var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+
+ flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.")
+ flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.")
+ flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.")
+ flag.Parse()
+
+ if *input == "" || *output == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Parse the input file.
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, *input, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+
+ // Print the AST if requested.
+ if *printAST {
+ ast.Print(fset, f)
+ }
+
+ cmap := ast.NewCommentMap(fset, f, f.Comments)
+
+ // Update imports based on what's used in types and consts.
+ maps := []mapValue{types, consts}
+ importDecl, err := updateImports(maps, imports)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+ types = maps[0]
+ consts = maps[1]
+
+ // Reassign all specified constants.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.CONST {
+ continue
+ }
+
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ValueSpec)
+ for i, id := range s.Names {
+ if n, ok := consts[id.Name]; ok {
+ s.Values[i] = &ast.BasicLit{Value: n}
+ }
+ }
+ }
+ }
+
+ // Go through all globals and their uses in the AST and rename the types
+ // with explicitly provided names, and rename all types, variables,
+ // consts and functions with the provided prefix and suffix.
+ globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) {
+ if n, ok := types[ident.Name]; ok && kind == globals.KindType {
+ ident.Name = n
+ } else {
+ switch kind {
+ case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
+ ident.Name = *prefix + ident.Name + *suffix
+ case globals.KindTag:
+ // Modify the state tag appropriately.
+ if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil {
+ if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil {
+ ident.Name = m[1] + `state:".(` + t[1] + *prefix + t[2] + *suffix + t[3] + `)"` + m[3]
+ }
+ }
+ }
+ }
+ })
+
+ // Remove the definition of all types that are being remapped.
+ set := make(typeSet)
+ for _, v := range types {
+ set[v] = struct{}{}
+ }
+ removeTypes(set, f)
+
+ // Add the new imports, if any, to the top.
+ if importDecl != nil {
+ newDecls := make([]ast.Decl, 0, len(f.Decls)+1)
+ newDecls = append(newDecls, importDecl)
+ newDecls = append(newDecls, f.Decls...)
+ f.Decls = newDecls
+ }
+
+ // Update comments to remove the ones potentially associated with the
+ // type T that we removed.
+ f.Comments = cmap.Filter(f).Comments()
+
+ // If there are file (package) comments, delete them.
+ if f.Doc != nil {
+ for i, cg := range f.Comments {
+ if cg == f.Doc {
+ f.Comments = append(f.Comments[:i], f.Comments[i+1:]...)
+ break
+ }
+ }
+ }
+
+ // Write the output file.
+ f.Name.Name = *packageName
+
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, f); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+
+ if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ }
+}
diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/generics_tests/all_stmts/input.go
new file mode 100644
index 000000000..870af3b6c
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/input.go
@@ -0,0 +1,290 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "sync"
+)
+
+type T int
+
+func h(T) {
+}
+
+type s struct {
+ a, b int
+ c []int
+}
+
+func g(T) *s {
+ return &s{}
+}
+
+func f() (T, []int) {
+ // Branch.
+ goto T
+ goto R
+
+ // Labeled.
+T:
+ _ = T(0)
+
+ // Empty.
+R:
+ ;
+
+ // Assignment with definition.
+ a, b, c := T(1), T(2), T(3)
+ _, _, _ = a, b, c
+
+ // Assignment without definition.
+ g(T(0)).a, g(T(1)).b, c = int(T(1)), int(T(2)), T(3)
+ _, _, _ = a, b, c
+
+ // Block.
+ {
+ var T T
+ T = 0
+ _ = T
+ }
+
+ // Declarations.
+ type Type T
+ const Const T = 10
+ var g1 func(T, int, ...T) (int, T)
+ var v T
+ var w = T(0)
+ {
+ var T struct {
+ f []T
+ }
+ _ = T
+ }
+
+ // Defer.
+ defer g1(T(0), 1)
+
+ // Expression.
+ h(v + w + T(1))
+
+ // For statements.
+ for i := T(0); i < T(10); i++ {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ for {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ // Go.
+ go g1(T(0), 1)
+
+ // If statements.
+ if a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else if b := T(0); b != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else if T := T(0); T != 1 {
+ T++
+ } else {
+ T--
+ }
+
+ if a := T(0); a != T(1) {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ } else {
+ var T func(int) T
+ v := T(0)
+ _ = v
+ }
+
+ // Inc/Dec statements.
+ (*(*T)(nil))++
+ (*(*T)(nil))--
+
+ // Range statements.
+ for g(T(0)).a, g(T(1)).b = range g(T(10)).c {
+ var d T
+ _ = d
+ }
+
+ for T, b := range g(T(10)).c {
+ _ = T
+ _ = b
+ }
+
+ // Select statement.
+ {
+ var fch func(T) chan int
+
+ select {
+ case <-fch(T(30)):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ case T := <-fch(T(30)):
+ T = 0
+ _ = T
+ case g(T(0)).a = <-fch(T(30)):
+ var T T
+ T = 0
+ _ = T
+ case fch(T(30)) <- int(T(0)):
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ // Send statements.
+ {
+ var ch chan T
+ var fch func(T) chan int
+
+ ch <- T(0)
+ fch(T(1)) <- g(T(10)).a
+ }
+
+ // Switch statements.
+ {
+ var a T
+ var b int
+ switch {
+ case a == T(0):
+ var T T
+ T = 0
+ _ = T
+ case a < T(0), b < g(T(10)).a:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ switch T(g(T(10)).a) {
+ case T(0):
+ var T T
+ T = 0
+ _ = T
+ case T(1), T(g(T(10)).a):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch b := g(T(10)); T(b.a) + T(10) {
+ case T(0):
+ var T T
+ T = 0
+ _ = T
+ case T(1), T(g(T(10)).a):
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ // Type switch statements.
+ {
+ var interfaceFunc func(T) interface{}
+
+ switch interfaceFunc(T(0)).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch x := interfaceFunc(T(0)).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+
+ switch t := T(0); x := interfaceFunc(T(0) + t).(type) {
+ case *T, T, int:
+ var T T
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **T:
+ var T T
+ T = 0
+ _ = T
+ default:
+ var T T
+ T = 0
+ _ = T
+ }
+ }
+
+ // Return statement.
+ return T(10), g(T(11)).c
+}
diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt
new file mode 100644
index 000000000..c9d0e09bf
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/opts.txt
@@ -0,0 +1 @@
+-t=T=Q
diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/generics_tests/all_stmts/output/output.go
new file mode 100644
index 000000000..e4e670bf1
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_stmts/output/output.go
@@ -0,0 +1,288 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "sync"
+)
+
+func h(Q) {
+}
+
+type s struct {
+ a, b int
+ c []int
+}
+
+func g(Q) *s {
+ return &s{}
+}
+
+func f() (Q, []int) {
+ // Branch.
+ goto T
+ goto R
+
+ // Labeled.
+T:
+ _ = Q(0)
+
+ // Empty.
+R:
+ ;
+
+ // Assignment with definition.
+ a, b, c := Q(1), Q(2), Q(3)
+ _, _, _ = a, b, c
+
+ // Assignment without definition.
+ g(Q(0)).a, g(Q(1)).b, c = int(Q(1)), int(Q(2)), Q(3)
+ _, _, _ = a, b, c
+
+ // Block.
+ {
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ // Declarations.
+ type Type Q
+ const Const Q = 10
+ var g1 func(Q, int, ...Q) (int, Q)
+ var v Q
+ var w = Q(0)
+ {
+ var T struct {
+ f []Q
+ }
+ _ = T
+ }
+
+ // Defer.
+ defer g1(Q(0), 1)
+
+ // Expression.
+ h(v + w + Q(1))
+
+ // For statements.
+ for i := Q(0); i < Q(10); i++ {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ for {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ // Go.
+ go g1(Q(0), 1)
+
+ // If statements.
+ if a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else if b := Q(0); b != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else if T := Q(0); T != 1 {
+ T++
+ } else {
+ T--
+ }
+
+ if a := Q(0); a != Q(1) {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ } else {
+ var T func(int) Q
+ v := T(0)
+ _ = v
+ }
+
+ // Inc/Dec statements.
+ (*(*Q)(nil))++
+ (*(*Q)(nil))--
+
+ // Range statements.
+ for g(Q(0)).a, g(Q(1)).b = range g(Q(10)).c {
+ var d Q
+ _ = d
+ }
+
+ for T, b := range g(Q(10)).c {
+ _ = T
+ _ = b
+ }
+
+ // Select statement.
+ {
+ var fch func(Q) chan int
+
+ select {
+ case <-fch(Q(30)):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ case T := <-fch(Q(30)):
+ T = 0
+ _ = T
+ case g(Q(0)).a = <-fch(Q(30)):
+ var T Q
+ T = 0
+ _ = T
+ case fch(Q(30)) <- int(Q(0)):
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ // Send statements.
+ {
+ var ch chan Q
+ var fch func(Q) chan int
+
+ ch <- Q(0)
+ fch(Q(1)) <- g(Q(10)).a
+ }
+
+ // Switch statements.
+ {
+ var a Q
+ var b int
+ switch {
+ case a == Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case a < Q(0), b < g(Q(10)).a:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ switch Q(g(Q(10)).a) {
+ case Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case Q(1), Q(g(Q(10)).a):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch b := g(Q(10)); Q(b.a) + Q(10) {
+ case Q(0):
+ var T Q
+ T = 0
+ _ = T
+ case Q(1), Q(g(Q(10)).a):
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ // Type switch statements.
+ {
+ var interfaceFunc func(Q) interface{}
+
+ switch interfaceFunc(Q(0)).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch x := interfaceFunc(Q(0)).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+
+ switch t := Q(0); x := interfaceFunc(Q(0) + t).(type) {
+ case *Q, Q, int:
+ var T Q
+ T = 0
+ _ = T
+ _ = x
+ case sync.Mutex, **Q:
+ var T Q
+ T = 0
+ _ = T
+ default:
+ var T Q
+ T = 0
+ _ = T
+ }
+ }
+
+ // Return statement.
+ return Q(10), g(Q(11)).c
+}
diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/generics_tests/all_types/input.go
new file mode 100644
index 000000000..3a8643e3d
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/input.go
@@ -0,0 +1,43 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import "./lib"
+
+type T int
+
+type newType struct {
+ a T
+ b lib.T
+ c *T
+ d (T)
+ e chan T
+ f <-chan T
+ g chan<- T
+ h []T
+ i [10]T
+ j map[T]T
+ k func(T, T) (T, T)
+ l interface {
+ f(T)
+ }
+ m struct {
+ T
+ a T
+ }
+}
+
+func f(...T) {
+}
diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/generics_tests/all_types/lib/lib.go
new file mode 100644
index 000000000..d3911d12d
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/lib/lib.go
@@ -0,0 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lib
+
+type T int32
diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt
new file mode 100644
index 000000000..c9d0e09bf
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/opts.txt
@@ -0,0 +1 @@
+-t=T=Q
diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/generics_tests/all_types/output/output.go
new file mode 100644
index 000000000..b89840936
--- /dev/null
+++ b/tools/go_generics/generics_tests/all_types/output/output.go
@@ -0,0 +1,41 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import "./lib"
+
+type newType struct {
+ a Q
+ b lib.T
+ c *Q
+ d (Q)
+ e chan Q
+ f <-chan Q
+ g chan<- Q
+ h []Q
+ i [10]Q
+ j map[Q]Q
+ k func(Q, Q) (Q, Q)
+ l interface {
+ f(Q)
+ }
+ m struct {
+ Q
+ a Q
+ }
+}
+
+func f(...Q) {
+}
diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/generics_tests/consts/input.go
new file mode 100644
index 000000000..dabf76e1e
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/input.go
@@ -0,0 +1,26 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+const c1 = 10
+const x, y, z = 100, 200, 300
+const v float32 = 1.0 + 2.0
+const s = "abc"
+const (
+ A = 10
+ B, C, D = 10, 20, 30
+ S = "abc"
+ T, U, V string = "abc", "def", "ghi"
+)
diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt
new file mode 100644
index 000000000..4fb59dce8
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/opts.txt
@@ -0,0 +1 @@
+-c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC"
diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/generics_tests/consts/output/output.go
new file mode 100644
index 000000000..72865607e
--- /dev/null
+++ b/tools/go_generics/generics_tests/consts/output/output.go
@@ -0,0 +1,26 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+const c1 = 20
+const x, y, z = 100, 200, 600
+const v float32 = 3.3
+const s = "def"
+const (
+ A = 20
+ B, C, D = 10, 100, 30
+ S = "def"
+ T, U, V string = "ABC", "def", "ghi"
+)
diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/generics_tests/imports/input.go
new file mode 100644
index 000000000..66b43fee5
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/input.go
@@ -0,0 +1,24 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type T int
+
+var global T
+
+const (
+ m = 0
+ n = 0
+)
diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt
new file mode 100644
index 000000000..87324be79
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/opts.txt
@@ -0,0 +1 @@
+-t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath
diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/generics_tests/imports/output/output.go
new file mode 100644
index 000000000..5f20d43ce
--- /dev/null
+++ b/tools/go_generics/generics_tests/imports/output/output.go
@@ -0,0 +1,27 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ __generics_imported1 "mymathpath"
+ __generics_imported0 "sync"
+)
+
+var global __generics_imported0.Mutex
+
+const (
+ m = __generics_imported1.Uint64
+ n = __generics_imported1.Uint32
+)
diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/generics_tests/remove_typedef/input.go
new file mode 100644
index 000000000..c02307d32
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/input.go
@@ -0,0 +1,37 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+func f(T) Q {
+ return Q{}
+}
+
+type T struct{}
+
+type Q struct{}
+
+func (*T) f() {
+}
+
+func (T) g() {
+}
+
+func (*Q) f(T) T {
+ return T{}
+}
+
+func (*Q) g(T) *T {
+ return nil
+}
diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt
new file mode 100644
index 000000000..9c8ecaada
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/opts.txt
@@ -0,0 +1 @@
+-t=T=U
diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/generics_tests/remove_typedef/output/output.go
new file mode 100644
index 000000000..d20a89abd
--- /dev/null
+++ b/tools/go_generics/generics_tests/remove_typedef/output/output.go
@@ -0,0 +1,29 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+func f(U) Q {
+ return Q{}
+}
+
+type Q struct{}
+
+func (*Q) f(U) U {
+ return U{}
+}
+
+func (*Q) g(U) *U {
+ return nil
+}
diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/generics_tests/simple/input.go
new file mode 100644
index 000000000..670161d6e
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/input.go
@@ -0,0 +1,45 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type T int
+
+var global T
+
+func f(_ T, a int) {
+}
+
+func g(a T, b int) {
+ var c T
+ _ = c
+
+ d := (*T)(nil)
+ _ = d
+}
+
+type R struct {
+ T
+ a *T
+}
+
+var (
+ Z *T = (*T)(nil)
+)
+
+const (
+ X T = (T)(0)
+)
+
+type Y T
diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt
new file mode 100644
index 000000000..7832ef66f
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/opts.txt
@@ -0,0 +1 @@
+-t=T=Q -suffix=New
diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/generics_tests/simple/output/output.go
new file mode 100644
index 000000000..75b5467cd
--- /dev/null
+++ b/tools/go_generics/generics_tests/simple/output/output.go
@@ -0,0 +1,43 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+var globalNew Q
+
+func fNew(_ Q, a int) {
+}
+
+func gNew(a Q, b int) {
+ var c Q
+ _ = c
+
+ d := (*Q)(nil)
+ _ = d
+}
+
+type RNew struct {
+ Q
+ a *Q
+}
+
+var (
+ ZNew *Q = (*Q)(nil)
+)
+
+const (
+ XNew Q = (Q)(0)
+)
+
+type YNew Q
diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD
new file mode 100644
index 000000000..a238becab
--- /dev/null
+++ b/tools/go_generics/globals/BUILD
@@ -0,0 +1,13 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "globals",
+ srcs = [
+ "globals_visitor.go",
+ "scope.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/tools/go_generics/globals",
+ visibility = ["//tools/go_generics:__pkg__"],
+)
diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go
new file mode 100644
index 000000000..fc0de4381
--- /dev/null
+++ b/tools/go_generics/globals/globals_visitor.go
@@ -0,0 +1,588 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package globals provides an AST visitor that calls the visit function for all
+// global identifiers.
+package globals
+
+import (
+ "fmt"
+
+ "go/ast"
+ "go/token"
+ "path/filepath"
+ "strconv"
+)
+
+// globalsVisitor holds the state used while traversing the nodes of a file in
+// search of globals.
+//
+// The visitor does two passes on the global declarations: the first one adds
+// all globals to the global scope (since Go allows references to globals that
+// haven't been declared yet), and the second one calls f() for the definition
+// and uses of globals found in the first pass.
+//
+// The implementation correctly handles cases when globals are aliased by
+// locals; in such cases, f() is not called.
+type globalsVisitor struct {
+ // file is the file whose nodes are being visited.
+ file *ast.File
+
+ // fset is the file set the file being visited belongs to.
+ fset *token.FileSet
+
+ // f is the visit function to be called when a global symbol is reached.
+ f func(*ast.Ident, SymKind)
+
+ // scope is the current scope as nodes are visited.
+ scope *scope
+}
+
+// unexpected is called when an unexpected node appears in the AST. It dumps
+// the location of the associated token and panics because this should only
+// happen when there is a bug in the traversal code.
+func (v *globalsVisitor) unexpected(p token.Pos) {
+ panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p)))
+}
+
+// pushScope creates a new scope and pushes it to the top of the scope stack.
+func (v *globalsVisitor) pushScope() {
+ v.scope = newScope(v.scope)
+}
+
+// popScope removes the scope created by the last call to pushScope.
+func (v *globalsVisitor) popScope() {
+ v.scope = v.scope.outer
+}
+
+// visitType is called when an expression is known to be a type, for example,
+// on the first argument of make(). It visits all children nodes and reports
+// any globals.
+func (v *globalsVisitor) visitType(ge ast.Expr) {
+ switch e := ge.(type) {
+ case *ast.Ident:
+ if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
+ v.f(e, s.kind)
+ }
+
+ case *ast.SelectorExpr:
+ id := GetIdent(e.X)
+ if id == nil {
+ v.unexpected(e.X.Pos())
+ }
+
+ case *ast.StarExpr:
+ v.visitType(e.X)
+ case *ast.ParenExpr:
+ v.visitType(e.X)
+ case *ast.ChanType:
+ v.visitType(e.Value)
+ case *ast.Ellipsis:
+ v.visitType(e.Elt)
+ case *ast.ArrayType:
+ v.visitExpr(e.Len)
+ v.visitType(e.Elt)
+ case *ast.MapType:
+ v.visitType(e.Key)
+ v.visitType(e.Value)
+ case *ast.StructType:
+ v.visitFields(e.Fields, KindUnknown)
+ case *ast.FuncType:
+ v.visitFields(e.Params, KindUnknown)
+ v.visitFields(e.Results, KindUnknown)
+ case *ast.InterfaceType:
+ v.visitFields(e.Methods, KindUnknown)
+ default:
+ v.unexpected(ge.Pos())
+ }
+}
+
+// visitFields visits all fields, and add symbols if kind isn't KindUnknown.
+func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) {
+ if l == nil {
+ return
+ }
+
+ for _, f := range l.List {
+ if kind != KindUnknown {
+ for _, n := range f.Names {
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ v.visitType(f.Type)
+ if f.Tag != nil {
+ tag := ast.NewIdent(f.Tag.Value)
+ v.f(tag, KindTag)
+ // Replace the tag if updated.
+ if tag.Name != f.Tag.Value {
+ f.Tag.Value = tag.Name
+ }
+ }
+ }
+}
+
+// visitGenDecl is called when a generic declation is encountered, for example,
+// on variable, constant and type declarations. It adds all newly defined
+// symbols to the current scope and reports them if the current scope is the
+// global one.
+func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) {
+ switch d.Tok {
+ case token.IMPORT:
+ case token.TYPE:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ v.scope.add(s.Name.Name, KindType, s.Name.Pos())
+ if v.scope.isGlobal() {
+ v.f(s.Name, KindType)
+ }
+ v.visitType(s.Type)
+ }
+ case token.CONST, token.VAR:
+ kind := KindConst
+ if d.Tok == token.VAR {
+ kind = KindVar
+ }
+
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ValueSpec)
+ if s.Type != nil {
+ v.visitType(s.Type)
+ }
+
+ for _, e := range s.Values {
+ v.visitExpr(e)
+ }
+
+ for _, n := range s.Names {
+ if v.scope.isGlobal() {
+ v.f(n, kind)
+ }
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ default:
+ v.unexpected(d.Pos())
+ }
+}
+
+// isViableType determines if the given expression is a viable type expression,
+// that is, if it could be interpreted as a type, for example, sync.Mutex,
+// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc.
+func (v *globalsVisitor) isViableType(expr ast.Expr) bool {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ // This covers the plain identifier case. When we see it, we
+ // have to check if it resolves to a type; if the symbol is not
+ // known, we'll claim it's viable as a type.
+ s := v.scope.deepLookup(e.Name)
+ return s == nil || s.kind == KindType
+
+ case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis:
+ // This covers the following cases:
+ // 1. ChanType:
+ // chan T
+ // <-chan T
+ // chan<- T
+ // 2. ArrayType:
+ // [Expr]T
+ // 3. MapType:
+ // map[T]U
+ // 4. StructType:
+ // struct { Fields }
+ // 5. FuncType:
+ // func(Fields)Returns
+ // 6. Interface:
+ // interface { Fields }
+ // 7. Ellipsis:
+ // ...T
+ return true
+
+ case *ast.SelectorExpr:
+ // The only case in which an expression involving a selector can
+ // be a type is if it has the following form X.T, where X is an
+ // import, and T is a type exported by X.
+ //
+ // There's no way to know whether T is a type because we don't
+ // parse imports. So we just claim that this is a viable type;
+ // it doesn't affect the general result because we don't visit
+ // imported symbols.
+ id := GetIdent(e.X)
+ if id == nil {
+ return false
+ }
+
+ s := v.scope.deepLookup(id.Name)
+ return s != nil && s.kind == KindImport
+
+ case *ast.StarExpr:
+ // This covers the *T case. The expression is a viable type if
+ // T is.
+ return v.isViableType(e.X)
+
+ case *ast.ParenExpr:
+ // This covers the (T) case. The expression is a viable type if
+ // T is.
+ return v.isViableType(e.X)
+
+ default:
+ return false
+ }
+}
+
+// visitCallExpr visits a "call expression" which can be either a
+// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type
+// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.).
+func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) {
+ if v.isViableType(e.Fun) {
+ v.visitType(e.Fun)
+ } else {
+ v.visitExpr(e.Fun)
+ }
+
+ // If the function being called is new or make, the first argument is
+ // a type, so it needs to be visited as such.
+ first := 0
+ if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") {
+ if len(e.Args) > 0 {
+ v.visitType(e.Args[0])
+ }
+ first = 1
+ }
+
+ for i := first; i < len(e.Args); i++ {
+ v.visitExpr(e.Args[i])
+ }
+}
+
+// visitExpr visits all nodes of an expression, and reports any globals that it
+// finds.
+func (v *globalsVisitor) visitExpr(ge ast.Expr) {
+ switch e := ge.(type) {
+ case nil:
+ case *ast.Ident:
+ if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
+ v.f(e, s.kind)
+ }
+
+ case *ast.BasicLit:
+ case *ast.CompositeLit:
+ v.visitType(e.Type)
+ for _, ne := range e.Elts {
+ v.visitExpr(ne)
+ }
+ case *ast.FuncLit:
+ v.pushScope()
+ v.visitFields(e.Type.Params, KindParameter)
+ v.visitFields(e.Type.Results, KindResult)
+ v.visitBlockStmt(e.Body)
+ v.popScope()
+
+ case *ast.BinaryExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Y)
+
+ case *ast.CallExpr:
+ v.visitCallExpr(e)
+
+ case *ast.IndexExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Index)
+
+ case *ast.KeyValueExpr:
+ v.visitExpr(e.Value)
+
+ case *ast.ParenExpr:
+ v.visitExpr(e.X)
+
+ case *ast.SelectorExpr:
+ v.visitExpr(e.X)
+
+ case *ast.SliceExpr:
+ v.visitExpr(e.X)
+ v.visitExpr(e.Low)
+ v.visitExpr(e.High)
+ v.visitExpr(e.Max)
+
+ case *ast.StarExpr:
+ v.visitExpr(e.X)
+
+ case *ast.TypeAssertExpr:
+ v.visitExpr(e.X)
+ if e.Type != nil {
+ v.visitType(e.Type)
+ }
+
+ case *ast.UnaryExpr:
+ v.visitExpr(e.X)
+
+ default:
+ v.unexpected(ge.Pos())
+ }
+}
+
+// GetIdent returns the identifier associated with the given expression by
+// removing parentheses if needed.
+func GetIdent(expr ast.Expr) *ast.Ident {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ return e
+ case *ast.ParenExpr:
+ return GetIdent(e.X)
+ default:
+ return nil
+ }
+}
+
+// visitStmt visits all nodes of a statement, and reports any globals that it
+// finds. It also adds to the current scope new symbols defined/declared.
+func (v *globalsVisitor) visitStmt(gs ast.Stmt) {
+ switch s := gs.(type) {
+ case nil, *ast.BranchStmt, *ast.EmptyStmt:
+ case *ast.AssignStmt:
+ for _, e := range s.Rhs {
+ v.visitExpr(e)
+ }
+
+ // We visit the LHS after the RHS because the symbols we'll
+ // potentially add to the table aren't meant to be visible to
+ // the RHS.
+ for _, e := range s.Lhs {
+ if s.Tok == token.DEFINE {
+ if n := GetIdent(e); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+ }
+ v.visitExpr(e)
+ }
+
+ case *ast.BlockStmt:
+ v.visitBlockStmt(s)
+
+ case *ast.DeclStmt:
+ v.visitGenDecl(s.Decl.(*ast.GenDecl))
+
+ case *ast.DeferStmt:
+ v.visitCallExpr(s.Call)
+
+ case *ast.ExprStmt:
+ v.visitExpr(s.X)
+
+ case *ast.ForStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Cond)
+ v.visitStmt(s.Post)
+ v.visitBlockStmt(s.Body)
+ v.popScope()
+
+ case *ast.GoStmt:
+ v.visitCallExpr(s.Call)
+
+ case *ast.IfStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Cond)
+ v.visitBlockStmt(s.Body)
+ v.visitStmt(s.Else)
+ v.popScope()
+
+ case *ast.IncDecStmt:
+ v.visitExpr(s.X)
+
+ case *ast.LabeledStmt:
+ v.visitStmt(s.Stmt)
+
+ case *ast.RangeStmt:
+ v.pushScope()
+ v.visitExpr(s.X)
+ if s.Tok == token.DEFINE {
+ if n := GetIdent(s.Key); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+
+ if n := GetIdent(s.Value); n != nil {
+ v.scope.add(n.Name, KindVar, n.Pos())
+ }
+ }
+ v.visitExpr(s.Key)
+ v.visitExpr(s.Value)
+ v.visitBlockStmt(s.Body)
+ v.popScope()
+
+ case *ast.ReturnStmt:
+ for _, r := range s.Results {
+ v.visitExpr(r)
+ }
+
+ case *ast.SelectStmt:
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CommClause)
+
+ v.pushScope()
+ v.visitStmt(c.Comm)
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+
+ case *ast.SendStmt:
+ v.visitExpr(s.Chan)
+ v.visitExpr(s.Value)
+
+ case *ast.SwitchStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitExpr(s.Tag)
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CaseClause)
+ v.pushScope()
+ for _, ce := range c.List {
+ v.visitExpr(ce)
+ }
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+ v.popScope()
+
+ case *ast.TypeSwitchStmt:
+ v.pushScope()
+ v.visitStmt(s.Init)
+ v.visitStmt(s.Assign)
+ for _, ns := range s.Body.List {
+ c := ns.(*ast.CaseClause)
+ v.pushScope()
+ for _, ce := range c.List {
+ v.visitType(ce)
+ }
+ for _, bs := range c.Body {
+ v.visitStmt(bs)
+ }
+ v.popScope()
+ }
+ v.popScope()
+
+ default:
+ v.unexpected(gs.Pos())
+ }
+}
+
+// visitBlockStmt visits all statements in the block, adding symbols to a newly
+// created scope.
+func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) {
+ v.pushScope()
+ for _, c := range s.List {
+ v.visitStmt(c)
+ }
+ v.popScope()
+}
+
+// visitFuncDecl is called when a function or method declation is encountered.
+// it creates a new scope for the function [optional] receiver, parameters and
+// results, and visits all children nodes.
+func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) {
+ // We don't report methods.
+ if d.Recv == nil {
+ v.f(d.Name, KindFunction)
+ }
+
+ v.pushScope()
+ v.visitFields(d.Recv, KindReceiver)
+ v.visitFields(d.Type.Params, KindParameter)
+ v.visitFields(d.Type.Results, KindResult)
+ if d.Body != nil {
+ v.visitBlockStmt(d.Body)
+ }
+ v.popScope()
+}
+
+// globalsFromDecl is called in the first, and adds symbols to global scope.
+func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) {
+ switch d.Tok {
+ case token.IMPORT:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.ImportSpec)
+ if s.Name == nil {
+ str, _ := strconv.Unquote(s.Path.Value)
+ v.scope.add(filepath.Base(str), KindImport, s.Path.Pos())
+ } else if s.Name.Name != "_" {
+ v.scope.add(s.Name.Name, KindImport, s.Name.Pos())
+ }
+ }
+ case token.TYPE:
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ v.scope.add(s.Name.Name, KindType, s.Name.Pos())
+ }
+ case token.CONST, token.VAR:
+ kind := KindConst
+ if d.Tok == token.VAR {
+ kind = KindVar
+ }
+
+ for _, s := range d.Specs {
+ for _, n := range s.(*ast.ValueSpec).Names {
+ v.scope.add(n.Name, kind, n.Pos())
+ }
+ }
+ default:
+ v.unexpected(d.Pos())
+ }
+}
+
+// visit implements the visiting of globals. It does performs the two passes
+// described in the description of the globalsVisitor struct.
+func (v *globalsVisitor) visit() {
+ // Gather all symbols in the global scope. This excludes methods.
+ v.pushScope()
+ for _, gd := range v.file.Decls {
+ switch d := gd.(type) {
+ case *ast.GenDecl:
+ v.globalsFromGenDecl(d)
+ case *ast.FuncDecl:
+ if d.Recv == nil {
+ v.scope.add(d.Name.Name, KindFunction, d.Name.Pos())
+ }
+ default:
+ v.unexpected(gd.Pos())
+ }
+ }
+
+ // Go through the contents of the declarations.
+ for _, gd := range v.file.Decls {
+ switch d := gd.(type) {
+ case *ast.GenDecl:
+ v.visitGenDecl(d)
+ case *ast.FuncDecl:
+ v.visitFuncDecl(d)
+ }
+ }
+}
+
+// Visit traverses the provided AST and calls f() for each identifier that
+// refers to global names. The global name must be defined in the file itself.
+//
+// The function f() is allowed to modify the identifier, for example, to rename
+// uses of global references.
+func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind)) {
+ v := globalsVisitor{
+ fset: fset,
+ file: file,
+ f: f,
+ }
+
+ v.visit()
+}
diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go
new file mode 100644
index 000000000..18743bdee
--- /dev/null
+++ b/tools/go_generics/globals/scope.go
@@ -0,0 +1,80 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package globals
+
+import (
+ "go/token"
+)
+
+// SymKind specifies the kind of a global symbol. For example, a variable, const
+// function, etc.
+type SymKind int
+
+// Constants for different kinds of symbols.
+const (
+ KindUnknown SymKind = iota
+ KindImport
+ KindType
+ KindVar
+ KindConst
+ KindFunction
+ KindReceiver
+ KindParameter
+ KindResult
+ KindTag
+)
+
+type symbol struct {
+ kind SymKind
+ pos token.Pos
+ scope *scope
+}
+
+type scope struct {
+ outer *scope
+ syms map[string]*symbol
+}
+
+func newScope(outer *scope) *scope {
+ return &scope{
+ outer: outer,
+ syms: make(map[string]*symbol),
+ }
+}
+
+func (s *scope) isGlobal() bool {
+ return s.outer == nil
+}
+
+func (s *scope) lookup(n string) *symbol {
+ return s.syms[n]
+}
+
+func (s *scope) deepLookup(n string) *symbol {
+ for x := s; x != nil; x = x.outer {
+ if sym := x.lookup(n); sym != nil {
+ return sym
+ }
+ }
+ return nil
+}
+
+func (s *scope) add(name string, kind SymKind, pos token.Pos) {
+ s.syms[name] = &symbol{
+ kind: kind,
+ pos: pos,
+ scope: s,
+ }
+}
diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh
new file mode 100755
index 000000000..699e1f631
--- /dev/null
+++ b/tools/go_generics/go_generics_unittest.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+
+# Copyright 2018 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Bash "safe-mode": Treat command failures as fatal (even those that occur in
+# pipes), and treat unset variables as errors.
+set -eu -o pipefail
+
+# This file will be generated as a self-extracting shell script in order to
+# eliminate the need for any runtime dependencies. The tarball at the end will
+# include the go_generics binary, as well as a subdirectory named
+# generics_tests. See the BUILD file for more information.
+declare -r temp=$(mktemp -d)
+function cleanup() {
+ rm -rf "${temp}"
+}
+# trap cleanup EXIT
+
+# Print message in "$1" then exit with status 1.
+function die () {
+ echo "$1" 1>&2
+ exit 1
+}
+
+# This prints the line number of __BUNDLE__ below, that should be the last line
+# of this script. After that point, the concatenated archive will be the
+# contents.
+declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0`
+tail -n+"${tgz}" $0 | tar -xzv -C "${temp}"
+
+# The target for the test.
+declare -r binary="$(find ${temp} -type f -a -name go_generics)"
+declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*"
+
+# Go through all test cases.
+for f in ${input_dirs}; do
+ base=$(basename "${f}")
+
+ # Run go_generics on the input file.
+ opts=$(head -n 1 ${f}/opts.txt)
+ out="${f}/output/generated.go"
+ expected="${f}/output/output.go"
+ ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\""
+
+ # Compare the outputs.
+ diff ${expected} ${out}
+ if [ $? -ne 0 ]; then
+ echo "Expected:"
+ cat ${expected}
+ echo "Actual:"
+ cat ${out}
+ die "Actual output is different from expected for test \"${base}\""
+ fi
+done
+
+echo "PASS"
+exit 0
+__BUNDLE__
diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go
new file mode 100644
index 000000000..97267098b
--- /dev/null
+++ b/tools/go_generics/imports.go
@@ -0,0 +1,150 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/tools/go_generics/globals"
+)
+
+type importedPackage struct {
+ newName string
+ path string
+}
+
+// updateImportIdent modifies the given import identifier with the new name
+// stored in the used map. If the identifier doesn't exist in the used map yet,
+// a new name is generated and inserted into the map.
+func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error {
+ importName := id.Name
+
+ // If the name is already in the table, just use the new name.
+ m := used[importName]
+ if m != nil {
+ id.Name = m.newName
+ return nil
+ }
+
+ // Create a new entry in the used map.
+ path := imports[importName]
+ if path == "" {
+ return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig)
+ }
+
+ m = &importedPackage{
+ newName: fmt.Sprintf("__generics_imported%d", len(used)),
+ path: strconv.Quote(path),
+ }
+ used[importName] = m
+
+ id.Name = m.newName
+
+ return nil
+}
+
+// convertExpression creates a new string that is a copy of the input one with
+// all imports references renamed to the names in the "used" map. If the
+// referenced import isn't in "used" yet, a new one is created based on the path
+// in "imports" and stored in "used". For example, if string s is
+// "math.MaxUint32-math.MaxUint16+10", it would be converted to
+// "x.MaxUint32-x.MathUint16+10", where x is a generated name.
+func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) {
+ // Parse the expression in the input string.
+ expr, err := parser.ParseExpr(s)
+ if err != nil {
+ return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err)
+ }
+
+ // Go through the AST and update references.
+ var retErr error
+ ast.Inspect(expr, func(n ast.Node) bool {
+ switch x := n.(type) {
+ case *ast.SelectorExpr:
+ if id := globals.GetIdent(x.X); id != nil {
+ if err := updateImportIdent(s, imports, id, used); err != nil {
+ retErr = err
+ }
+ return false
+ }
+ }
+ return true
+ })
+ if retErr != nil {
+ return "", retErr
+ }
+
+ // Convert the modified AST back to a string.
+ fset := token.NewFileSet()
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, expr); err != nil {
+ return "", err
+ }
+
+ return string(buf.Bytes()), nil
+}
+
+// updateImports replaces all maps in the input slice with copies where the
+// mapped values have had all references to imported packages renamed to
+// generated names. It also returns an import declaration for all the renamed
+// import packages.
+//
+// For example, if the input maps contains A=math.B and C=math.D, the updated
+// maps will instead contain A=__generics_imported0.B and
+// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would
+// be returned as the import declaration.
+func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
+ importsUsed := make(map[string]*importedPackage)
+
+ // Update all maps.
+ for i, m := range maps {
+ newMap := make(mapValue)
+ for n, e := range m {
+ updated, err := convertExpression(e, imports, importsUsed)
+ if err != nil {
+ return nil, err
+ }
+
+ newMap[n] = updated
+ }
+ maps[i] = newMap
+ }
+
+ // Nothing else to do if no imports are used in the expressions.
+ if len(importsUsed) == 0 {
+ return nil, nil
+ }
+
+ // Create spec array for each new import.
+ specs := make([]ast.Spec, 0, len(importsUsed))
+ for _, i := range importsUsed {
+ specs = append(specs, &ast.ImportSpec{
+ Name: &ast.Ident{Name: i.newName},
+ Path: &ast.BasicLit{Value: i.path},
+ })
+ }
+
+ return &ast.GenDecl{
+ Tok: token.IMPORT,
+ Specs: specs,
+ Lparen: token.NoPos + 1,
+ }, nil
+}
diff --git a/tools/go_generics/merge.go b/tools/go_generics/merge.go
new file mode 100644
index 000000000..ebe7cf4e4
--- /dev/null
+++ b/tools/go_generics/merge.go
@@ -0,0 +1,139 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strconv"
+)
+
+var (
+ output = flag.String("o", "", "output `file`")
+)
+
+func fatalf(s string, args ...interface{}) {
+ fmt.Fprintf(os.Stderr, s, args...)
+ os.Exit(1)
+}
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+
+ flag.Parse()
+ if *output == "" || len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // Load all files.
+ files := make(map[string]*ast.File)
+ fset := token.NewFileSet()
+ var name string
+ for _, fname := range flag.Args() {
+ f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors)
+ if err != nil {
+ fatalf("%v\n", err)
+ }
+
+ files[fname] = f
+ if name == "" {
+ name = f.Name.Name
+ } else if name != f.Name.Name {
+ fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name)
+ }
+ }
+
+ // Merge all files into one.
+ pkg := &ast.Package{
+ Name: name,
+ Files: files,
+ }
+ f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates)
+
+ // Create a new declaration slice with all imports at the top, merging any
+ // redundant imports.
+ imports := make(map[string]*ast.ImportSpec)
+ var anonImports []*ast.ImportSpec
+ for _, d := range f.Decls {
+ if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT {
+ for _, s := range g.Specs {
+ i := s.(*ast.ImportSpec)
+ p, _ := strconv.Unquote(i.Path.Value)
+ var n string
+ if i.Name == nil {
+ n = filepath.Base(p)
+ } else {
+ n = i.Name.Name
+ }
+ if n == "_" {
+ anonImports = append(anonImports, i)
+ } else {
+ if i2, ok := imports[n]; ok {
+ if first, second := i.Path.Value, i2.Path.Value; first != second {
+ fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second)
+ }
+ } else {
+ imports[n] = i
+ }
+ }
+ }
+ }
+ }
+ newDecls := make([]ast.Decl, 0, len(f.Decls))
+ if l := len(imports) + len(anonImports); l > 0 {
+ // Non-NoPos Lparen is needed for Go to recognize more than one spec in
+ // ast.GenDecl.Specs.
+ d := &ast.GenDecl{
+ Tok: token.IMPORT,
+ Lparen: token.NoPos + 1,
+ Specs: make([]ast.Spec, 0, l),
+ }
+ for _, i := range imports {
+ d.Specs = append(d.Specs, i)
+ }
+ for _, i := range anonImports {
+ d.Specs = append(d.Specs, i)
+ }
+ newDecls = append(newDecls, d)
+ }
+ for _, d := range f.Decls {
+ if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT {
+ newDecls = append(newDecls, d)
+ }
+ }
+ f.Decls = newDecls
+
+ // Write the output file.
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, f); err != nil {
+ fatalf("%v\n", err)
+ }
+
+ if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil {
+ fatalf("%v\n", err)
+ }
+}
diff --git a/tools/go_generics/remove.go b/tools/go_generics/remove.go
new file mode 100644
index 000000000..2a66de762
--- /dev/null
+++ b/tools/go_generics/remove.go
@@ -0,0 +1,105 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "go/ast"
+ "go/token"
+)
+
+type typeSet map[string]struct{}
+
+// isTypeOrPointerToType determines if the given AST expression represents a
+// type or a pointer to a type that exists in the provided type set.
+func isTypeOrPointerToType(set typeSet, expr ast.Expr, starCount int) bool {
+ switch e := expr.(type) {
+ case *ast.Ident:
+ _, ok := set[e.Name]
+ return ok
+ case *ast.StarExpr:
+ if starCount > 1 {
+ return false
+ }
+ return isTypeOrPointerToType(set, e.X, starCount+1)
+ case *ast.ParenExpr:
+ return isTypeOrPointerToType(set, e.X, starCount)
+ default:
+ return false
+ }
+}
+
+// isMethodOf determines if the given function declaration is a method of one
+// of the types in the provided type set. To do that, it checks if the function
+// has a receiver and that its type is either T or *T, where T is a type that
+// exists in the set. This is per the spec:
+//
+// That parameter section must declare a single parameter, the receiver. Its
+// type must be of the form T or *T (possibly using parentheses) where T is a
+// type name. The type denoted by T is called the receiver base type; it must
+// not be a pointer or interface type and it must be declared in the same
+// package as the method.
+func isMethodOf(set typeSet, f *ast.FuncDecl) bool {
+ // If the function doesn't have exactly one receiver, then it's
+ // definitely not a method.
+ if f.Recv == nil || len(f.Recv.List) != 1 {
+ return false
+ }
+
+ return isTypeOrPointerToType(set, f.Recv.List[0].Type, 0)
+}
+
+// removeTypeDefinitions removes the definition of all types contained in the
+// provided type set.
+func removeTypeDefinitions(set typeSet, d *ast.GenDecl) {
+ if d.Tok != token.TYPE {
+ return
+ }
+
+ i := 0
+ for _, gs := range d.Specs {
+ s := gs.(*ast.TypeSpec)
+ if _, ok := set[s.Name.Name]; !ok {
+ d.Specs[i] = gs
+ i++
+ }
+ }
+
+ d.Specs = d.Specs[:i]
+}
+
+// removeTypes removes from the AST the definition of all types and their
+// method sets that are contained in the provided type set.
+func removeTypes(set typeSet, f *ast.File) {
+ // Go through the top-level declarations.
+ i := 0
+ for _, decl := range f.Decls {
+ keep := true
+ switch d := decl.(type) {
+ case *ast.GenDecl:
+ countBefore := len(d.Specs)
+ removeTypeDefinitions(set, d)
+ keep = countBefore == 0 || len(d.Specs) > 0
+ case *ast.FuncDecl:
+ keep = !isMethodOf(set, d)
+ }
+
+ if keep {
+ f.Decls[i] = decl
+ i++
+ }
+ }
+
+ f.Decls = f.Decls[:i]
+}
diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD
new file mode 100644
index 000000000..2d9a6fa9d
--- /dev/null
+++ b/tools/go_generics/rules_tests/BUILD
@@ -0,0 +1,43 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+go_template_instance(
+ name = "instance",
+ out = "instance_test.go",
+ consts = {
+ "n": "20",
+ "m": "\"test\"",
+ "o": "math.MaxUint64",
+ },
+ imports = {
+ "math": "math",
+ },
+ package = "template_test",
+ template = ":test_template",
+ types = {
+ "t": "int",
+ },
+)
+
+go_template(
+ name = "test_template",
+ srcs = [
+ "template.go",
+ ],
+ opt_consts = [
+ "n",
+ "m",
+ "o",
+ ],
+ opt_types = ["t"],
+)
+
+go_test(
+ name = "template_test",
+ srcs = [
+ "instance_test.go",
+ "template_test.go",
+ ],
+)
diff --git a/tools/go_generics/rules_tests/template.go b/tools/go_generics/rules_tests/template.go
new file mode 100644
index 000000000..73c024f0e
--- /dev/null
+++ b/tools/go_generics/rules_tests/template.go
@@ -0,0 +1,42 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package template
+
+type t float
+
+const (
+ n t = 10.1
+ m = "abc"
+ o = 0
+)
+
+func max(a, b t) t {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+func add(a t) t {
+ return a + n
+}
+
+func getName() string {
+ return m
+}
+
+func getMax() uint64 {
+ return o
+}
diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go
new file mode 100644
index 000000000..76c4cdb64
--- /dev/null
+++ b/tools/go_generics/rules_tests/template_test.go
@@ -0,0 +1,48 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package template_test
+
+import (
+ "math"
+ "testing"
+)
+
+func TestMax(t *testing.T) {
+ var a int = max(10, 20)
+ if a != 20 {
+ t.Errorf("Bad result of max, got %v, want %v", a, 20)
+ }
+}
+
+func TestIntConst(t *testing.T) {
+ var a int = add(10)
+ if a != 30 {
+ t.Errorf("Bad result of add, got %v, want %v", a, 30)
+ }
+}
+
+func TestStrConst(t *testing.T) {
+ v := getName()
+ if v != "test" {
+ t.Errorf("Bad name, got %v, want %v", v, "test")
+ }
+}
+
+func TestImport(t *testing.T) {
+ v := getMax()
+ if v != math.MaxUint64 {
+ t.Errorf("Bad max value, got %v, want %v", v, uint64(math.MaxUint64))
+ }
+}
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
new file mode 100644
index 000000000..edbeb4e2d
--- /dev/null
+++ b/tools/go_stateify/BUILD
@@ -0,0 +1,9 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+go_binary(
+ name = "stateify",
+ srcs = ["main.go"],
+ visibility = ["//visibility:public"],
+)
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
new file mode 100644
index 000000000..87fdc0d28
--- /dev/null
+++ b/tools/go_stateify/defs.bzl
@@ -0,0 +1,77 @@
+"""Stateify is a tool for generating state wrappers for Go types.
+
+The go_stateify rule is used to generate a file that will appear in a Go
+target; the output file should appear explicitly in a srcs list. For example:
+
+go_stateify(
+ name = "foo_state",
+ srcs = ["foo.go"],
+ out = "foo_state.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_state.go",
+ ],
+ deps = [
+ "//pkg/state",
+ ],
+)
+"""
+
+def _go_stateify_impl(ctx):
+ """Implementation for the stateify tool."""
+ output = ctx.outputs.out
+
+ # Run the stateify command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ if ctx.attr._statepkg:
+ args += ["-statepkg=%s" % ctx.attr._statepkg]
+ if ctx.attr.imports:
+ args += ["-imports=%s" % ",".join(ctx.attr.imports)]
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output],
+ mnemonic = "GoStateify",
+ progress_message = "Generating state library %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+"""
+Generates save and restore logic from a set of Go files.
+
+
+Args:
+ name: the name of the rule.
+ srcs: the input source files. These files should include all structs in the package that need to be saved.
+ imports: an optional list of extra non-aliased, Go-style absolute import paths.
+ out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
+ package: the package name for the input sources.
+"""
+
+go_stateify = rule(
+ attrs = {
+ "srcs": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "out": attr.output(mandatory = True),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_stateify:stateify"),
+ ),
+ "_statepkg": attr.string(default = "gvisor.googlesource.com/gvisor/pkg/state"),
+ },
+ implementation = _go_stateify_impl,
+)
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
new file mode 100644
index 000000000..5eb4fe51f
--- /dev/null
+++ b/tools/go_stateify/main.go
@@ -0,0 +1,386 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Stateify provides a simple way to generate Load/Save methods based on
+// existing types and struct tags.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "reflect"
+ "strings"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ imports = flag.String("imports", "", "extra imports for the output file")
+ output = flag.String("output", "", "output file")
+ statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
+)
+
+// resolveTypeName returns a qualified type name.
+func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) {
+ for done := false; !done; {
+ // Resolve star expressions.
+ switch rs := typ.(type) {
+ case *ast.StarExpr:
+ qualified += "*"
+ typ = rs.X
+ case *ast.ArrayType:
+ if rs.Len == nil {
+ // Slice type declaration.
+ qualified += "[]"
+ } else {
+ // Array type declaration.
+ qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]"
+ }
+ typ = rs.Elt
+ default:
+ // No more descent.
+ done = true
+ }
+ }
+
+ // Resolve a package selector.
+ sel, ok := typ.(*ast.SelectorExpr)
+ if ok {
+ qualified = qualified + sel.X.(*ast.Ident).Name + "."
+ typ = sel.Sel
+ }
+
+ // Figure out actual type name.
+ ident, ok := typ.(*ast.Ident)
+ if !ok {
+ panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name))
+ }
+ field = ident.Name
+ qualified = qualified + field
+ return
+}
+
+// extractStateTag pulls the relevant state tag.
+func extractStateTag(tag *ast.BasicLit) string {
+ if tag == nil {
+ return ""
+ }
+ if len(tag.Value) < 2 {
+ return ""
+ }
+ return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state")
+}
+
+// scanFunctions is a set of functions passed to scanFields.
+type scanFunctions struct {
+ zerovalue func(name string)
+ normal func(name string)
+ wait func(name string)
+ value func(name, typName string)
+}
+
+// scanFields scans the fields of a struct.
+//
+// Each provided function will be applied to appropriately tagged fields, or
+// skipped if nil.
+//
+// Fields tagged nosave are skipped.
+func scanFields(ss *ast.StructType, fn scanFunctions) {
+ if ss.Fields.List == nil {
+ // No fields.
+ return
+ }
+
+ // Scan all fields.
+ for _, field := range ss.Fields.List {
+ // Calculate the name.
+ name := ""
+ if field.Names != nil {
+ // It's a named field; override.
+ name = field.Names[0].Name
+ } else {
+ // Anonymous types can't be embedded, so we don't need
+ // to worry about providing a useful name here.
+ name, _ = resolveTypeName("", field.Type)
+ }
+
+ // Skip _ fields.
+ if name == "_" {
+ continue
+ }
+
+ switch tag := extractStateTag(field.Tag); tag {
+ case "zerovalue":
+ if fn.zerovalue != nil {
+ fn.zerovalue(name)
+ }
+
+ case "":
+ if fn.normal != nil {
+ fn.normal(name)
+ }
+
+ case "wait":
+ if fn.wait != nil {
+ fn.wait(name)
+ }
+
+ case "manual", "nosave", "ignore":
+ // Do nothing.
+
+ default:
+ if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") {
+ if fn.value != nil {
+ fn.value(name, tag[2:len(tag)-1])
+ }
+ }
+ }
+ }
+}
+
+func camelCased(name string) string {
+ return strings.ToUpper(name[:1]) + name[1:]
+}
+
+func main() {
+ // Parse flags.
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+ if *pkg == "" {
+ fmt.Fprintf(os.Stderr, "Error: package required.")
+ os.Exit(1)
+ }
+
+ // Open the output file.
+ var (
+ outputFile *os.File
+ err error
+ )
+ if *output == "" || *output == "-" {
+ outputFile = os.Stdout
+ } else {
+ outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err)
+ }
+ defer outputFile.Close()
+ }
+
+ // Set the statePrefix for below, depending on the import.
+ statePrefix := ""
+ if *statePkg != "" {
+ parts := strings.Split(*statePkg, "/")
+ statePrefix = parts[len(parts)-1] + "."
+ }
+
+ // initCalls is dumped at the end.
+ var initCalls []string
+
+ // Declare our emission closures.
+ emitRegister := func(name string) {
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name))
+ }
+ emitZeroCheck := func(name string) {
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
+ }
+
+ // Emit the package name.
+ fmt.Fprintf(outputFile, "// automatically generated by stateify.\n\n")
+ fmt.Fprintf(outputFile, "package %s\n\n", *pkg)
+ fmt.Fprintln(outputFile, "import (")
+ if *statePkg != "" {
+ fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg)
+ }
+ if *imports != "" {
+ for _, i := range strings.Split(*imports, ",") {
+ fmt.Fprintf(outputFile, " \"%s\"\n", i)
+ }
+ }
+ fmt.Fprintln(outputFile, ")\n")
+
+ files := make([]*ast.File, 0, len(flag.Args()))
+
+ // Parse the input files.
+ for _, filename := range flag.Args() {
+ // Parse the file.
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, filename, nil, 0)
+ if err != nil {
+ // Not a valid input file?
+ fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
+ os.Exit(1)
+ }
+ files = append(files, f)
+ }
+
+ type method struct {
+ receiver string
+ name string
+ }
+
+ // Search for and add all methods with a pointer receiver and no other
+ // arguments to a set. We support auto-detecting the existence of
+ // several different methods with this signature.
+ simpleMethods := map[method]struct{}{}
+ for _, f := range files {
+
+ // Go over all functions.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.FuncDecl)
+ if !ok {
+ continue
+ }
+ if d.Name == nil || d.Recv == nil || d.Type == nil {
+ // Not a named method.
+ continue
+ }
+ if len(d.Recv.List) != 1 {
+ // Wrong number of receivers?
+ continue
+ }
+ if d.Type.Params != nil && len(d.Type.Params.List) != 0 {
+ // Has argument(s).
+ continue
+ }
+ if d.Type.Results != nil && len(d.Type.Results.List) != 0 {
+ // Has return(s).
+ continue
+ }
+
+ pt, ok := d.Recv.List[0].Type.(*ast.StarExpr)
+ if !ok {
+ // Not a pointer receiver.
+ continue
+ }
+
+ t, ok := pt.X.(*ast.Ident)
+ if !ok {
+ // This shouldn't happen with valid Go.
+ continue
+ }
+
+ simpleMethods[method{t.Name, d.Name.Name}] = struct{}{}
+ }
+ }
+
+ for _, f := range files {
+ // Go over all named types.
+ for _, decl := range f.Decls {
+ d, ok := decl.(*ast.GenDecl)
+ if !ok || d.Tok != token.TYPE {
+ continue
+ }
+
+ for _, gs := range d.Specs {
+ ts := gs.(*ast.TypeSpec)
+ switch ts.Type.(type) {
+ case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr:
+ // Don't register.
+ break
+ case *ast.StructType:
+ ss := ts.Type.(*ast.StructType)
+
+ // Define beforeSave if a definition was not found. This
+ // prevents the code from compiling if a custom beforeSave
+ // was defined in a file not provided to this binary and
+ // prevents inherited methods from being called multiple times
+ // by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name)
+ }
+
+ // Generate the save method.
+ fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(ss, scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(ss, scanFunctions{value: emitSaveValue})
+ scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Define afterLoad if a definition was not found. We do this
+ // for the same reason that we do it for beforeSave.
+ _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
+ if !hasAfterLoad {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name)
+ }
+
+ // Generate the load method.
+ //
+ // Note that the manual loads always follow the
+ // automated loads.
+ fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
+ scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(ss, scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Add to our registration.
+ emitRegister(ts.Name.Name)
+ case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
+ _, val := resolveTypeName(ts.Name.Name, ts.Type)
+
+ // Dispatch directly.
+ fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val)
+ fmt.Fprintf(outputFile, "}\n\n")
+ fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // See above.
+ emitRegister(ts.Name.Name)
+ }
+ }
+ }
+ }
+
+ // Emit the init() function.
+ fmt.Fprintf(outputFile, "func init() {\n")
+ for _, ic := range initCalls {
+ fmt.Fprintf(outputFile, " %s\n", ic)
+ }
+ fmt.Fprintf(outputFile, "}\n")
+}