// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"sort"
	"strconv"

	"gvisor.dev/gvisor/tools/go_generics/globals"
)

type importedPackage struct {
	newName string
	path    string
}

// updateImportIdent modifies the given import identifier with the new name
// stored in the used map. If the identifier doesn't exist in the used map yet,
// a new name is generated and inserted into the map.
func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error {
	importName := id.Name

	// If the name is already in the table, just use the new name.
	m := used[importName]
	if m != nil {
		id.Name = m.newName
		return nil
	}

	// Create a new entry in the used map.
	path := imports[importName]
	if path == "" {
		return fmt.Errorf("unknown path to package '%s', used in '%s'", importName, orig)
	}

	m = &importedPackage{
		newName: fmt.Sprintf("__generics_imported%d", len(used)),
		path:    strconv.Quote(path),
	}
	used[importName] = m

	id.Name = m.newName

	return nil
}

// convertExpression creates a new string that is a copy of the input one with
// all imports references renamed to the names in the "used" map. If the
// referenced import isn't in "used" yet, a new one is created based on the path
// in "imports" and stored in "used". For example, if string s is
// "math.MaxUint32-math.MaxUint16+10", it would be converted to
// "x.MaxUint32-x.MathUint16+10", where x is a generated name.
func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) {
	// Parse the expression in the input string.
	expr, err := parser.ParseExpr(s)
	if err != nil {
		return "", fmt.Errorf("unable to parse \"%s\": %v", s, err)
	}

	// Go through the AST and update references.
	var retErr error
	ast.Inspect(expr, func(n ast.Node) bool {
		switch x := n.(type) {
		case *ast.SelectorExpr:
			if id := globals.GetIdent(x.X); id != nil {
				if err := updateImportIdent(s, imports, id, used); err != nil {
					retErr = err
				}
				return false
			}
		}
		return true
	})
	if retErr != nil {
		return "", retErr
	}

	// Convert the modified AST back to a string.
	fset := token.NewFileSet()
	var buf bytes.Buffer
	if err := format.Node(&buf, fset, expr); err != nil {
		return "", err
	}

	return string(buf.Bytes()), nil
}

// updateImports replaces all maps in the input slice with copies where the
// mapped values have had all references to imported packages renamed to
// generated names. It also returns an import declaration for all the renamed
// import packages.
//
// For example, if the input maps contains A=math.B and C=math.D, the updated
// maps will instead contain A=__generics_imported0.B and
// C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would
// be returned as the import declaration.
func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
	importsUsed := make(map[string]*importedPackage)

	// Update all maps.
	for i, m := range maps {
		newMap := make(mapValue)
		for n, e := range m {
			updated, err := convertExpression(e, imports, importsUsed)
			if err != nil {
				return nil, err
			}

			newMap[n] = updated
		}
		maps[i] = newMap
	}

	// Nothing else to do if no imports are used in the expressions.
	if len(importsUsed) == 0 {
		return nil, nil
	}
	var names []string
	for n := range importsUsed {
		names = append(names, n)
	}
	// Sort the new imports for deterministic build outputs.
	sort.Strings(names)

	// Create spec array for each new import.
	specs := make([]ast.Spec, 0, len(importsUsed))
	for _, n := range names {
		i := importsUsed[n]
		specs = append(specs, &ast.ImportSpec{
			Name: &ast.Ident{Name: i.newName},
			Path: &ast.BasicLit{Value: i.path},
		})
	}

	return &ast.GenDecl{
		Tok:    token.IMPORT,
		Specs:  specs,
		Lparen: token.NoPos + 1,
	}, nil
}