Skip to content

Commit

Permalink
gopls/internal/golang: generate test name for selected function/method
Browse files Browse the repository at this point in the history
This CL interpret the user's intention by reading the x_test.go's
package name.

If the x_test.go exist and is using package name x, we allow adding
test for unexported function/method.
If the x_test.go does not exist or it exist but using package name
x_test, we will only allow adding test for exported function/methods.

For golang/vscode-go#1594

Change-Id: I20f6c41dc4c53bb816b40982a0ebbbcb1e3a92e7
Reviewed-on: https://go-review.googlesource.com/c/tools/+/621675
Reviewed-by: Robert Findley <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
h9jiang committed Nov 4, 2024
1 parent 2998e9a commit e36459f
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 42 deletions.
156 changes: 124 additions & 32 deletions gopls/internal/golang/addtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ import (
"context"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"os"
"path/filepath"
"strings"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/cache/parsego"
"golang.org/x/tools/gopls/internal/protocol"
goplsastutil "golang.org/x/tools/gopls/internal/util/astutil"
"golang.org/x/tools/internal/typesinternal"
)

// AddTestForFunc adds a test for the function enclosing the given input range.
Expand All @@ -29,6 +34,13 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
return nil, err
}

if errors := pkg.ParseErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has parse errors: %v", errors[0])
}
if errors := pkg.TypeErrors(); len(errors) > 0 {
return nil, fmt.Errorf("package has type errors: %v", errors[0])
}

testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go"
goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.Dir().Path(), testBase))

Expand All @@ -41,32 +53,37 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
// exist.

var (
eofRange protocol.Range // empty selection at end of new file
// edits contains all the text edits to be applied to the test file.
edits []protocol.TextEdit
// header is the buffer containing the text edit to the beginning of the file.
header bytes.Buffer
// xtest indicates whether the test file use package x or x_test.
// TODO(hxjiang): For now, we try to interpret the user's intention by
// reading the foo_test.go's package name. Instead, we can discuss the option
// to interpret the user's intention by which function they are selecting.
// Have one file for x_test package testing, one file for x package testing.
xtest = true
)

testPgf, err := snapshot.ParseGo(ctx, testFH, parsego.Header)
if err != nil {
if testPGF, err := snapshot.ParseGo(ctx, testFH, parsego.Header); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}

changes = append(changes, protocol.DocumentChangeCreate(goTestFileURI))

// If this test file was created by the gopls, add a copyright header based
// on the originating file.
// header is the buffer containing the text to add to the beginning of the file.
var header bytes.Buffer

