Skip to content

Commit

Permalink
gopls/internal/golang: consolidate imports from both file in qualifier
Browse files Browse the repository at this point in the history
This commit improves qualifier by consolidating imports from both the
main file (x.go) and its corresponding test file (x_test.go).

An imports map is used to track all import paths and their local
renames.  Imports from x_test.go are prioritized over x.go as gopls
is generating test in x_test.go.

This ensures that the generated qualifier correctly reflects any
necessary renames, improving accuracy and consistency.

For golang/vscode-go#1594

Change-Id: I457d5f22f7de4fe86006b57487f243494c8e7f6f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/622320
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Robert Findley <[email protected]>
Auto-Submit: Hongxiang Jiang <[email protected]>
  • Loading branch information
h9jiang authored and gopherbot committed Nov 5, 2024
1 parent 0c792f1 commit 691997a
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 11 deletions.
60 changes: 49 additions & 11 deletions gopls/internal/golang/addtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"html/template"
"os"
"path/filepath"
"strconv"
"strings"

"golang.org/x/tools/go/ast/astutil"
Expand Down Expand Up @@ -115,6 +116,33 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
return nil, fmt.Errorf("package has type errors: %v", errors[0])
}

// imports is a map from package path to local package name.
var imports = make(map[string]string)

var collectImports = func(file *ast.File) error {
for _, spec := range file.Imports {
// TODO(hxjiang): support dot imports.
if spec.Name != nil && spec.Name.Name == "." {
return fmt.Errorf("\"add a test for FUNC\" does not support files containing dot imports")
}
path, err := strconv.Unquote(spec.Path.Value)
if err != nil {
return err
}
if spec.Name != nil && spec.Name.Name != "_" {
imports[path] = spec.Name.Name
} else {
imports[path] = filepath.Base(path)
}
}
return nil
}

// Collect all the imports from the x.go, keep track of the local package name.
if err := collectImports(pgf.File); err != nil {
return nil, err
}

testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go"
goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.Dir().Path(), testBase))

Expand Down Expand Up @@ -192,6 +220,26 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
if err != nil {
return nil, err
}

// Collect all the imports from the x_test.go, overwrite the local pakcage
// name collected from x.go.
if err := collectImports(testPGF.File); err != nil {
return nil, err
}
}

// qf qualifier returns the local package name need to use in x_test.go by
// consulting the consolidated imports map.
qf := func(p *types.Package) string {
// When generating test in x packages, any type/function defined in the same
// x package can emit package name.
if !xtest && p == pkg.Types() {
return ""
}
if local, ok := imports[p.Path()]; ok {
return local
}
return p.Name()
}

// TODO(hxjiang): modify existing imports or add new imports.
Expand Down Expand Up @@ -231,16 +279,6 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
// the option to drop the return value if the type is unexported.
}

// TODO(hxjiang): qualifier should consolidate existing imports from x
// package and existing x_test package. The existing x_test package imports
// should overwrite x package imports.
var qf types.Qualifier
if xtest {
qf = (*types.Package).Name
} else {
qf = typesinternal.NameRelativeTo(pkg.Types())
}

testName, err := testName(fn)
if err != nil {
return nil, err
Expand All @@ -251,7 +289,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
}

if sig.Recv() == nil && xtest {
data.PackageName = pkg.Types().Name()
data.PackageName = qf(pkg.Types())
}

for i := range sig.Params().Len() {
Expand Down
85 changes: 85 additions & 0 deletions gopls/internal/test/marker/testdata/codeaction/addtest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,88 @@ func Foo(in, in1, in2, in3 string) (out, out1, out2 string) {return in, in, in}
+ }
+ }
+}
-- xpackagerename/xpackagerename.go --
package main

import (
mytime "time"
myast "go/ast"
)

func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xpackage_rename)

-- @xpackage_rename/xpackagerename/xpackagerename_test.go --
@@ -0,0 +1,26 @@
+package main_test
+
+func TestFoo(t *testing.T) {
+ type args struct {
+ in mytime.Time
+ in2 *myast.Node
+ }
+ tests := []struct {
+ name string // description of this test case
+ args args
+ want mytime.Time
+ want2 *myast.Node
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ got, got2 := main.Foo(tt.args.in, tt.args.in2)
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got, tt.want)
+ }
+ if true {
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got2, tt.want2)
+ }
+ }
+}
-- xtestpackagerename/xtestpackagerename.go --
package main

import (
mytime "time"
myast "go/ast"
)

func Foo(t mytime.Time, a *myast.Node) (mytime.Time, *myast.Node) {return t, a} //@codeactionedit("Foo", "source.addTest", xtest_package_rename)

-- xtestpackagerename/xtestpackagerename_test.go --
package main_test

import (
yourtime "time"
yourast "go/ast"
)

var fooTime = yourtime.Time{}
var fooNode = yourast.Node{}

-- @xtest_package_rename/xtestpackagerename/xtestpackagerename_test.go --
@@ -11 +11,24 @@
+func TestFoo(t *testing.T) {
+ type args struct {
+ in yourtime.Time
+ in2 *yourast.Node
+ }
+ tests := []struct {
+ name string // description of this test case
+ args args
+ want yourtime.Time
+ want2 *yourast.Node
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ got, got2 := main.Foo(tt.args.in, tt.args.in2)
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got, tt.want)
+ }
+ if true {
+ t.Errorf("%s: Foo() = %v, want %v", tt.name, got2, tt.want2)
+ }
+ }
+}

0 comments on commit 691997a

Please sign in to comment.