diff options
Diffstat (limited to 'tools')
35 files changed, 3164 insertions, 0 deletions
diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD new file mode 100644 index 000000000..1afc58625 --- /dev/null +++ b/tools/go_generics/BUILD @@ -0,0 +1,46 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_binary") + +go_binary( + name = "go_generics", + srcs = [ + "generics.go", + "imports.go", + "remove.go", + ], + visibility = ["//visibility:public"], + deps = ["//tools/go_generics/globals"], +) + +go_binary( + name = "go_merge", + srcs = [ + "merge.go", + ], + visibility = ["//visibility:public"], +) + +genrule( + name = "go_generics_tests", + srcs = glob(["generics_tests/**"]) + [":go_generics"], + outs = ["go_generics_tests.tgz"], + cmd = "tar -czvhf $@ $(SRCS)", +) + +genrule( + name = "go_generics_test_bundle", + srcs = [ + ":go_generics_tests.tgz", + ":go_generics_unittest.sh", + ], + outs = ["go_generics_test.sh"], + cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", + executable = True, +) + +sh_test( + name = "go_generics_test", + size = "small", + srcs = ["go_generics_test.sh"], +) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl new file mode 100644 index 000000000..0b2467805 --- /dev/null +++ b/tools/go_generics/defs.bzl @@ -0,0 +1,152 @@ +def _go_template_impl(ctx): + input = ctx.files.srcs + output = ctx.outputs.out + + args = ["-o=%s" % output.path] + [f.path for f in input] + + ctx.actions.run( + inputs = input, + outputs = [output], + mnemonic = "GoGenericsTemplate", + progress_message = "Building Go template %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + + return struct( + types = ctx.attr.types, + opt_types = ctx.attr.opt_types, + consts = ctx.attr.consts, + opt_consts = ctx.attr.opt_consts, + deps = ctx.attr.deps, + file = output, + ) + +""" +Generates a Go template from a set of Go files. + +A Go template is similar to a go library, except that it has certain types that +can be replaced before usage. For example, one could define a templatized List +struct, whose elements are of type T, then instantiate that template for +T=segment, where "segment" is the concrete type. + +Args: + name: the name of the template. + srcs: the list of source files that comprise the template. + types: the list of generic types in the template that are required to be specified. + opt_types: the list of generic types in the template that can but aren't required to be specified. + consts: the list of constants in the template that are required to be specified. + opt_consts: the list of constants in the template that can but aren't required to be specified. + deps: the list of dependencies. +""" + +go_template = rule( + attrs = { + "srcs": attr.label_list( + mandatory = True, + allow_files = True, + ), + "deps": attr.label_list(allow_files = True), + "types": attr.string_list(), + "opt_types": attr.string_list(), + "consts": attr.string_list(), + "opt_consts": attr.string_list(), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_generics:go_merge"), + ), + }, + outputs = { + "out": "%{name}_template.go", + }, + implementation = _go_template_impl, +) + +def _go_template_instance_impl(ctx): + template = ctx.attr.template + output = ctx.outputs.out + + # Check that all required types are defined. + for t in template.types: + if t not in ctx.attr.types: + fail("Missing value for type %s in %s" % (t, ctx.attr.template.label)) + + # Check that all defined types are expected by the template. + for t in ctx.attr.types: + if (t not in template.types) and (t not in template.opt_types): + fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) + + # Check that all required consts are defined. + for t in template.consts: + if t not in ctx.attr.consts: + fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label)) + + # Check that all defined consts are expected by the template. + for t in ctx.attr.consts: + if (t not in template.consts) and (t not in template.opt_consts): + fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) + + # Build the argument list. + args = ["-i=%s" % template.file.path, "-o=%s" % output.path] + args += ["-p=%s" % ctx.attr.package] + + if len(ctx.attr.prefix) > 0: + args += ["-prefix=%s" % ctx.attr.prefix] + + if len(ctx.attr.suffix) > 0: + args += ["-suffix=%s" % ctx.attr.suffix] + + args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] + args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] + args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] + + ctx.actions.run( + inputs = [template.file], + outputs = [output], + mnemonic = "GoGenericsInstance", + progress_message = "Building Go template instance %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + + # TODO: How can we get the dependencies out? + return struct( + files = depset([output]) + ) + +""" +Instantiates a Go template by replacing all generic types with concrete ones. + +Args: + name: the name of the template instance. + template: the label of the template to be instatiated. + prefix: a prefix to be added to globals in the template. + suffix: a suffix to be added to global in the template. + types: the map from generic type names to concrete ones. + consts: the map from constant names to their values. + imports: the map from imports used in types/consts to their import paths. + package: the name of the package the instantiated template will be compiled into. +""" + +go_template_instance = rule( + attrs = { + "template": attr.label( + mandatory = True, + providers = ["types"], + ), + "prefix": attr.string(), + "suffix": attr.string(), + "types": attr.string_dict(), + "consts": attr.string_dict(), + "imports": attr.string_dict(), + "package": attr.string(mandatory = True), + "out": attr.output(mandatory = True), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_generics"), + ), + }, + implementation = _go_template_instance_impl, +) diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go new file mode 100644 index 000000000..033923442 --- /dev/null +++ b/tools/go_generics/generics.go @@ -0,0 +1,274 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// go_generics reads a Go source file and writes a new version of that file with +// a few transformations applied to each. Namely: +// +// 1. Global types can be explicitly renamed with the -t option. For example, +// if -t=A=B is passed in, all references to A will be replaced with +// references to B; a function declaration like: +// +// func f(arg *A) +// +// would be renamed to: +// +// fun f(arg *B) +// +// 2. Global type definitions and their method sets will be removed when they're +// being renamed with -t. For example, if -t=A=B is passed in, the following +// definition and methods that existed in the input file wouldn't exist at +// all in the output file: +// +// type A struct{} +// +// func (*A) f() {} +// +// 3. All global types, variables, constants and functions (not methods) are +// prefixed and suffixed based on the option -prefix and -suffix arguments. +// For example, if -suffix=A is passed in, the following globals: +// +// func f() +// type t struct{} +// +// would be renamed to: +// +// func fA() +// type tA struct{} +// +// Some special tags are also modified. For example: +// +// "state:.(t)" +// +// would become: +// +// "state:.(tA)" +// +// 4. The package is renamed to the value via the -p argument. +// 5. Value of constants can be modified with -c argument. +// +// Note that not just the top-level declarations are renamed, all references to +// them are also properly renamed as well, taking into account visibility rules +// and shadowing. For example, if -suffix=A is passed in, the following: +// +// var b = 100 +// +// func f() { +// g(b) +// b := 0 +// g(b) +// } +// +// Would be replaced with: +// +// var bA = 100 +// +// func f() { +// g(bA) +// b := 0 +// g(b) +// } +// +// Note that the second call to g() kept "b" as an argument because it refers to +// the local variable "b". +// +// Unfortunately, go_generics does not handle anonymous fields with renamed types. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "os" + "regexp" + "strings" + + "gvisor.googlesource.com/gvisor/tools/go_generics/globals" +) + +var ( + input = flag.String("i", "", "input `file`") + output = flag.String("o", "", "output `file`") + suffix = flag.String("suffix", "", "`suffix` to add to each global symbol") + prefix = flag.String("prefix", "", "`prefix` to add to each global symbol") + packageName = flag.String("p", "main", "output package `name`") + printAST = flag.Bool("ast", false, "prints the AST") + types = make(mapValue) + consts = make(mapValue) + imports = make(mapValue) +) + +// mapValue implements flag.Value. We use a mapValue flag instead of a regular +// string flag when we want to allow more than one instance of the flag. For +// example, we allow several "-t A=B" arguments, and will rename them all. +type mapValue map[string]string + +func (m mapValue) String() string { + var b bytes.Buffer + first := true + for k, v := range m { + if !first { + b.WriteRune(',') + } else { + first = false + } + b.WriteString(k) + b.WriteRune('=') + b.WriteString(v) + } + return b.String() +} + +func (m mapValue) Set(s string) error { + sep := strings.Index(s, "=") + if sep == -1 { + return fmt.Errorf("missing '=' from '%s'", s) + } + + m[s[:sep]] = s[sep+1:] + + return nil +} + +// stateTagRegexp matches against the 'typed' state tags. +var stateTagRegexp = regexp.MustCompile(`^(.*[^a-z0-9_])state:"\.\(([^\)]*)\)"(.*)$`) + +var identifierRegexp = regexp.MustCompile(`^(.*[^a-zA-Z_])([a-zA-Z_][a-zA-Z0-9_]*)(.*)$`) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.Var(types, "t", "rename type A to B when `A=B` is passed in. Multiple such mappings are allowed.") + flag.Var(consts, "c", "reassign constant A to value B when `A=B` is passed in. Multiple such mappings are allowed.") + flag.Var(imports, "import", "specifies the import libraries to use when types are not local. `name=path` specifies that 'name', used in types as name.type, refers to the package living in 'path'.") + flag.Parse() + + if *input == "" || *output == "" { + flag.Usage() + os.Exit(1) + } + + // Parse the input file. + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, *input, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + + // Print the AST if requested. + if *printAST { + ast.Print(fset, f) + } + + cmap := ast.NewCommentMap(fset, f, f.Comments) + + // Update imports based on what's used in types and consts. + maps := []mapValue{types, consts} + importDecl, err := updateImports(maps, imports) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + types = maps[0] + consts = maps[1] + + // Reassign all specified constants. + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.CONST { + continue + } + + for _, gs := range d.Specs { + s := gs.(*ast.ValueSpec) + for i, id := range s.Names { + if n, ok := consts[id.Name]; ok { + s.Values[i] = &ast.BasicLit{Value: n} + } + } + } + } + + // Go through all globals and their uses in the AST and rename the types + // with explicitly provided names, and rename all types, variables, + // consts and functions with the provided prefix and suffix. + globals.Visit(fset, f, func(ident *ast.Ident, kind globals.SymKind) { + if n, ok := types[ident.Name]; ok && kind == globals.KindType { + ident.Name = n + } else { + switch kind { + case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction: + ident.Name = *prefix + ident.Name + *suffix + case globals.KindTag: + // Modify the state tag appropriately. + if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil { + if t := identifierRegexp.FindStringSubmatch(m[2]); t != nil { + ident.Name = m[1] + `state:".(` + t[1] + *prefix + t[2] + *suffix + t[3] + `)"` + m[3] + } + } + } + } + }) + + // Remove the definition of all types that are being remapped. + set := make(typeSet) + for _, v := range types { + set[v] = struct{}{} + } + removeTypes(set, f) + + // Add the new imports, if any, to the top. + if importDecl != nil { + newDecls := make([]ast.Decl, 0, len(f.Decls)+1) + newDecls = append(newDecls, importDecl) + newDecls = append(newDecls, f.Decls...) + f.Decls = newDecls + } + + // Update comments to remove the ones potentially associated with the + // type T that we removed. + f.Comments = cmap.Filter(f).Comments() + + // If there are file (package) comments, delete them. + if f.Doc != nil { + for i, cg := range f.Comments { + if cg == f.Doc { + f.Comments = append(f.Comments[:i], f.Comments[i+1:]...) + break + } + } + } + + // Write the output file. + f.Name.Name = *packageName + + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + + if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/generics_tests/all_stmts/input.go new file mode 100644 index 000000000..870af3b6c --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/input.go @@ -0,0 +1,290 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "sync" +) + +type T int + +func h(T) { +} + +type s struct { + a, b int + c []int +} + +func g(T) *s { + return &s{} +} + +func f() (T, []int) { + // Branch. + goto T + goto R + + // Labeled. +T: + _ = T(0) + + // Empty. +R: + ; + + // Assignment with definition. + a, b, c := T(1), T(2), T(3) + _, _, _ = a, b, c + + // Assignment without definition. + g(T(0)).a, g(T(1)).b, c = int(T(1)), int(T(2)), T(3) + _, _, _ = a, b, c + + // Block. + { + var T T + T = 0 + _ = T + } + + // Declarations. + type Type T + const Const T = 10 + var g1 func(T, int, ...T) (int, T) + var v T + var w = T(0) + { + var T struct { + f []T + } + _ = T + } + + // Defer. + defer g1(T(0), 1) + + // Expression. + h(v + w + T(1)) + + // For statements. + for i := T(0); i < T(10); i++ { + var T func(int) T + v := T(0) + _ = v + } + + for { + var T func(int) T + v := T(0) + _ = v + } + + // Go. + go g1(T(0), 1) + + // If statements. + if a != T(1) { + var T func(int) T + v := T(0) + _ = v + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } else if b := T(0); b != T(1) { + var T func(int) T + v := T(0) + _ = v + } else if T := T(0); T != 1 { + T++ + } else { + T-- + } + + if a := T(0); a != T(1) { + var T func(int) T + v := T(0) + _ = v + } else { + var T func(int) T + v := T(0) + _ = v + } + + // Inc/Dec statements. + (*(*T)(nil))++ + (*(*T)(nil))-- + + // Range statements. + for g(T(0)).a, g(T(1)).b = range g(T(10)).c { + var d T + _ = d + } + + for T, b := range g(T(10)).c { + _ = T + _ = b + } + + // Select statement. + { + var fch func(T) chan int + + select { + case <-fch(T(30)): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + case T := <-fch(T(30)): + T = 0 + _ = T + case g(T(0)).a = <-fch(T(30)): + var T T + T = 0 + _ = T + case fch(T(30)) <- int(T(0)): + var T T + T = 0 + _ = T + } + } + + // Send statements. + { + var ch chan T + var fch func(T) chan int + + ch <- T(0) + fch(T(1)) <- g(T(10)).a + } + + // Switch statements. + { + var a T + var b int + switch { + case a == T(0): + var T T + T = 0 + _ = T + case a < T(0), b < g(T(10)).a: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + } + + switch T(g(T(10)).a) { + case T(0): + var T T + T = 0 + _ = T + case T(1), T(g(T(10)).a): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch b := g(T(10)); T(b.a) + T(10) { + case T(0): + var T T + T = 0 + _ = T + case T(1), T(g(T(10)).a): + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + // Type switch statements. + { + var interfaceFunc func(T) interface{} + + switch interfaceFunc(T(0)).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch x := interfaceFunc(T(0)).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + _ = x + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + + switch t := T(0); x := interfaceFunc(T(0) + t).(type) { + case *T, T, int: + var T T + T = 0 + _ = T + _ = x + case sync.Mutex, **T: + var T T + T = 0 + _ = T + default: + var T T + T = 0 + _ = T + } + } + + // Return statement. + return T(10), g(T(11)).c +} diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt new file mode 100644 index 000000000..c9d0e09bf --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/opts.txt @@ -0,0 +1 @@ +-t=T=Q diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/generics_tests/all_stmts/output/output.go new file mode 100644 index 000000000..e4e670bf1 --- /dev/null +++ b/tools/go_generics/generics_tests/all_stmts/output/output.go @@ -0,0 +1,288 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "sync" +) + +func h(Q) { +} + +type s struct { + a, b int + c []int +} + +func g(Q) *s { + return &s{} +} + +func f() (Q, []int) { + // Branch. + goto T + goto R + + // Labeled. +T: + _ = Q(0) + + // Empty. +R: + ; + + // Assignment with definition. + a, b, c := Q(1), Q(2), Q(3) + _, _, _ = a, b, c + + // Assignment without definition. + g(Q(0)).a, g(Q(1)).b, c = int(Q(1)), int(Q(2)), Q(3) + _, _, _ = a, b, c + + // Block. + { + var T Q + T = 0 + _ = T + } + + // Declarations. + type Type Q + const Const Q = 10 + var g1 func(Q, int, ...Q) (int, Q) + var v Q + var w = Q(0) + { + var T struct { + f []Q + } + _ = T + } + + // Defer. + defer g1(Q(0), 1) + + // Expression. + h(v + w + Q(1)) + + // For statements. + for i := Q(0); i < Q(10); i++ { + var T func(int) Q + v := T(0) + _ = v + } + + for { + var T func(int) Q + v := T(0) + _ = v + } + + // Go. + go g1(Q(0), 1) + + // If statements. + if a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else if b := Q(0); b != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else if T := Q(0); T != 1 { + T++ + } else { + T-- + } + + if a := Q(0); a != Q(1) { + var T func(int) Q + v := T(0) + _ = v + } else { + var T func(int) Q + v := T(0) + _ = v + } + + // Inc/Dec statements. + (*(*Q)(nil))++ + (*(*Q)(nil))-- + + // Range statements. + for g(Q(0)).a, g(Q(1)).b = range g(Q(10)).c { + var d Q + _ = d + } + + for T, b := range g(Q(10)).c { + _ = T + _ = b + } + + // Select statement. + { + var fch func(Q) chan int + + select { + case <-fch(Q(30)): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + case T := <-fch(Q(30)): + T = 0 + _ = T + case g(Q(0)).a = <-fch(Q(30)): + var T Q + T = 0 + _ = T + case fch(Q(30)) <- int(Q(0)): + var T Q + T = 0 + _ = T + } + } + + // Send statements. + { + var ch chan Q + var fch func(Q) chan int + + ch <- Q(0) + fch(Q(1)) <- g(Q(10)).a + } + + // Switch statements. + { + var a Q + var b int + switch { + case a == Q(0): + var T Q + T = 0 + _ = T + case a < Q(0), b < g(Q(10)).a: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + } + + switch Q(g(Q(10)).a) { + case Q(0): + var T Q + T = 0 + _ = T + case Q(1), Q(g(Q(10)).a): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch b := g(Q(10)); Q(b.a) + Q(10) { + case Q(0): + var T Q + T = 0 + _ = T + case Q(1), Q(g(Q(10)).a): + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + // Type switch statements. + { + var interfaceFunc func(Q) interface{} + + switch interfaceFunc(Q(0)).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch x := interfaceFunc(Q(0)).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + _ = x + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + + switch t := Q(0); x := interfaceFunc(Q(0) + t).(type) { + case *Q, Q, int: + var T Q + T = 0 + _ = T + _ = x + case sync.Mutex, **Q: + var T Q + T = 0 + _ = T + default: + var T Q + T = 0 + _ = T + } + } + + // Return statement. + return Q(10), g(Q(11)).c +} diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/generics_tests/all_types/input.go new file mode 100644 index 000000000..3a8643e3d --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/input.go @@ -0,0 +1,43 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import "./lib" + +type T int + +type newType struct { + a T + b lib.T + c *T + d (T) + e chan T + f <-chan T + g chan<- T + h []T + i [10]T + j map[T]T + k func(T, T) (T, T) + l interface { + f(T) + } + m struct { + T + a T + } +} + +func f(...T) { +} diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/generics_tests/all_types/lib/lib.go new file mode 100644 index 000000000..d3911d12d --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/lib/lib.go @@ -0,0 +1,17 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lib + +type T int32 diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt new file mode 100644 index 000000000..c9d0e09bf --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/opts.txt @@ -0,0 +1 @@ +-t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/generics_tests/all_types/output/output.go new file mode 100644 index 000000000..b89840936 --- /dev/null +++ b/tools/go_generics/generics_tests/all_types/output/output.go @@ -0,0 +1,41 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "./lib" + +type newType struct { + a Q + b lib.T + c *Q + d (Q) + e chan Q + f <-chan Q + g chan<- Q + h []Q + i [10]Q + j map[Q]Q + k func(Q, Q) (Q, Q) + l interface { + f(Q) + } + m struct { + Q + a Q + } +} + +func f(...Q) { +} diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/generics_tests/consts/input.go new file mode 100644 index 000000000..dabf76e1e --- /dev/null +++ b/tools/go_generics/generics_tests/consts/input.go @@ -0,0 +1,26 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +const c1 = 10 +const x, y, z = 100, 200, 300 +const v float32 = 1.0 + 2.0 +const s = "abc" +const ( + A = 10 + B, C, D = 10, 20, 30 + S = "abc" + T, U, V string = "abc", "def", "ghi" +) diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt new file mode 100644 index 000000000..4fb59dce8 --- /dev/null +++ b/tools/go_generics/generics_tests/consts/opts.txt @@ -0,0 +1 @@ +-c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/generics_tests/consts/output/output.go new file mode 100644 index 000000000..72865607e --- /dev/null +++ b/tools/go_generics/generics_tests/consts/output/output.go @@ -0,0 +1,26 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +const c1 = 20 +const x, y, z = 100, 200, 600 +const v float32 = 3.3 +const s = "def" +const ( + A = 20 + B, C, D = 10, 100, 30 + S = "def" + T, U, V string = "ABC", "def", "ghi" +) diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/generics_tests/imports/input.go new file mode 100644 index 000000000..66b43fee5 --- /dev/null +++ b/tools/go_generics/generics_tests/imports/input.go @@ -0,0 +1,24 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type T int + +var global T + +const ( + m = 0 + n = 0 +) diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt new file mode 100644 index 000000000..87324be79 --- /dev/null +++ b/tools/go_generics/generics_tests/imports/opts.txt @@ -0,0 +1 @@ +-t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/generics_tests/imports/output/output.go new file mode 100644 index 000000000..5f20d43ce --- /dev/null +++ b/tools/go_generics/generics_tests/imports/output/output.go @@ -0,0 +1,27 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + __generics_imported1 "mymathpath" + __generics_imported0 "sync" +) + +var global __generics_imported0.Mutex + +const ( + m = __generics_imported1.Uint64 + n = __generics_imported1.Uint32 +) diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/generics_tests/remove_typedef/input.go new file mode 100644 index 000000000..c02307d32 --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/input.go @@ -0,0 +1,37 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +func f(T) Q { + return Q{} +} + +type T struct{} + +type Q struct{} + +func (*T) f() { +} + +func (T) g() { +} + +func (*Q) f(T) T { + return T{} +} + +func (*Q) g(T) *T { + return nil +} diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt new file mode 100644 index 000000000..9c8ecaada --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/opts.txt @@ -0,0 +1 @@ +-t=T=U diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/generics_tests/remove_typedef/output/output.go new file mode 100644 index 000000000..d20a89abd --- /dev/null +++ b/tools/go_generics/generics_tests/remove_typedef/output/output.go @@ -0,0 +1,29 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +func f(U) Q { + return Q{} +} + +type Q struct{} + +func (*Q) f(U) U { + return U{} +} + +func (*Q) g(U) *U { + return nil +} diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/generics_tests/simple/input.go new file mode 100644 index 000000000..670161d6e --- /dev/null +++ b/tools/go_generics/generics_tests/simple/input.go @@ -0,0 +1,45 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type T int + +var global T + +func f(_ T, a int) { +} + +func g(a T, b int) { + var c T + _ = c + + d := (*T)(nil) + _ = d +} + +type R struct { + T + a *T +} + +var ( + Z *T = (*T)(nil) +) + +const ( + X T = (T)(0) +) + +type Y T diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt new file mode 100644 index 000000000..7832ef66f --- /dev/null +++ b/tools/go_generics/generics_tests/simple/opts.txt @@ -0,0 +1 @@ +-t=T=Q -suffix=New diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/generics_tests/simple/output/output.go new file mode 100644 index 000000000..75b5467cd --- /dev/null +++ b/tools/go_generics/generics_tests/simple/output/output.go @@ -0,0 +1,43 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +var globalNew Q + +func fNew(_ Q, a int) { +} + +func gNew(a Q, b int) { + var c Q + _ = c + + d := (*Q)(nil) + _ = d +} + +type RNew struct { + Q + a *Q +} + +var ( + ZNew *Q = (*Q)(nil) +) + +const ( + XNew Q = (Q)(0) +) + +type YNew Q diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD new file mode 100644 index 000000000..a238becab --- /dev/null +++ b/tools/go_generics/globals/BUILD @@ -0,0 +1,13 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "globals", + srcs = [ + "globals_visitor.go", + "scope.go", + ], + importpath = "gvisor.googlesource.com/gvisor/tools/go_generics/globals", + visibility = ["//tools/go_generics:__pkg__"], +) diff --git a/tools/go_generics/globals/globals_visitor.go b/tools/go_generics/globals/globals_visitor.go new file mode 100644 index 000000000..fc0de4381 --- /dev/null +++ b/tools/go_generics/globals/globals_visitor.go @@ -0,0 +1,588 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package globals provides an AST visitor that calls the visit function for all +// global identifiers. +package globals + +import ( + "fmt" + + "go/ast" + "go/token" + "path/filepath" + "strconv" +) + +// globalsVisitor holds the state used while traversing the nodes of a file in +// search of globals. +// +// The visitor does two passes on the global declarations: the first one adds +// all globals to the global scope (since Go allows references to globals that +// haven't been declared yet), and the second one calls f() for the definition +// and uses of globals found in the first pass. +// +// The implementation correctly handles cases when globals are aliased by +// locals; in such cases, f() is not called. +type globalsVisitor struct { + // file is the file whose nodes are being visited. + file *ast.File + + // fset is the file set the file being visited belongs to. + fset *token.FileSet + + // f is the visit function to be called when a global symbol is reached. + f func(*ast.Ident, SymKind) + + // scope is the current scope as nodes are visited. + scope *scope +} + +// unexpected is called when an unexpected node appears in the AST. It dumps +// the location of the associated token and panics because this should only +// happen when there is a bug in the traversal code. +func (v *globalsVisitor) unexpected(p token.Pos) { + panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p))) +} + +// pushScope creates a new scope and pushes it to the top of the scope stack. +func (v *globalsVisitor) pushScope() { + v.scope = newScope(v.scope) +} + +// popScope removes the scope created by the last call to pushScope. +func (v *globalsVisitor) popScope() { + v.scope = v.scope.outer +} + +// visitType is called when an expression is known to be a type, for example, +// on the first argument of make(). It visits all children nodes and reports +// any globals. +func (v *globalsVisitor) visitType(ge ast.Expr) { + switch e := ge.(type) { + case *ast.Ident: + if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { + v.f(e, s.kind) + } + + case *ast.SelectorExpr: + id := GetIdent(e.X) + if id == nil { + v.unexpected(e.X.Pos()) + } + + case *ast.StarExpr: + v.visitType(e.X) + case *ast.ParenExpr: + v.visitType(e.X) + case *ast.ChanType: + v.visitType(e.Value) + case *ast.Ellipsis: + v.visitType(e.Elt) + case *ast.ArrayType: + v.visitExpr(e.Len) + v.visitType(e.Elt) + case *ast.MapType: + v.visitType(e.Key) + v.visitType(e.Value) + case *ast.StructType: + v.visitFields(e.Fields, KindUnknown) + case *ast.FuncType: + v.visitFields(e.Params, KindUnknown) + v.visitFields(e.Results, KindUnknown) + case *ast.InterfaceType: + v.visitFields(e.Methods, KindUnknown) + default: + v.unexpected(ge.Pos()) + } +} + +// visitFields visits all fields, and add symbols if kind isn't KindUnknown. +func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) { + if l == nil { + return + } + + for _, f := range l.List { + if kind != KindUnknown { + for _, n := range f.Names { + v.scope.add(n.Name, kind, n.Pos()) + } + } + v.visitType(f.Type) + if f.Tag != nil { + tag := ast.NewIdent(f.Tag.Value) + v.f(tag, KindTag) + // Replace the tag if updated. + if tag.Name != f.Tag.Value { + f.Tag.Value = tag.Name + } + } + } +} + +// visitGenDecl is called when a generic declation is encountered, for example, +// on variable, constant and type declarations. It adds all newly defined +// symbols to the current scope and reports them if the current scope is the +// global one. +func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) { + switch d.Tok { + case token.IMPORT: + case token.TYPE: + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + v.scope.add(s.Name.Name, KindType, s.Name.Pos()) + if v.scope.isGlobal() { + v.f(s.Name, KindType) + } + v.visitType(s.Type) + } + case token.CONST, token.VAR: + kind := KindConst + if d.Tok == token.VAR { + kind = KindVar + } + + for _, gs := range d.Specs { + s := gs.(*ast.ValueSpec) + if s.Type != nil { + v.visitType(s.Type) + } + + for _, e := range s.Values { + v.visitExpr(e) + } + + for _, n := range s.Names { + if v.scope.isGlobal() { + v.f(n, kind) + } + v.scope.add(n.Name, kind, n.Pos()) + } + } + default: + v.unexpected(d.Pos()) + } +} + +// isViableType determines if the given expression is a viable type expression, +// that is, if it could be interpreted as a type, for example, sync.Mutex, +// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc. +func (v *globalsVisitor) isViableType(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.Ident: + // This covers the plain identifier case. When we see it, we + // have to check if it resolves to a type; if the symbol is not + // known, we'll claim it's viable as a type. + s := v.scope.deepLookup(e.Name) + return s == nil || s.kind == KindType + + case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis: + // This covers the following cases: + // 1. ChanType: + // chan T + // <-chan T + // chan<- T + // 2. ArrayType: + // [Expr]T + // 3. MapType: + // map[T]U + // 4. StructType: + // struct { Fields } + // 5. FuncType: + // func(Fields)Returns + // 6. Interface: + // interface { Fields } + // 7. Ellipsis: + // ...T + return true + + case *ast.SelectorExpr: + // The only case in which an expression involving a selector can + // be a type is if it has the following form X.T, where X is an + // import, and T is a type exported by X. + // + // There's no way to know whether T is a type because we don't + // parse imports. So we just claim that this is a viable type; + // it doesn't affect the general result because we don't visit + // imported symbols. + id := GetIdent(e.X) + if id == nil { + return false + } + + s := v.scope.deepLookup(id.Name) + return s != nil && s.kind == KindImport + + case *ast.StarExpr: + // This covers the *T case. The expression is a viable type if + // T is. + return v.isViableType(e.X) + + case *ast.ParenExpr: + // This covers the (T) case. The expression is a viable type if + // T is. + return v.isViableType(e.X) + + default: + return false + } +} + +// visitCallExpr visits a "call expression" which can be either a +// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type +// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.). +func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) { + if v.isViableType(e.Fun) { + v.visitType(e.Fun) + } else { + v.visitExpr(e.Fun) + } + + // If the function being called is new or make, the first argument is + // a type, so it needs to be visited as such. + first := 0 + if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") { + if len(e.Args) > 0 { + v.visitType(e.Args[0]) + } + first = 1 + } + + for i := first; i < len(e.Args); i++ { + v.visitExpr(e.Args[i]) + } +} + +// visitExpr visits all nodes of an expression, and reports any globals that it +// finds. +func (v *globalsVisitor) visitExpr(ge ast.Expr) { + switch e := ge.(type) { + case nil: + case *ast.Ident: + if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() { + v.f(e, s.kind) + } + + case *ast.BasicLit: + case *ast.CompositeLit: + v.visitType(e.Type) + for _, ne := range e.Elts { + v.visitExpr(ne) + } + case *ast.FuncLit: + v.pushScope() + v.visitFields(e.Type.Params, KindParameter) + v.visitFields(e.Type.Results, KindResult) + v.visitBlockStmt(e.Body) + v.popScope() + + case *ast.BinaryExpr: + v.visitExpr(e.X) + v.visitExpr(e.Y) + + case *ast.CallExpr: + v.visitCallExpr(e) + + case *ast.IndexExpr: + v.visitExpr(e.X) + v.visitExpr(e.Index) + + case *ast.KeyValueExpr: + v.visitExpr(e.Value) + + case *ast.ParenExpr: + v.visitExpr(e.X) + + case *ast.SelectorExpr: + v.visitExpr(e.X) + + case *ast.SliceExpr: + v.visitExpr(e.X) + v.visitExpr(e.Low) + v.visitExpr(e.High) + v.visitExpr(e.Max) + + case *ast.StarExpr: + v.visitExpr(e.X) + + case *ast.TypeAssertExpr: + v.visitExpr(e.X) + if e.Type != nil { + v.visitType(e.Type) + } + + case *ast.UnaryExpr: + v.visitExpr(e.X) + + default: + v.unexpected(ge.Pos()) + } +} + +// GetIdent returns the identifier associated with the given expression by +// removing parentheses if needed. +func GetIdent(expr ast.Expr) *ast.Ident { + switch e := expr.(type) { + case *ast.Ident: + return e + case *ast.ParenExpr: + return GetIdent(e.X) + default: + return nil + } +} + +// visitStmt visits all nodes of a statement, and reports any globals that it +// finds. It also adds to the current scope new symbols defined/declared. +func (v *globalsVisitor) visitStmt(gs ast.Stmt) { + switch s := gs.(type) { + case nil, *ast.BranchStmt, *ast.EmptyStmt: + case *ast.AssignStmt: + for _, e := range s.Rhs { + v.visitExpr(e) + } + + // We visit the LHS after the RHS because the symbols we'll + // potentially add to the table aren't meant to be visible to + // the RHS. + for _, e := range s.Lhs { + if s.Tok == token.DEFINE { + if n := GetIdent(e); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + } + v.visitExpr(e) + } + + case *ast.BlockStmt: + v.visitBlockStmt(s) + + case *ast.DeclStmt: + v.visitGenDecl(s.Decl.(*ast.GenDecl)) + + case *ast.DeferStmt: + v.visitCallExpr(s.Call) + + case *ast.ExprStmt: + v.visitExpr(s.X) + + case *ast.ForStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Cond) + v.visitStmt(s.Post) + v.visitBlockStmt(s.Body) + v.popScope() + + case *ast.GoStmt: + v.visitCallExpr(s.Call) + + case *ast.IfStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Cond) + v.visitBlockStmt(s.Body) + v.visitStmt(s.Else) + v.popScope() + + case *ast.IncDecStmt: + v.visitExpr(s.X) + + case *ast.LabeledStmt: + v.visitStmt(s.Stmt) + + case *ast.RangeStmt: + v.pushScope() + v.visitExpr(s.X) + if s.Tok == token.DEFINE { + if n := GetIdent(s.Key); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + + if n := GetIdent(s.Value); n != nil { + v.scope.add(n.Name, KindVar, n.Pos()) + } + } + v.visitExpr(s.Key) + v.visitExpr(s.Value) + v.visitBlockStmt(s.Body) + v.popScope() + + case *ast.ReturnStmt: + for _, r := range s.Results { + v.visitExpr(r) + } + + case *ast.SelectStmt: + for _, ns := range s.Body.List { + c := ns.(*ast.CommClause) + + v.pushScope() + v.visitStmt(c.Comm) + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + + case *ast.SendStmt: + v.visitExpr(s.Chan) + v.visitExpr(s.Value) + + case *ast.SwitchStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitExpr(s.Tag) + for _, ns := range s.Body.List { + c := ns.(*ast.CaseClause) + v.pushScope() + for _, ce := range c.List { + v.visitExpr(ce) + } + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + v.popScope() + + case *ast.TypeSwitchStmt: + v.pushScope() + v.visitStmt(s.Init) + v.visitStmt(s.Assign) + for _, ns := range s.Body.List { + c := ns.(*ast.CaseClause) + v.pushScope() + for _, ce := range c.List { + v.visitType(ce) + } + for _, bs := range c.Body { + v.visitStmt(bs) + } + v.popScope() + } + v.popScope() + + default: + v.unexpected(gs.Pos()) + } +} + +// visitBlockStmt visits all statements in the block, adding symbols to a newly +// created scope. +func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) { + v.pushScope() + for _, c := range s.List { + v.visitStmt(c) + } + v.popScope() +} + +// visitFuncDecl is called when a function or method declation is encountered. +// it creates a new scope for the function [optional] receiver, parameters and +// results, and visits all children nodes. +func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) { + // We don't report methods. + if d.Recv == nil { + v.f(d.Name, KindFunction) + } + + v.pushScope() + v.visitFields(d.Recv, KindReceiver) + v.visitFields(d.Type.Params, KindParameter) + v.visitFields(d.Type.Results, KindResult) + if d.Body != nil { + v.visitBlockStmt(d.Body) + } + v.popScope() +} + +// globalsFromDecl is called in the first, and adds symbols to global scope. +func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) { + switch d.Tok { + case token.IMPORT: + for _, gs := range d.Specs { + s := gs.(*ast.ImportSpec) + if s.Name == nil { + str, _ := strconv.Unquote(s.Path.Value) + v.scope.add(filepath.Base(str), KindImport, s.Path.Pos()) + } else if s.Name.Name != "_" { + v.scope.add(s.Name.Name, KindImport, s.Name.Pos()) + } + } + case token.TYPE: + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + v.scope.add(s.Name.Name, KindType, s.Name.Pos()) + } + case token.CONST, token.VAR: + kind := KindConst + if d.Tok == token.VAR { + kind = KindVar + } + + for _, s := range d.Specs { + for _, n := range s.(*ast.ValueSpec).Names { + v.scope.add(n.Name, kind, n.Pos()) + } + } + default: + v.unexpected(d.Pos()) + } +} + +// visit implements the visiting of globals. It does performs the two passes +// described in the description of the globalsVisitor struct. +func (v *globalsVisitor) visit() { + // Gather all symbols in the global scope. This excludes methods. + v.pushScope() + for _, gd := range v.file.Decls { + switch d := gd.(type) { + case *ast.GenDecl: + v.globalsFromGenDecl(d) + case *ast.FuncDecl: + if d.Recv == nil { + v.scope.add(d.Name.Name, KindFunction, d.Name.Pos()) + } + default: + v.unexpected(gd.Pos()) + } + } + + // Go through the contents of the declarations. + for _, gd := range v.file.Decls { + switch d := gd.(type) { + case *ast.GenDecl: + v.visitGenDecl(d) + case *ast.FuncDecl: + v.visitFuncDecl(d) + } + } +} + +// Visit traverses the provided AST and calls f() for each identifier that +// refers to global names. The global name must be defined in the file itself. +// +// The function f() is allowed to modify the identifier, for example, to rename +// uses of global references. +func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind)) { + v := globalsVisitor{ + fset: fset, + file: file, + f: f, + } + + v.visit() +} diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go new file mode 100644 index 000000000..18743bdee --- /dev/null +++ b/tools/go_generics/globals/scope.go @@ -0,0 +1,80 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package globals + +import ( + "go/token" +) + +// SymKind specifies the kind of a global symbol. For example, a variable, const +// function, etc. +type SymKind int + +// Constants for different kinds of symbols. +const ( + KindUnknown SymKind = iota + KindImport + KindType + KindVar + KindConst + KindFunction + KindReceiver + KindParameter + KindResult + KindTag +) + +type symbol struct { + kind SymKind + pos token.Pos + scope *scope +} + +type scope struct { + outer *scope + syms map[string]*symbol +} + +func newScope(outer *scope) *scope { + return &scope{ + outer: outer, + syms: make(map[string]*symbol), + } +} + +func (s *scope) isGlobal() bool { + return s.outer == nil +} + +func (s *scope) lookup(n string) *symbol { + return s.syms[n] +} + +func (s *scope) deepLookup(n string) *symbol { + for x := s; x != nil; x = x.outer { + if sym := x.lookup(n); sym != nil { + return sym + } + } + return nil +} + +func (s *scope) add(name string, kind SymKind, pos token.Pos) { + s.syms[name] = &symbol{ + kind: kind, + pos: pos, + scope: s, + } +} diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh new file mode 100755 index 000000000..699e1f631 --- /dev/null +++ b/tools/go_generics/go_generics_unittest.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Bash "safe-mode": Treat command failures as fatal (even those that occur in +# pipes), and treat unset variables as errors. +set -eu -o pipefail + +# This file will be generated as a self-extracting shell script in order to +# eliminate the need for any runtime dependencies. The tarball at the end will +# include the go_generics binary, as well as a subdirectory named +# generics_tests. See the BUILD file for more information. +declare -r temp=$(mktemp -d) +function cleanup() { + rm -rf "${temp}" +} +# trap cleanup EXIT + +# Print message in "$1" then exit with status 1. +function die () { + echo "$1" 1>&2 + exit 1 +} + +# This prints the line number of __BUNDLE__ below, that should be the last line +# of this script. After that point, the concatenated archive will be the +# contents. +declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` +tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" + +# The target for the test. +declare -r binary="$(find ${temp} -type f -a -name go_generics)" +declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" + +# Go through all test cases. +for f in ${input_dirs}; do + base=$(basename "${f}") + + # Run go_generics on the input file. + opts=$(head -n 1 ${f}/opts.txt) + out="${f}/output/generated.go" + expected="${f}/output/output.go" + ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" + + # Compare the outputs. + diff ${expected} ${out} + if [ $? -ne 0 ]; then + echo "Expected:" + cat ${expected} + echo "Actual:" + cat ${out} + die "Actual output is different from expected for test \"${base}\"" + fi +done + +echo "PASS" +exit 0 +__BUNDLE__ diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go new file mode 100644 index 000000000..97267098b --- /dev/null +++ b/tools/go_generics/imports.go @@ -0,0 +1,150 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "strconv" + + "gvisor.googlesource.com/gvisor/tools/go_generics/globals" +) + +type importedPackage struct { + newName string + path string +} + +// updateImportIdent modifies the given import identifier with the new name +// stored in the used map. If the identifier doesn't exist in the used map yet, +// a new name is generated and inserted into the map. +func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error { + importName := id.Name + + // If the name is already in the table, just use the new name. + m := used[importName] + if m != nil { + id.Name = m.newName + return nil + } + + // Create a new entry in the used map. + path := imports[importName] + if path == "" { + return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig) + } + + m = &importedPackage{ + newName: fmt.Sprintf("__generics_imported%d", len(used)), + path: strconv.Quote(path), + } + used[importName] = m + + id.Name = m.newName + + return nil +} + +// convertExpression creates a new string that is a copy of the input one with +// all imports references renamed to the names in the "used" map. If the +// referenced import isn't in "used" yet, a new one is created based on the path +// in "imports" and stored in "used". For example, if string s is +// "math.MaxUint32-math.MaxUint16+10", it would be converted to +// "x.MaxUint32-x.MathUint16+10", where x is a generated name. +func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) { + // Parse the expression in the input string. + expr, err := parser.ParseExpr(s) + if err != nil { + return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err) + } + + // Go through the AST and update references. + var retErr error + ast.Inspect(expr, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.SelectorExpr: + if id := globals.GetIdent(x.X); id != nil { + if err := updateImportIdent(s, imports, id, used); err != nil { + retErr = err + } + return false + } + } + return true + }) + if retErr != nil { + return "", retErr + } + + // Convert the modified AST back to a string. + fset := token.NewFileSet() + var buf bytes.Buffer + if err := format.Node(&buf, fset, expr); err != nil { + return "", err + } + + return string(buf.Bytes()), nil +} + +// updateImports replaces all maps in the input slice with copies where the +// mapped values have had all references to imported packages renamed to +// generated names. It also returns an import declaration for all the renamed +// import packages. +// +// For example, if the input maps contains A=math.B and C=math.D, the updated +// maps will instead contain A=__generics_imported0.B and +// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would +// be returned as the import declaration. +func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) { + importsUsed := make(map[string]*importedPackage) + + // Update all maps. + for i, m := range maps { + newMap := make(mapValue) + for n, e := range m { + updated, err := convertExpression(e, imports, importsUsed) + if err != nil { + return nil, err + } + + newMap[n] = updated + } + maps[i] = newMap + } + + // Nothing else to do if no imports are used in the expressions. + if len(importsUsed) == 0 { + return nil, nil + } + + // Create spec array for each new import. + specs := make([]ast.Spec, 0, len(importsUsed)) + for _, i := range importsUsed { + specs = append(specs, &ast.ImportSpec{ + Name: &ast.Ident{Name: i.newName}, + Path: &ast.BasicLit{Value: i.path}, + }) + } + + return &ast.GenDecl{ + Tok: token.IMPORT, + Specs: specs, + Lparen: token.NoPos + 1, + }, nil +} diff --git a/tools/go_generics/merge.go b/tools/go_generics/merge.go new file mode 100644 index 000000000..ebe7cf4e4 --- /dev/null +++ b/tools/go_generics/merge.go @@ -0,0 +1,139 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "strconv" +) + +var ( + output = flag.String("o", "", "output `file`") +) + +func fatalf(s string, args ...interface{}) { + fmt.Fprintf(os.Stderr, s, args...) + os.Exit(1) +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options] <input1> [<input2> ...]\n", os.Args[0]) + flag.PrintDefaults() + } + + flag.Parse() + if *output == "" || len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + + // Load all files. + files := make(map[string]*ast.File) + fset := token.NewFileSet() + var name string + for _, fname := range flag.Args() { + f, err := parser.ParseFile(fset, fname, nil, parser.ParseComments|parser.DeclarationErrors|parser.SpuriousErrors) + if err != nil { + fatalf("%v\n", err) + } + + files[fname] = f + if name == "" { + name = f.Name.Name + } else if name != f.Name.Name { + fatalf("Expected '%s' for package name instead of '%s'.\n", name, f.Name.Name) + } + } + + // Merge all files into one. + pkg := &ast.Package{ + Name: name, + Files: files, + } + f := ast.MergePackageFiles(pkg, ast.FilterUnassociatedComments|ast.FilterFuncDuplicates|ast.FilterImportDuplicates) + + // Create a new declaration slice with all imports at the top, merging any + // redundant imports. + imports := make(map[string]*ast.ImportSpec) + var anonImports []*ast.ImportSpec + for _, d := range f.Decls { + if g, ok := d.(*ast.GenDecl); ok && g.Tok == token.IMPORT { + for _, s := range g.Specs { + i := s.(*ast.ImportSpec) + p, _ := strconv.Unquote(i.Path.Value) + var n string + if i.Name == nil { + n = filepath.Base(p) + } else { + n = i.Name.Name + } + if n == "_" { + anonImports = append(anonImports, i) + } else { + if i2, ok := imports[n]; ok { + if first, second := i.Path.Value, i2.Path.Value; first != second { + fatalf("Conflicting paths for import name '%s': '%s' vs. '%s'\n", n, first, second) + } + } else { + imports[n] = i + } + } + } + } + } + newDecls := make([]ast.Decl, 0, len(f.Decls)) + if l := len(imports) + len(anonImports); l > 0 { + // Non-NoPos Lparen is needed for Go to recognize more than one spec in + // ast.GenDecl.Specs. + d := &ast.GenDecl{ + Tok: token.IMPORT, + Lparen: token.NoPos + 1, + Specs: make([]ast.Spec, 0, l), + } + for _, i := range imports { + d.Specs = append(d.Specs, i) + } + for _, i := range anonImports { + d.Specs = append(d.Specs, i) + } + newDecls = append(newDecls, d) + } + for _, d := range f.Decls { + if g, ok := d.(*ast.GenDecl); !ok || g.Tok != token.IMPORT { + newDecls = append(newDecls, d) + } + } + f.Decls = newDecls + + // Write the output file. + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + fatalf("%v\n", err) + } + + if err := ioutil.WriteFile(*output, buf.Bytes(), 0644); err != nil { + fatalf("%v\n", err) + } +} diff --git a/tools/go_generics/remove.go b/tools/go_generics/remove.go new file mode 100644 index 000000000..2a66de762 --- /dev/null +++ b/tools/go_generics/remove.go @@ -0,0 +1,105 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "go/ast" + "go/token" +) + +type typeSet map[string]struct{} + +// isTypeOrPointerToType determines if the given AST expression represents a +// type or a pointer to a type that exists in the provided type set. +func isTypeOrPointerToType(set typeSet, expr ast.Expr, starCount int) bool { + switch e := expr.(type) { + case *ast.Ident: + _, ok := set[e.Name] + return ok + case *ast.StarExpr: + if starCount > 1 { + return false + } + return isTypeOrPointerToType(set, e.X, starCount+1) + case *ast.ParenExpr: + return isTypeOrPointerToType(set, e.X, starCount) + default: + return false + } +} + +// isMethodOf determines if the given function declaration is a method of one +// of the types in the provided type set. To do that, it checks if the function +// has a receiver and that its type is either T or *T, where T is a type that +// exists in the set. This is per the spec: +// +// That parameter section must declare a single parameter, the receiver. Its +// type must be of the form T or *T (possibly using parentheses) where T is a +// type name. The type denoted by T is called the receiver base type; it must +// not be a pointer or interface type and it must be declared in the same +// package as the method. +func isMethodOf(set typeSet, f *ast.FuncDecl) bool { + // If the function doesn't have exactly one receiver, then it's + // definitely not a method. + if f.Recv == nil || len(f.Recv.List) != 1 { + return false + } + + return isTypeOrPointerToType(set, f.Recv.List[0].Type, 0) +} + +// removeTypeDefinitions removes the definition of all types contained in the +// provided type set. +func removeTypeDefinitions(set typeSet, d *ast.GenDecl) { + if d.Tok != token.TYPE { + return + } + + i := 0 + for _, gs := range d.Specs { + s := gs.(*ast.TypeSpec) + if _, ok := set[s.Name.Name]; !ok { + d.Specs[i] = gs + i++ + } + } + + d.Specs = d.Specs[:i] +} + +// removeTypes removes from the AST the definition of all types and their +// method sets that are contained in the provided type set. +func removeTypes(set typeSet, f *ast.File) { + // Go through the top-level declarations. + i := 0 + for _, decl := range f.Decls { + keep := true + switch d := decl.(type) { + case *ast.GenDecl: + countBefore := len(d.Specs) + removeTypeDefinitions(set, d) + keep = countBefore == 0 || len(d.Specs) > 0 + case *ast.FuncDecl: + keep = !isMethodOf(set, d) + } + + if keep { + f.Decls[i] = decl + i++ + } + } + + f.Decls = f.Decls[:i] +} diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD new file mode 100644 index 000000000..2d9a6fa9d --- /dev/null +++ b/tools/go_generics/rules_tests/BUILD @@ -0,0 +1,43 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +go_template_instance( + name = "instance", + out = "instance_test.go", + consts = { + "n": "20", + "m": "\"test\"", + "o": "math.MaxUint64", + }, + imports = { + "math": "math", + }, + package = "template_test", + template = ":test_template", + types = { + "t": "int", + }, +) + +go_template( + name = "test_template", + srcs = [ + "template.go", + ], + opt_consts = [ + "n", + "m", + "o", + ], + opt_types = ["t"], +) + +go_test( + name = "template_test", + srcs = [ + "instance_test.go", + "template_test.go", + ], +) diff --git a/tools/go_generics/rules_tests/template.go b/tools/go_generics/rules_tests/template.go new file mode 100644 index 000000000..73c024f0e --- /dev/null +++ b/tools/go_generics/rules_tests/template.go @@ -0,0 +1,42 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package template + +type t float + +const ( + n t = 10.1 + m = "abc" + o = 0 +) + +func max(a, b t) t { + if a > b { + return a + } + return b +} + +func add(a t) t { + return a + n +} + +func getName() string { + return m +} + +func getMax() uint64 { + return o +} diff --git a/tools/go_generics/rules_tests/template_test.go b/tools/go_generics/rules_tests/template_test.go new file mode 100644 index 000000000..76c4cdb64 --- /dev/null +++ b/tools/go_generics/rules_tests/template_test.go @@ -0,0 +1,48 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package template_test + +import ( + "math" + "testing" +) + +func TestMax(t *testing.T) { + var a int = max(10, 20) + if a != 20 { + t.Errorf("Bad result of max, got %v, want %v", a, 20) + } +} + +func TestIntConst(t *testing.T) { + var a int = add(10) + if a != 30 { + t.Errorf("Bad result of add, got %v, want %v", a, 30) + } +} + +func TestStrConst(t *testing.T) { + v := getName() + if v != "test" { + t.Errorf("Bad name, got %v, want %v", v, "test") + } +} + +func TestImport(t *testing.T) { + v := getMax() + if v != math.MaxUint64 { + t.Errorf("Bad max value, got %v, want %v", v, uint64(math.MaxUint64)) + } +} diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD new file mode 100644 index 000000000..edbeb4e2d --- /dev/null +++ b/tools/go_stateify/BUILD @@ -0,0 +1,9 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_binary") + +go_binary( + name = "stateify", + srcs = ["main.go"], + visibility = ["//visibility:public"], +) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl new file mode 100644 index 000000000..87fdc0d28 --- /dev/null +++ b/tools/go_stateify/defs.bzl @@ -0,0 +1,77 @@ +"""Stateify is a tool for generating state wrappers for Go types. + +The go_stateify rule is used to generate a file that will appear in a Go +target; the output file should appear explicitly in a srcs list. For example: + +go_stateify( + name = "foo_state", + srcs = ["foo.go"], + out = "foo_state.go", + package = "foo", +) + +go_library( + name = "foo", + srcs = [ + "foo.go", + "foo_state.go", + ], + deps = [ + "//pkg/state", + ], +) +""" + +def _go_stateify_impl(ctx): + """Implementation for the stateify tool.""" + output = ctx.outputs.out + + # Run the stateify command. + args = ["-output=%s" % output.path] + args += ["-pkg=%s" % ctx.attr.package] + if ctx.attr._statepkg: + args += ["-statepkg=%s" % ctx.attr._statepkg] + if ctx.attr.imports: + args += ["-imports=%s" % ",".join(ctx.attr.imports)] + args += ["--"] + for src in ctx.attr.srcs: + args += [f.path for f in src.files] + ctx.actions.run( + inputs = ctx.files.srcs, + outputs = [output], + mnemonic = "GoStateify", + progress_message = "Generating state library %s" % ctx.label, + arguments = args, + executable = ctx.executable._tool, + ) + +""" +Generates save and restore logic from a set of Go files. + + +Args: + name: the name of the rule. + srcs: the input source files. These files should include all structs in the package that need to be saved. + imports: an optional list of extra non-aliased, Go-style absolute import paths. + out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library. + package: the package name for the input sources. +""" + +go_stateify = rule( + attrs = { + "srcs": attr.label_list( + mandatory = True, + allow_files = True, + ), + "imports": attr.string_list(mandatory = False), + "package": attr.string(mandatory = True), + "out": attr.output(mandatory = True), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/go_stateify:stateify"), + ), + "_statepkg": attr.string(default = "gvisor.googlesource.com/gvisor/pkg/state"), + }, + implementation = _go_stateify_impl, +) diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go new file mode 100644 index 000000000..5eb4fe51f --- /dev/null +++ b/tools/go_stateify/main.go @@ -0,0 +1,386 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Stateify provides a simple way to generate Load/Save methods based on +// existing types and struct tags. +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "reflect" + "strings" +) + +var ( + pkg = flag.String("pkg", "", "output package") + imports = flag.String("imports", "", "extra imports for the output file") + output = flag.String("output", "", "output file") + statePkg = flag.String("statepkg", "", "state import package; defaults to empty") +) + +// resolveTypeName returns a qualified type name. +func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) { + for done := false; !done; { + // Resolve star expressions. + switch rs := typ.(type) { + case *ast.StarExpr: + qualified += "*" + typ = rs.X + case *ast.ArrayType: + if rs.Len == nil { + // Slice type declaration. + qualified += "[]" + } else { + // Array type declaration. + qualified += "[" + rs.Len.(*ast.BasicLit).Value + "]" + } + typ = rs.Elt + default: + // No more descent. + done = true + } + } + + // Resolve a package selector. + sel, ok := typ.(*ast.SelectorExpr) + if ok { + qualified = qualified + sel.X.(*ast.Ident).Name + "." + typ = sel.Sel + } + + // Figure out actual type name. + ident, ok := typ.(*ast.Ident) + if !ok { + panic(fmt.Sprintf("type not supported: %s (involves anonymous types?)", name)) + } + field = ident.Name + qualified = qualified + field + return +} + +// extractStateTag pulls the relevant state tag. +func extractStateTag(tag *ast.BasicLit) string { + if tag == nil { + return "" + } + if len(tag.Value) < 2 { + return "" + } + return reflect.StructTag(tag.Value[1 : len(tag.Value)-1]).Get("state") +} + +// scanFunctions is a set of functions passed to scanFields. +type scanFunctions struct { + zerovalue func(name string) + normal func(name string) + wait func(name string) + value func(name, typName string) +} + +// scanFields scans the fields of a struct. +// +// Each provided function will be applied to appropriately tagged fields, or +// skipped if nil. +// +// Fields tagged nosave are skipped. +func scanFields(ss *ast.StructType, fn scanFunctions) { + if ss.Fields.List == nil { + // No fields. + return + } + + // Scan all fields. + for _, field := range ss.Fields.List { + // Calculate the name. + name := "" + if field.Names != nil { + // It's a named field; override. + name = field.Names[0].Name + } else { + // Anonymous types can't be embedded, so we don't need + // to worry about providing a useful name here. + name, _ = resolveTypeName("", field.Type) + } + + // Skip _ fields. + if name == "_" { + continue + } + + switch tag := extractStateTag(field.Tag); tag { + case "zerovalue": + if fn.zerovalue != nil { + fn.zerovalue(name) + } + + case "": + if fn.normal != nil { + fn.normal(name) + } + + case "wait": + if fn.wait != nil { + fn.wait(name) + } + + case "manual", "nosave", "ignore": + // Do nothing. + + default: + if strings.HasPrefix(tag, ".(") && strings.HasSuffix(tag, ")") { + if fn.value != nil { + fn.value(name, tag[2:len(tag)-1]) + } + } + } + } +} + +func camelCased(name string) string { + return strings.ToUpper(name[:1]) + name[1:] +} + +func main() { + // Parse flags. + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [options]\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + os.Exit(1) + } + if *pkg == "" { + fmt.Fprintf(os.Stderr, "Error: package required.") + os.Exit(1) + } + + // Open the output file. + var ( + outputFile *os.File + err error + ) + if *output == "" || *output == "-" { + outputFile = os.Stdout + } else { + outputFile, err = os.OpenFile(*output, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "Error opening output %q: %v", *output, err) + } + defer outputFile.Close() + } + + // Set the statePrefix for below, depending on the import. + statePrefix := "" + if *statePkg != "" { + parts := strings.Split(*statePkg, "/") + statePrefix = parts[len(parts)-1] + "." + } + + // initCalls is dumped at the end. + var initCalls []string + + // Declare our emission closures. + emitRegister := func(name string) { + initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name)) + } + emitZeroCheck := func(name string) { + fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) + } + + // Emit the package name. + fmt.Fprintf(outputFile, "// automatically generated by stateify.\n\n") + fmt.Fprintf(outputFile, "package %s\n\n", *pkg) + fmt.Fprintln(outputFile, "import (") + if *statePkg != "" { + fmt.Fprintf(outputFile, " \"%s\"\n", *statePkg) + } + if *imports != "" { + for _, i := range strings.Split(*imports, ",") { + fmt.Fprintf(outputFile, " \"%s\"\n", i) + } + } + fmt.Fprintln(outputFile, ")\n") + + files := make([]*ast.File, 0, len(flag.Args())) + + // Parse the input files. + for _, filename := range flag.Args() { + // Parse the file. + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, filename, nil, 0) + if err != nil { + // Not a valid input file? + fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err) + os.Exit(1) + } + files = append(files, f) + } + + type method struct { + receiver string + name string + } + + // Search for and add all methods with a pointer receiver and no other + // arguments to a set. We support auto-detecting the existence of + // several different methods with this signature. + simpleMethods := map[method]struct{}{} + for _, f := range files { + + // Go over all functions. + for _, decl := range f.Decls { + d, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if d.Name == nil || d.Recv == nil || d.Type == nil { + // Not a named method. + continue + } + if len(d.Recv.List) != 1 { + // Wrong number of receivers? + continue + } + if d.Type.Params != nil && len(d.Type.Params.List) != 0 { + // Has argument(s). + continue + } + if d.Type.Results != nil && len(d.Type.Results.List) != 0 { + // Has return(s). + continue + } + + pt, ok := d.Recv.List[0].Type.(*ast.StarExpr) + if !ok { + // Not a pointer receiver. + continue + } + + t, ok := pt.X.(*ast.Ident) + if !ok { + // This shouldn't happen with valid Go. + continue + } + + simpleMethods[method{t.Name, d.Name.Name}] = struct{}{} + } + } + + for _, f := range files { + // Go over all named types. + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok || d.Tok != token.TYPE { + continue + } + + for _, gs := range d.Specs { + ts := gs.(*ast.TypeSpec) + switch ts.Type.(type) { + case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr: + // Don't register. + break + case *ast.StructType: + ss := ts.Type.(*ast.StructType) + + // Define beforeSave if a definition was not found. This + // prevents the code from compiling if a custom beforeSave + // was defined in a file not provided to this binary and + // prevents inherited methods from being called multiple times + // by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name) + } + + // Generate the save method. + fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(ss, scanFunctions{zerovalue: emitZeroCheck}) + scanFields(ss, scanFunctions{value: emitSaveValue}) + scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + + // Define afterLoad if a definition was not found. We do this + // for the same reason that we do it for beforeSave. + _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] + if !hasAfterLoad { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name) + } + + // Generate the load method. + // + // Note that the manual loads always follow the + // automated loads. + fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) + scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(ss, scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") + + // Add to our registration. + emitRegister(ts.Name.Name) + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: + _, val := resolveTypeName(ts.Name.Name, ts.Type) + + // Dispatch directly. + fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val) + fmt.Fprintf(outputFile, "}\n\n") + fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val) + fmt.Fprintf(outputFile, "}\n\n") + + // See above. + emitRegister(ts.Name.Name) + } + } + } + } + + // Emit the init() function. + fmt.Fprintf(outputFile, "func init() {\n") + for _, ic := range initCalls { + fmt.Fprintf(outputFile, " %s\n", ic) + } + fmt.Fprintf(outputFile, "}\n") +} |