Skip to content

Commit

Permalink
gopls/internal/golang: add missing imports in foo_test.go
Browse files Browse the repository at this point in the history
- Gopls will honor any renaming of package "testing" if any.
- Gopls will collect all the package that have not been imported
in foo_test.go and modify the foo_test.go imports.

For golang/vscode-go#1594

Change-Id: Id6b87b6417a26f8e925582317e91fb4ebff4a0e7
Reviewed-on: https://go-review.googlesource.com/c/tools/+/620697
Reviewed-by: Robert Findley <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
h9jiang committed Nov 13, 2024
1 parent 87ac91f commit 288b9cb
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 66 deletions.
181 changes: 151 additions & 30 deletions gopls/internal/golang/addtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ import (
"html/template"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"unicode"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/cache/metadata"
"golang.org/x/tools/gopls/internal/cache/parsego"
"golang.org/x/tools/gopls/internal/protocol"
goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
"golang.org/x/tools/internal/imports"
"golang.org/x/tools/internal/typesinternal"
)

const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
const testTmplString = `
func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
{{- /* Constructor input parameters struct declaration. */}}
{{- if and .Receiver .Receiver.Constructor}}
{{- if gt (len .Receiver.Constructor.Args) 1}}
Expand Down Expand Up @@ -83,7 +87,7 @@ const testTmplString = `func {{.TestFuncName}}(t *testing.T) {
{{- /* Loop over all the test cases. */}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run(tt.name, func(t *{{.TestingPackageName}}.T) {
{{- /* Constructor or empty initialization. */}}
{{- if .Receiver}}
{{- if .Receiver.Constructor}}
Expand Down Expand Up @@ -170,6 +174,10 @@ type receiver struct {
}

type testInfo struct {
// TestingPackageName is the package name should be used when referencing
// package "testing"
TestingPackageName string
// PackageName is the package name the target function/method is delcared from.
PackageName string
TestFuncName string
// Func holds information about the function or method being tested.
Expand Down Expand Up @@ -202,37 +210,79 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
return nil, err
}

if metadata.IsCommandLineArguments(pkg.Metadata().ID) {
return nil, fmt.Errorf("current file in command-line-arguments package")
}

if errors := pkg.ParseErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has parse errors: %v", errors[0])
}
if errors := pkg.TypeErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has type errors: %v", errors[0])
}

// imports is a map from package path to local package name.
var imports = make(map[string]string)
type packageInfo struct {
name string
renamed bool
}

var (
// fileImports is a map contains all the path imported in the original
// file foo.go.
fileImports map[string]packageInfo
// testImports is a map contains all the path already imported in test
// file foo_test.go.
testImports map[string]packageInfo
// extraImportsis a map from package path to local package name that
// need to be imported for the test function.
extraImports = make(map[string]packageInfo)
)

var collectImports = func(file *ast.File) error {
var collectImports = func(file *ast.File) (map[string]packageInfo, error) {
imps := make(map[string]packageInfo)
for _, spec := range file.Imports {
// TODO(hxjiang): support dot imports.
if spec.Name != nil && spec.Name.Name == "." {
return fmt.Errorf("\"add a test for FUNC\" does not support files containing dot imports")
return nil, fmt.Errorf("\"add a test for func\" does not support files containing dot imports")
}
path, err := strconv.Unquote(spec.Path.Value)
if err != nil {
return err
return nil, err
}
if spec.Name != nil && spec.Name.Name != "_" {
imports[path] = spec.Name.Name
if spec.Name != nil {
if spec.Name.Name == "_" {
continue
}
imps[path] = packageInfo{spec.Name.Name, true}
} else {
imports[path] = filepath.Base(path)
// The package name might differ from the base of its import
// path. For example, "/path/to/package/foo" could declare a
// package named "bar". Look up the target package ensures the
// accurate package name reference.
//
// While it's best practice to rename imported packages when
// their name differs from the base path (e.g.,
// "import bar \"path/to/package/foo\""), this is not mandatory.
id := pkg.Metadata().DepsByImpPath[metadata.ImportPath(path)]
if metadata.IsCommandLineArguments(id) {
return nil, fmt.Errorf("can not import command-line-arguments package")
}
if id == "" { // guess upon missing.
imps[path] = packageInfo{imports.ImportPathToAssumedName(path), false}
} else {
fromPkg, ok := snapshot.MetadataGraph().Packages[id]
if !ok {
return nil, fmt.Errorf("package id %v does not exist", id)
}
imps[path] = packageInfo{string(fromPkg.Name), false}
}
}
}
return nil
return imps, nil
}

// Collect all the imports from the x.go, keep track of the local package name.
if err := collectImports(pgf.File); err != nil {
if fileImports, err = collectImports(pgf.File); err != nil {
return nil, err
}

Expand All @@ -259,7 +309,8 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
xtest = true
)

if testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header); err != nil {
testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
Expand Down Expand Up @@ -288,8 +339,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
header.WriteString("\n\n")
}
}
// One empty line between package decl and rest of the file.
fmt.Fprintf(&header, "package %s_test\n\n", pkg.Types().Name())
fmt.Fprintf(&header, "package %s_test\n", pkg.Types().Name())

// Write the copyright and package decl to the beginning of the file.
edits = append(edits, protocol.TextEdit{
Expand All @@ -314,29 +364,41 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
return nil, err
}

// Collect all the imports from the x_test.go, overwrite the local pakcage
// name collected from x.go.
if err := collectImports(testPGF.File); err != nil {
// Collect all the imports from the foo_test.go.
if testImports, err = collectImports(testPGF.File); err != nil {
return nil, err
}
}

// qf qualifier returns the local package name need to use in x_test.go by
// consulting the consolidated imports map.
// qf qualifier determines the correct package name to use for a type in
// foo_test.go. It does this by:
// - Consult imports map from test file foo_test.go.
// - If not found, consult imports map from original file foo.go.
// If the package is not imported in test file foo_test.go, it is added to
// extraImports map.
qf := func(p *types.Package) string {
// When generating test in x packages, any type/function defined in the same
// x package can emit package name.
if !xtest && p == pkg.Types() {
return ""
}
if local, ok := imports[p.Path()]; ok {
return local
// Prefer using the package name if already defined in foo_test.go
if local, ok := testImports[p.Path()]; ok {
return local.name
}
// TODO(hxjiang): we should consult the scope of the test package to
// ensure these new imports do not shadow any package-level names.
// If not already imported by foo_test.go, consult the foo.go import map.
if local, ok := fileImports[p.Path()]; ok {
// The package that contains this type need to be added to the import
// list in foo_test.go.
extraImports[p.Path()] = local
return local.name
}
extraImports[p.Path()] = packageInfo{name: p.Name()}
return p.Name()
}

// TODO(hxjiang): modify existing imports or add new imports.

start, end, err := pgf.RangePos(loc.Range)
if err != nil {
return nil, err
Expand Down Expand Up @@ -378,8 +440,9 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
}

data := testInfo{
PackageName: qf(pkg.Types()),
TestFuncName: testName,
TestingPackageName: qf(types.NewPackage("testing", "testing")),
PackageName: qf(pkg.Types()),
TestFuncName: testName,
Func: function{
Name: fn.Name(),
},
Expand Down Expand Up @@ -557,15 +620,73 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
}
}

// Compute edits to update imports.
//
// If we're adding to an existing test file, we need to adjust existing
// imports. Otherwise, we can simply write out the imports to the new file.
if testPGF != nil {
var importFixes []*imports.ImportFix
for path, info := range extraImports {
name := ""
if info.renamed {
name = info.name
}
importFixes = append(importFixes, &imports.ImportFix{
StmtInfo: imports.ImportInfo{
ImportPath: path,
Name: name,
},
FixType: imports.AddImport,
})
}
importEdits, err := ComputeImportFixEdits(snapshot.Options().Local, testPGF.Src, importFixes...)
if err != nil {
return nil, fmt.Errorf("could not compute the import fix edits: %w", err)
}
edits = append(edits, importEdits...)
} else {
var importsBuffer bytes.Buffer
if len(extraImports) == 1 {
importsBuffer.WriteString("\nimport ")
for path, info := range extraImports {
if info.renamed {
importsBuffer.WriteString(info.name + " ")
}
importsBuffer.WriteString(fmt.Sprintf("\"%s\"\n", path))
}
} else {
importsBuffer.WriteString("\nimport(")
// Loop over the map in sorted order ensures deterministic outcome.
paths := make([]string, 0, len(extraImports))
for key := range extraImports {
paths = append(paths, key)
}
sort.Strings(paths)
for _, path := range paths {
importsBuffer.WriteString("\n\t")
if extraImports[path].renamed {
importsBuffer.WriteString(extraImports[path].name + " ")
}
importsBuffer.WriteString(fmt.Sprintf("\"%s\"", path))
}
importsBuffer.WriteString("\n)\n")
}
edits = append(edits, protocol.TextEdit{
Range: protocol.Range{},
NewText: importsBuffer.String(),
})
}

var test bytes.Buffer
if err := testTmpl.Execute(&test, data); err != nil {
return nil, err
}

edits = append(edits, protocol.TextEdit{
Range: eofRange,
NewText: test.String(),
})
edits = append(edits,
protocol.TextEdit{
Range: eofRange,
NewText: test.String(),
})

return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
}
Expand Down
Loading

0 comments on commit 288b9cb

Please sign in to comment.