From 386503de7a57cc643eab3dcd04cb78e40b669061 Mon Sep 17 00:00:00 2001 From: Hongxiang Jiang Date: Wed, 16 Oct 2024 14:44:19 -0400 Subject: [PATCH] gopls/internal/golang: add source code action for add test This CL is some glue code which build the connection between the LSP "code action request" with second call which compute the actual DocumentChange. AddTest source code action will create a test file if not already exist and insert a random function at the end of the test file. For testing, an internal boolean option "addTestSourceCodeAction" is created and only effective if set explicitly in marker test. For golang/vscode-go#1594 Change-Id: Ie3d9279ea2858805254181608a0d5103afd3a4c6 Reviewed-on: https://go-review.googlesource.com/c/tools/+/621056 Reviewed-by: Robert Findley Reviewed-by: Alan Donovan LUCI-TryBot-Result: Go LUCI --- gopls/internal/cache/parsego/file.go | 5 + gopls/internal/golang/addtest.go | 117 ++++++++++++++++++ gopls/internal/golang/codeaction.go | 36 ++++++ gopls/internal/golang/extracttofile.go | 6 +- .../internal/protocol/command/command_gen.go | 16 +++ gopls/internal/protocol/command/interface.go | 3 + gopls/internal/server/command.go | 72 ++++------- gopls/internal/settings/codeactionkind.go | 1 + gopls/internal/settings/default.go | 1 + gopls/internal/settings/settings.go | 7 ++ .../marker/testdata/codeaction/addtest.txt | 53 ++++++++ 11 files changed, 266 insertions(+), 51 deletions(-) create mode 100644 gopls/internal/golang/addtest.go create mode 100644 gopls/internal/test/marker/testdata/codeaction/addtest.txt diff --git a/gopls/internal/cache/parsego/file.go b/gopls/internal/cache/parsego/file.go index 687c8e39392..e7bf544a6ad 100644 --- a/gopls/internal/cache/parsego/file.go +++ b/gopls/internal/cache/parsego/file.go @@ -89,6 +89,11 @@ func (pgf *File) NodeRange(node ast.Node) (protocol.Range, error) { return pgf.Mapper.NodeRange(pgf.Tok, node) } +// NodeOffsets returns offsets for the ast.Node. +func (pgf *File) NodeOffsets(node ast.Node) (start int, end int, _ error) { + return safetoken.Offsets(pgf.Tok, node.Pos(), node.End()) +} + // NodeMappedRange returns a MappedRange for the ast.Node interval in this file. // A MappedRange can be converted to any other form. func (pgf *File) NodeMappedRange(node ast.Node) (protocol.MappedRange, error) { diff --git a/gopls/internal/golang/addtest.go b/gopls/internal/golang/addtest.go new file mode 100644 index 00000000000..725e705e923 --- /dev/null +++ b/gopls/internal/golang/addtest.go @@ -0,0 +1,117 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package golang + +// This file defines the behavior of the "Add test for FUNC" command. + +import ( + "bytes" + "context" + "errors" + "fmt" + "go/token" + "os" + "path/filepath" + "strings" + + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/parsego" + "golang.org/x/tools/gopls/internal/protocol" +) + +// AddTestForFunc adds a test for the function enclosing the given input range. +// It creates a _test.go file if one does not already exist. +func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.Location) (changes []protocol.DocumentChange, _ error) { + pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, loc.URI) + if err != nil { + return nil, err + } + + testBase := strings.TrimSuffix(filepath.Base(loc.URI.Path()), ".go") + "_test.go" + goTestFileURI := protocol.URIFromPath(filepath.Join(loc.URI.Dir().Path(), testBase)) + + testFH, err := snapshot.ReadFile(ctx, goTestFileURI) + if err != nil { + return nil, err + } + + // TODO(hxjiang): use a fresh name if the same test function name already + // exist. + + var ( + // edits contains all the text edits to be applied to the test file. + edits []protocol.TextEdit + // header is the buffer containing the text edit to the beginning of the file. + header bytes.Buffer + ) + + testPgf, err := snapshot.ParseGo(ctx, testFH, parsego.Header) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, err + } + + changes = append(changes, protocol.DocumentChangeCreate(goTestFileURI)) + + // If this test file was created by the gopls, add a copyright header based + // on the originating file. + // Search for something that looks like a copyright header, to replicate + // in the new file. + // TODO(hxjiang): should we refine this heuristic, for example by checking for + // the word 'copyright'? + if groups := pgf.File.Comments; len(groups) > 0 { + // Copyright should appear before package decl and must be the first + // comment group. + // Avoid copying any other comment like package doc or directive comment. + if c := groups[0]; c.Pos() < pgf.File.Package && c != pgf.File.Doc && + !isDirective(groups[0].List[0].Text) { + start, end, err := pgf.NodeOffsets(c) + if err != nil { + return nil, err + } + header.Write(pgf.Src[start:end]) + // One empty line between copyright header and package decl. + header.WriteString("\n\n") + } + } + } + + // If the test file does not have package decl, use the originating file to + // determine a package decl for the new file. Prefer xtest package.s + if testPgf == nil || testPgf.File == nil || testPgf.File.Package == token.NoPos { + // One empty line between package decl and rest of the file. + fmt.Fprintf(&header, "package %s_test\n\n", pkg.Types().Name()) + } + + // Write the copyright and package decl to the beginning of the file. + if text := header.String(); len(text) != 0 { + edits = append(edits, protocol.TextEdit{ + Range: protocol.Range{}, + NewText: text, + }) + } + + // TODO(hxjiang): reject if the function/method is unexported. + // TODO(hxjiang): modify existing imports or add new imports. + + // If the parse go file is missing, the fileEnd is the file start (zero value). + fileEnd := protocol.Range{} + if testPgf != nil { + fileEnd, err = testPgf.PosRange(testPgf.File.FileEnd, testPgf.File.FileEnd) + if err != nil { + return nil, err + } + } + + // test is the buffer containing the text edit to the test function. + var test bytes.Buffer + // TODO(hxjiang): replace test foo function with table-driven test. + test.WriteString("\nfunc TestFoo(*testing.T) {}") + edits = append(edits, protocol.TextEdit{ + Range: fileEnd, + NewText: test.String(), + }) + return append(changes, protocol.DocumentChangeEdit(testFH, edits)), nil +} diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 844f5dfe46c..3e4f3113f9e 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -227,6 +227,7 @@ type codeActionProducer struct { var codeActionProducers = [...]codeActionProducer{ {kind: protocol.QuickFix, fn: quickFix, needPkg: true}, {kind: protocol.SourceOrganizeImports, fn: sourceOrganizeImports}, + {kind: settings.AddTest, fn: addTest, needPkg: true}, {kind: settings.GoAssembly, fn: goAssembly, needPkg: true}, {kind: settings.GoDoc, fn: goDoc, needPkg: true}, {kind: settings.GoFreeSymbols, fn: goFreeSymbols}, @@ -467,6 +468,41 @@ func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) erro return nil } +// addTest produces "Add a test for FUNC" code actions. +// See [server.commandHandler.AddTest] for command implementation. +func addTest(ctx context.Context, req *codeActionsRequest) error { + // Reject if the feature is turned off. + if !req.snapshot.Options().AddTestSourceCodeAction { + return nil + } + + // Reject test package. + if req.pkg.Metadata().ForTest != "" { + return nil + } + + path, _ := astutil.PathEnclosingInterval(req.pgf.File, req.start, req.end) + if len(path) < 2 { + return nil + } + + decl, ok := path[len(path)-2].(*ast.FuncDecl) + if !ok { + return nil + } + + // Don't offer to create tests of "init" or "_". + if decl.Name.Name == "_" || decl.Name.Name == "init" { + return nil + } + + cmd := command.NewAddTestCommand("Add a test for "+decl.Name.String(), req.loc) + req.addCommandAction(cmd, true) + + // TODO(hxjiang): add code action for generate test for package/file. + return nil +} + // refactorRewriteRemoveUnusedParam produces "Remove unused parameter" code actions. // See [server.commandHandler.ChangeSignature] for command implementation. func refactorRewriteRemoveUnusedParam(ctx context.Context, req *codeActionsRequest) error { diff --git a/gopls/internal/golang/extracttofile.go b/gopls/internal/golang/extracttofile.go index 0a1d74408d7..ae26738a5c3 100644 --- a/gopls/internal/golang/extracttofile.go +++ b/gopls/internal/golang/extracttofile.go @@ -80,7 +80,7 @@ func findImportEdits(file *ast.File, info *types.Info, start, end token.Pos) (ad } // ExtractToNewFile moves selected declarations into a new file. -func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, rng protocol.Range) (*protocol.WorkspaceEdit, error) { +func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, rng protocol.Range) ([]protocol.DocumentChange, error) { errorPrefix := "ExtractToNewFile" pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI()) @@ -160,7 +160,7 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han return nil, err } - return protocol.NewWorkspaceEdit( + return []protocol.DocumentChange{ // edit the original file protocol.DocumentChangeEdit(fh, append(importDeletes, protocol.TextEdit{Range: replaceRange, NewText: ""})), // create a new file @@ -168,7 +168,7 @@ func ExtractToNewFile(ctx context.Context, snapshot *cache.Snapshot, fh file.Han // edit the created file protocol.DocumentChangeEdit(newFile, []protocol.TextEdit{ {Range: protocol.Range{}, NewText: string(newFileContent)}, - })), nil + })}, nil } // chooseNewFile chooses a new filename in dir, based on the name of the diff --git a/gopls/internal/protocol/command/command_gen.go b/gopls/internal/protocol/command/command_gen.go index 10c6c043a09..829a3824bc0 100644 --- a/gopls/internal/protocol/command/command_gen.go +++ b/gopls/internal/protocol/command/command_gen.go @@ -27,6 +27,7 @@ const ( AddDependency Command = "gopls.add_dependency" AddImport Command = "gopls.add_import" AddTelemetryCounters Command = "gopls.add_telemetry_counters" + AddTest Command = "gopls.add_test" ApplyFix Command = "gopls.apply_fix" Assembly Command = "gopls.assembly" ChangeSignature Command = "gopls.change_signature" @@ -71,6 +72,7 @@ var Commands = []Command{ AddDependency, AddImport, AddTelemetryCounters, + AddTest, ApplyFix, Assembly, ChangeSignature, @@ -131,6 +133,12 @@ func Dispatch(ctx context.Context, params *protocol.ExecuteCommandParams, s Inte return nil, err } return nil, s.AddTelemetryCounters(ctx, a0) + case AddTest: + var a0 protocol.Location + if err := UnmarshalArgs(params.Arguments, &a0); err != nil { + return nil, err + } + return s.AddTest(ctx, a0) case ApplyFix: var a0 ApplyFixArgs if err := UnmarshalArgs(params.Arguments, &a0); err != nil { @@ -372,6 +380,14 @@ func NewAddTelemetryCountersCommand(title string, a0 AddTelemetryCountersArgs) * } } +func NewAddTestCommand(title string, a0 protocol.Location) *protocol.Command { + return &protocol.Command{ + Title: title, + Command: AddTest.String(), + Arguments: MustMarshalArgs(a0), + } +} + func NewApplyFixCommand(title string, a0 ApplyFixArgs) *protocol.Command { return &protocol.Command{ Title: title, diff --git a/gopls/internal/protocol/command/interface.go b/gopls/internal/protocol/command/interface.go index eda608a8425..258e1008395 100644 --- a/gopls/internal/protocol/command/interface.go +++ b/gopls/internal/protocol/command/interface.go @@ -224,6 +224,9 @@ type Interface interface { // to avoid conflicts with other counters gopls collects. AddTelemetryCounters(context.Context, AddTelemetryCountersArgs) error + // AddTest: add a test for the selected function + AddTest(context.Context, protocol.Location) (*protocol.WorkspaceEdit, error) + // MaybePromptForTelemetry: Prompt user to enable telemetry // // Checks for the right conditions, and then prompts the user diff --git a/gopls/internal/server/command.go b/gopls/internal/server/command.go index bfc8f9c5565..403eadf0d2c 100644 --- a/gopls/internal/server/command.go +++ b/gopls/internal/server/command.go @@ -275,6 +275,24 @@ func (*commandHandler) AddTelemetryCounters(_ context.Context, args command.AddT return nil } +func (c *commandHandler) AddTest(ctx context.Context, loc protocol.Location) (*protocol.WorkspaceEdit, error) { + var result *protocol.WorkspaceEdit + err := c.run(ctx, commandConfig{ + forURI: loc.URI, + }, func(ctx context.Context, deps commandDeps) error { + if deps.snapshot.FileKind(deps.fh) != file.Go { + return fmt.Errorf("can't add test for non-Go file") + } + docedits, err := golang.AddTestForFunc(ctx, deps.snapshot, loc) + if err != nil { + return err + } + return applyChanges(ctx, c.s.client, docedits) + }) + // TODO(hxjiang): move the cursor to the new test once edits applied. + return result, err +} + // commandConfig configures common command set-up and execution. type commandConfig struct { requireSave bool // whether all files must be saved for the command to work @@ -388,16 +406,7 @@ func (c *commandHandler) ApplyFix(ctx context.Context, args command.ApplyFixArgs result = wsedit return nil } - resp, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{ - Edit: *wsedit, - }) - if err != nil { - return err - } - if !resp.Applied { - return errors.New(resp.FailureReason) - } - return nil + return applyChanges(ctx, c.s.client, changes) }) return result, err } @@ -622,17 +631,7 @@ func (c *commandHandler) RemoveDependency(ctx context.Context, args command.Remo if err != nil { return err } - response, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{ - Edit: *protocol.NewWorkspaceEdit( - protocol.DocumentChangeEdit(deps.fh, edits)), - }) - if err != nil { - return err - } - if !response.Applied { - return fmt.Errorf("edits not applied because of %s", response.FailureReason) - } - return nil + return applyChanges(ctx, c.s.client, []protocol.DocumentChange{protocol.DocumentChangeEdit(deps.fh, edits)}) }) } @@ -1107,17 +1106,7 @@ func (c *commandHandler) AddImport(ctx context.Context, args command.AddImportAr if err != nil { return fmt.Errorf("could not add import: %v", err) } - r, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{ - Edit: *protocol.NewWorkspaceEdit( - protocol.DocumentChangeEdit(deps.fh, edits)), - }) - if err != nil { - return fmt.Errorf("could not apply import edits: %v", err) - } - if !r.Applied { - return fmt.Errorf("failed to apply edits: %v", r.FailureReason) - } - return nil + return applyChanges(ctx, c.s.client, []protocol.DocumentChange{protocol.DocumentChangeEdit(deps.fh, edits)}) }) } @@ -1126,18 +1115,11 @@ func (c *commandHandler) ExtractToNewFile(ctx context.Context, args protocol.Loc progress: "Extract to a new file", forURI: args.URI, }, func(ctx context.Context, deps commandDeps) error { - edit, err := golang.ExtractToNewFile(ctx, deps.snapshot, deps.fh, args.Range) + changes, err := golang.ExtractToNewFile(ctx, deps.snapshot, deps.fh, args.Range) if err != nil { return err } - resp, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{Edit: *edit}) - if err != nil { - return fmt.Errorf("could not apply edits: %v", err) - } - if !resp.Applied { - return fmt.Errorf("edits not applied: %s", resp.FailureReason) - } - return nil + return applyChanges(ctx, c.s.client, changes) }) } @@ -1543,13 +1525,7 @@ func (c *commandHandler) ChangeSignature(ctx context.Context, args command.Chang result = wsedit return nil } - r, err := c.s.client.ApplyEdit(ctx, &protocol.ApplyWorkspaceEditParams{ - Edit: *wsedit, - }) - if !r.Applied { - return fmt.Errorf("failed to apply edits: %v", r.FailureReason) - } - return nil + return applyChanges(ctx, c.s.client, docedits) }) return result, err } diff --git a/gopls/internal/settings/codeactionkind.go b/gopls/internal/settings/codeactionkind.go index 177431b5f06..16a2eecb2cb 100644 --- a/gopls/internal/settings/codeactionkind.go +++ b/gopls/internal/settings/codeactionkind.go @@ -79,6 +79,7 @@ const ( GoDoc protocol.CodeActionKind = "source.doc" GoFreeSymbols protocol.CodeActionKind = "source.freesymbols" GoTest protocol.CodeActionKind = "source.test" + AddTest protocol.CodeActionKind = "source.addTest" // gopls GoplsDocFeatures protocol.CodeActionKind = "gopls.doc.features" diff --git a/gopls/internal/settings/default.go b/gopls/internal/settings/default.go index 25f3eae80f5..2f637f3d16d 100644 --- a/gopls/internal/settings/default.go +++ b/gopls/internal/settings/default.go @@ -136,6 +136,7 @@ func DefaultOptions(overrides ...func(*Options)) *Options { LinkifyShowMessage: false, IncludeReplaceInWorkspace: false, ZeroConfig: true, + AddTestSourceCodeAction: false, }, } }) diff --git a/gopls/internal/settings/settings.go b/gopls/internal/settings/settings.go index 3d97a22cafe..02c59163609 100644 --- a/gopls/internal/settings/settings.go +++ b/gopls/internal/settings/settings.go @@ -700,6 +700,11 @@ type InternalOptions struct { // TODO(rfindley): make pull diagnostics robust, and remove this option, // allowing pull diagnostics by default. PullDiagnostics bool + + // AddTestSourceCodeAction enables support for adding test as a source code + // action. + // TODO(hxjiang): remove this option once the feature is implemented. + AddTestSourceCodeAction bool } type SubdirWatchPatterns string @@ -980,6 +985,8 @@ func (o *Options) setOne(name string, value any) error { return setBool(&o.DeepCompletion, value) case "completeUnimported": return setBool(&o.CompleteUnimported, value) + case "addTestSourceCodeAction": + return setBool(&o.AddTestSourceCodeAction, value) case "completionBudget": return setDuration(&o.CompletionBudget, value) case "matcher": diff --git a/gopls/internal/test/marker/testdata/codeaction/addtest.txt b/gopls/internal/test/marker/testdata/codeaction/addtest.txt new file mode 100644 index 00000000000..cc597acdbcf --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/addtest.txt @@ -0,0 +1,53 @@ +This test checks the behavior of the 'add test for FUNC' code action. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/copyright + +go 1.18 + +-- settings.json -- +{ + "addTestSourceCodeAction": true +} + +-- withcopyright/copyright.go -- +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.23 + +// Package copyright is for lsp test. +package copyright + +func Foo(in string) string {return in} //@codeactionedit("Foo", "source.addTest", with_copyright) + +-- @with_copyright/withcopyright/copyright_test.go -- +@@ -0,0 +1,8 @@ ++// Copyright 2020 The Go Authors. All rights reserved. ++// Use of this source code is governed by a BSD-style ++// license that can be found in the LICENSE file. ++ ++package copyright_test ++ ++ ++func TestFoo(*testing.T) {} +\ No newline at end of file +-- withoutcopyright/copyright.go -- +//go:build go1.23 + +// Package copyright is for lsp test. +package copyright + +func Foo(in string) string {return in} //@codeactionedit("Foo", "source.addTest", without_copyright) + +-- @without_copyright/withoutcopyright/copyright_test.go -- +@@ -0,0 +1,4 @@ ++package copyright_test ++ ++ ++func TestFoo(*testing.T) {} +\ No newline at end of file