summaryrefslogtreecommitdiffhomepage
path: root/tools/go_generics/imports.go
diff options
context:
space:
mode:
Diffstat (limited to 'tools/go_generics/imports.go')
-rw-r--r--tools/go_generics/imports.go150
1 files changed, 150 insertions, 0 deletions
diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go
new file mode 100644
index 000000000..97267098b
--- /dev/null
+++ b/tools/go_generics/imports.go
@@ -0,0 +1,150 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/format"
+ "go/parser"
+ "go/token"
+ "strconv"
+
+ "gvisor.googlesource.com/gvisor/tools/go_generics/globals"
+)
+
+type importedPackage struct {
+ newName string
+ path string
+}
+
+// updateImportIdent modifies the given import identifier with the new name
+// stored in the used map. If the identifier doesn't exist in the used map yet,
+// a new name is generated and inserted into the map.
+func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error {
+ importName := id.Name
+
+ // If the name is already in the table, just use the new name.
+ m := used[importName]
+ if m != nil {
+ id.Name = m.newName
+ return nil
+ }
+
+ // Create a new entry in the used map.
+ path := imports[importName]
+ if path == "" {
+ return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig)
+ }
+
+ m = &importedPackage{
+ newName: fmt.Sprintf("__generics_imported%d", len(used)),
+ path: strconv.Quote(path),
+ }
+ used[importName] = m
+
+ id.Name = m.newName
+
+ return nil
+}
+
+// convertExpression creates a new string that is a copy of the input one with
+// all imports references renamed to the names in the "used" map. If the
+// referenced import isn't in "used" yet, a new one is created based on the path
+// in "imports" and stored in "used". For example, if string s is
+// "math.MaxUint32-math.MaxUint16+10", it would be converted to
+// "x.MaxUint32-x.MathUint16+10", where x is a generated name.
+func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) {
+ // Parse the expression in the input string.
+ expr, err := parser.ParseExpr(s)
+ if err != nil {
+ return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err)
+ }
+
+ // Go through the AST and update references.
+ var retErr error
+ ast.Inspect(expr, func(n ast.Node) bool {
+ switch x := n.(type) {
+ case *ast.SelectorExpr:
+ if id := globals.GetIdent(x.X); id != nil {
+ if err := updateImportIdent(s, imports, id, used); err != nil {
+ retErr = err
+ }
+ return false
+ }
+ }
+ return true
+ })
+ if retErr != nil {
+ return "", retErr
+ }
+
+ // Convert the modified AST back to a string.
+ fset := token.NewFileSet()
+ var buf bytes.Buffer
+ if err := format.Node(&buf, fset, expr); err != nil {
+ return "", err
+ }
+
+ return string(buf.Bytes()), nil
+}
+
+// updateImports replaces all maps in the input slice with copies where the
+// mapped values have had all references to imported packages renamed to
+// generated names. It also returns an import declaration for all the renamed
+// import packages.
+//
+// For example, if the input maps contains A=math.B and C=math.D, the updated
+// maps will instead contain A=__generics_imported0.B and
+// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would
+// be returned as the import declaration.
+func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
+ importsUsed := make(map[string]*importedPackage)
+
+ // Update all maps.
+ for i, m := range maps {
+ newMap := make(mapValue)
+ for n, e := range m {
+ updated, err := convertExpression(e, imports, importsUsed)
+ if err != nil {
+ return nil, err
+ }
+
+ newMap[n] = updated
+ }
+ maps[i] = newMap
+ }
+
+ // Nothing else to do if no imports are used in the expressions.
+ if len(importsUsed) == 0 {
+ return nil, nil
+ }
+
+ // Create spec array for each new import.
+ specs := make([]ast.Spec, 0, len(importsUsed))
+ for _, i := range importsUsed {
+ specs = append(specs, &ast.ImportSpec{
+ Name: &ast.Ident{Name: i.newName},
+ Path: &ast.BasicLit{Value: i.path},
+ })
+ }
+
+ return &ast.GenDecl{
+ Tok: token.IMPORT,
+ Specs: specs,
+ Lparen: token.NoPos + 1,
+ }, nil
+}