diff --git a/pkg/gui/controllers/branches_controller.go b/pkg/gui/controllers/branches_controller.go index a97168fc15f..f63d4d9f191 100644 --- a/pkg/gui/controllers/branches_controller.go +++ b/pkg/gui/controllers/branches_controller.go @@ -728,13 +728,26 @@ func (self *BranchesController) createPullRequestMenu(selectedBranch *models.Bra }, }, { + // TODO: Replace with "Select remote and branch"? LabelColumns: fromToLabelColumns(branch.Name, self.c.Tr.SelectBranch), OnPress: func() error { + if !branch.IsTrackingRemote() { + return errors.New(self.c.Tr.PullRequestNoUpstream) + } + + if len(self.c.Model().Remotes) == 1 { + toRemote := self.c.Model().Remotes[0].Name + self.c.Log.Debugf("PR will target the only existing remote '%s'", toRemote) + return self.promptForTargetBranchNameAndCreatePullRequest(branch, toRemote) + } + self.c.Prompt(types.PromptOpts{ - Title: branch.Name + " →", - FindSuggestionsFunc: self.c.Helpers().Suggestions.GetRemoteBranchesSuggestionsFunc("/"), - HandleConfirm: func(targetBranchName string) error { - return self.createPullRequest(branch.Name, targetBranchName) + Title: "Select Target Remote", + FindSuggestionsFunc: self.c.Helpers().Suggestions.GetRemoteSuggestionsFunc(), + HandleConfirm: func(toRemote string) error { + self.c.Log.Debugf("PR will target remote '%s'", toRemote) + + return self.promptForTargetBranchNameAndCreatePullRequest(branch, toRemote) }, }) @@ -764,6 +777,19 @@ func (self *BranchesController) createPullRequestMenu(selectedBranch *models.Bra return self.c.Menu(types.CreateMenuOptions{Title: fmt.Sprint(self.c.Tr.CreatePullRequestOptions), Items: menuItems}) } +func (self *BranchesController) promptForTargetBranchNameAndCreatePullRequest(fromBranch *models.Branch, toRemote string) error { + self.c.Prompt(types.PromptOpts{ + Title: fmt.Sprintf("%s → %s/", fromBranch.UpstreamBranch, toRemote), + FindSuggestionsFunc: self.c.Helpers().Suggestions.GetRemoteBranchesForRemoteSuggestionsFunc(toRemote), + HandleConfirm: func(toBranch string) error { + self.c.Log.Debugf("PR will target branch '%s' on remote '%s'", toBranch, toRemote) + return self.createPullRequest(fromBranch.UpstreamBranch, toBranch) + }, + }) + + return nil +} + func (self *BranchesController) createPullRequest(from string, to string) error { url, err := self.c.Helpers().Host.GetPullRequestURL(from, to) if err != nil { diff --git a/pkg/gui/controllers/helpers/suggestions_helper.go b/pkg/gui/controllers/helpers/suggestions_helper.go index 441a488b52e..0e0fe3521f3 100644 --- a/pkg/gui/controllers/helpers/suggestions_helper.go +++ b/pkg/gui/controllers/helpers/suggestions_helper.go @@ -162,10 +162,38 @@ func (self *SuggestionsHelper) getRemoteBranchNames(separator string) []string { }) } +func (self *SuggestionsHelper) getRemoteBranchNamesForRemote(remoteName string) []string { + for _, remote := range self.c.Model().Remotes { + if remote.Name == remoteName { + return lo.Map(remote.Branches, func(branch *models.RemoteBranch, _ int) string { + return branch.Name + }) + } + } + + return nil +} + +func (self *SuggestionsHelper) getRemoteBranchNamesWithoutRemotePrefix() []string { + return lo.FlatMap(self.c.Model().Remotes, func(remote *models.Remote, _ int) []string { + return lo.Map(remote.Branches, func(branch *models.RemoteBranch, _ int) string { + return branch.Name + }) + }) +} + func (self *SuggestionsHelper) GetRemoteBranchesSuggestionsFunc(separator string) func(string) []*types.Suggestion { return FilterFunc(self.getRemoteBranchNames(separator), self.c.UserConfig().Gui.UseFuzzySearch()) } +func (self *SuggestionsHelper) GetRemoteBranchesForRemoteSuggestionsFunc(remoteName string) func(string) []*types.Suggestion { + return FilterFunc(self.getRemoteBranchNamesForRemote(remoteName), self.c.UserConfig().Gui.UseFuzzySearch()) +} + +func (self *SuggestionsHelper) GetRemoteBranchesWithoutRemotePrefixSuggestionsFunc() func(string) []*types.Suggestion { + return FilterFunc(self.getRemoteBranchNamesWithoutRemotePrefix(), self.c.UserConfig().Gui.UseFuzzySearch()) +} + func (self *SuggestionsHelper) getTagNames() []string { return lo.Map(self.c.Model().Tags, func(tag *models.Tag, _ int) string { return tag.Name