// If this test file was created by the gopls, add a copyright header and
// package decl based on the originating file.
// Search for something that looks like a copyright header, to replicate
// in the new file.
// TODO(hxjiang): should we refine this heuristic, for example by checking for
// the word 'copyright'?
if groups := pgf.File.Comments; len(groups) > 0 {
// Copyright should appear before package decl and must be the first
// comment group.
// Avoid copying any other comment like package doc or directive comment.
if c := groups[0]; c.Pos() < pgf.File.Package && c != pgf.File.Doc &&
!isDirective(groups[0].List[0].Text) {
!isDirective(c.List[0].Text) &&
strings.Contains(strings.ToLower(c.List[0].Text), "copyright") {
start, end, err := pgf.NodeOffsets(c)
if err != nil {
return nil, err
Expand All @@ -76,42 +93,117 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
header.WriteString("\n\n")
}
}
}

// If the test file does not have package decl, use the originating file to
// determine a package decl for the new file. Prefer xtest package.s
if testPgf == nil || testPgf.File == nil || testPgf.File.Package == token.NoPos {
// One empty line between package decl and rest of the file.
fmt.Fprintf(&header, "package %s_test\n\n", pkg.Types().Name())
}

// Write the copyright and package decl to the beginning of the file.
if text := header.String(); len(text) != 0 {
// Write the copyright and package decl to the beginning of the file.
edits = append(edits, protocol.TextEdit{
Range: protocol.Range{},
NewText: text,
NewText: header.String(),
})
} else { // existing _test.go file.
if testPGF.File.Name == nil || testPGF.File.Name.NamePos == token.NoPos {
return nil, fmt.Errorf("missing package declaration")
}
switch testPGF.File.Name.Name {
case pgf.File.Name.Name:
xtest = false
case pgf.File.Name.Name + "_test":
xtest = true
default:
return nil, fmt.Errorf("invalid package declaration %q in test file %q", testPGF.File.Name, testPGF)
}

eofRange, err = testPGF.PosRange(testPGF.File.FileEnd, testPGF.File.FileEnd)
if err != nil {
return nil, err
}
}

// TODO(hxjiang): reject if the function/method is unexported.
// TODO(hxjiang): modify existing imports or add new imports.

// If the parse go file is missing, the fileEnd is the file start (zero value).
fileEnd := protocol.Range{}
if testPgf != nil {
fileEnd, err = testPgf.PosRange(testPgf.File.FileEnd, testPgf.File.FileEnd)
if err != nil {
return nil, err
start, end, err := pgf.RangePos(loc.Range)
if err != nil {
return nil, err
}

path, _ := astutil.PathEnclosingInterval(pgf.File, start, end)
if len(path) < 2 {
return nil, fmt.Errorf("no enclosing function")
}

decl, ok := path[len(path)-2].(*ast.FuncDecl)
if !ok {
return nil, fmt.Errorf("no enclosing function")
}

fn := pkg.TypesInfo().Defs[decl.Name].(*types.Func)

if xtest {
// Reject if function/method is unexported.
if !fn.Exported() {
return nil, fmt.Errorf("cannot add test of unexported function %s to external test package %s_test", decl.Name, pgf.File.Name)
}

// Reject if receiver is unexported.
if fn.Signature().Recv() != nil {
if _, ident, _ := goplsastutil.UnpackRecv(decl.Recv.List[0].Type); !ident.IsExported() {
return nil, fmt.Errorf("cannot add external test for method %s.%s as receiver type is not exported", ident.Name, decl.Name)
}
}

// TODO(hxjiang): reject if the any input parameter type is unexported.
// TODO(hxjiang): reject if any return value type is unexported. Explore
// the option to drop the return value if the type is unexported.
}

// test is the buffer containing the text edit to the test function.
var test bytes.Buffer
// TODO(hxjiang): replace test foo function with table-driven test.
test.WriteString("\nfunc TestFoo(*testing.T) {}")
testName, err := testName(fn)
if err != nil {
return nil, err
}
// TODO(hxjiang): replace test function with table-driven test.
edits = append(edits, protocol.TextEdit{
Range: fileEnd,
NewText: test.String(),
Range: eofRange,
NewText: fmt.Sprintf(`
func %s(*testing.T) {
// TODO: implement test
}
`, testName),
})
return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil
}

// testName returns the name of the function to use for the new function that
// tests fn.
// Returns empty string if the fn is ill typed or nil.
func testName(fn *types.Func) (string, error) {
if fn == nil {
return "", fmt.Errorf("input nil function")
}
testName := "Test"
if recv := fn.Signature().Recv(); recv != nil { // method declaration.
// Retrieve the unpointered receiver type to ensure the test name is based
// on the topmost alias or named type, not the alias' RHS type (potentially
// unexported) type.
// For example:
// type Foo = foo // Foo is an exported alias for the unexported type foo
recvType := recv.Type()
if ptr, ok := recv.Type().(*types.Pointer); ok {
recvType = ptr.Elem()
}

t, ok := recvType.(typesinternal.NamedOrAlias)
if !ok {
return "", fmt.Errorf("receiver type is not named type or alias type")
}

if !t.Obj().Exported() {
testName += "_"
}

testName += t.Obj().Name() + "_"
} else if !fn.Exported() { // unexported function declaration.
testName += "_"
}
return testName + fn.Name(), nil
}
Loading

0 comments on commit e36459f

Please sign in to comment.