diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 3e532c11fd..33a578a72d 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -217,7 +217,7 @@ func CreatePullArgParser() *argparser.ArgParser { ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"remoteBranch", "The name of a branch on the specified remote to be merged into the current working set."}) ap.SupportsFlag(SquashParam, "", "Merge changes to the working set without updating the commit history") ap.SupportsFlag(NoFFParam, "", "Create a merge commit even when the merge resolves as a fast-forward.") - ap.SupportsFlag(ForceFlag, "f", "Ignore any foreign key warnings and proceed with the commit.") + ap.SupportsFlag(ForceFlag, "f", "Update from the remote HEAD even if there are errors.") ap.SupportsFlag(CommitFlag, "", "Perform the merge and commit the result. This is the default option, but can be overridden with the --no-commit flag. Note that this option does not affect fast-forward merges, which don't create a new merge commit, and if any merge conflicts or constraint violations are detected, no commit will be attempted.") ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.") ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.") diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index 8f5e59b94b..46a08580a8 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -497,7 +497,7 @@ func validateMergeSpec(ctx context.Context, spec *merge.MergeSpec) errhand.Verbo if _, err := merge.MayHaveConstraintViolations(ctx, ancRoot, mergedRoot); err != nil { return errhand.VerboseErrorFromError(err) } - if !spec.Noff { + if !spec.NoFF { cli.Println("Fast-forward") } } else if err == doltdb.ErrUpToDate || err == doltdb.ErrIsAhead { @@ -728,7 +728,7 @@ func performMerge(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryis if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); err != nil && !errors.Is(err, doltdb.ErrUpToDate) { return nil, err } else if ok { - if spec.Noff { + if spec.NoFF { return executeNoFFMergeAndCommit(ctx, sqlCtx, queryist, dEnv, spec, suggestedMsg, cliCtx) } return nil, merge.ExecuteFFMerge(ctx, dEnv, spec) @@ -763,18 +763,17 @@ func executeNoFFMergeAndCommit(ctx context.Context, sqlCtx *sql.Context, queryis mergeParentCommits = []*doltdb.Commit{ws.MergeState().Commit()} } - msg, err := getCommitMsgForMerge(ctx, sqlCtx, queryist, spec.Msg, suggestedMsg, spec.NoEdit, cliCtx) + msg, err := getCommitMsgForMerge(sqlCtx, queryist, suggestedMsg, spec.NoEdit, cliCtx) if err != nil { return tblToStats, err } pendingCommit, err := actions.GetCommitStaged(ctx, roots, ws, mergeParentCommits, dEnv.DbData().Ddb, actions.CommitStagedProps{ - Message: msg, - Date: spec.Date, - AllowEmpty: spec.AllowEmpty, - Force: spec.Force, - Name: spec.Name, - Email: spec.Email, + Message: msg, + Date: spec.Date, + Force: spec.Force, + Name: spec.Name, + Email: spec.Email, }) headRef, err := dEnv.RepoStateReader().CWBHeadRef() @@ -816,7 +815,7 @@ func executeMergeAndCommit(ctx context.Context, sqlCtx *sql.Context, queryist cl return tblToStats, nil } - msg, err := getCommitMsgForMerge(ctx, sqlCtx, queryist, spec.Msg, suggestedMsg, spec.NoEdit, cliCtx) + msg, err := getCommitMsgForMerge(sqlCtx, queryist, suggestedMsg, spec.NoEdit, cliCtx) if err != nil { return tblToStats, err } @@ -832,11 +831,13 @@ func executeMergeAndCommit(ctx context.Context, sqlCtx *sql.Context, queryist cl } // getCommitMsgForMerge returns user defined message if exists; otherwise, get the commit message from editor. -func getCommitMsgForMerge(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryist, userDefinedMsg, suggestedMsg string, noEdit bool, cliCtx cli.CliContext) (string, error) { - if userDefinedMsg != "" { - return userDefinedMsg, nil - } - +func getCommitMsgForMerge( + sqlCtx *sql.Context, + queryist cli.Queryist, + suggestedMsg string, + noEdit bool, + cliCtx cli.CliContext, +) (string, error) { msg, err := getCommitMessageFromEditor(sqlCtx, queryist, suggestedMsg, "", noEdit, cliCtx) if err != nil { return msg, err diff --git a/go/cmd/dolt/commands/pull.go b/go/cmd/dolt/commands/pull.go index 612465df16..d228e78783 100644 --- a/go/cmd/dolt/commands/pull.go +++ b/go/cmd/dolt/commands/pull.go @@ -16,6 +16,7 @@ package commands import ( "context" + "errors" "fmt" "github.com/dolthub/go-mysql-server/sql" @@ -107,7 +108,20 @@ func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, d return HandleVErrAndExitCode(verr, usage) } - pullSpec, err := env.NewPullSpec(ctx, dEnv.RepoStateReader(), remoteName, remoteRefName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.NoCommitFlag), apr.Contains(cli.NoEditFlag), apr.Contains(cli.ForceFlag), apr.NArg() == 1) + remoteOnly := apr.NArg() == 1 + pullSpec, err := env.NewPullSpec( + ctx, + dEnv.RepoStateReader(), + remoteName, + remoteRefName, + remoteOnly, + env.WithSquash(apr.Contains(cli.SquashParam)), + env.WithNoFF(apr.Contains(cli.NoFFParam)), + env.WithNoCommit(apr.Contains(cli.NoCommitFlag)), + env.WithNoEdit(apr.Contains(cli.NoEditFlag)), + env.WithForce(apr.Contains(cli.ForceFlag)), + ) + if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -120,7 +134,14 @@ func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, d } // pullHelper splits pull into fetch, prepare merge, and merge to interleave printing -func pullHelper(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryist, dEnv *env.DoltEnv, pullSpec *env.PullSpec, cliCtx cli.CliContext) error { +func pullHelper( + ctx context.Context, + sqlCtx *sql.Context, + queryist cli.Queryist, + dEnv *env.DoltEnv, + pullSpec *env.PullSpec, + cliCtx cli.CliContext, +) error { srcDB, err := pullSpec.Remote.GetRemoteDBWithoutCaching(ctx, dEnv.DoltDB.ValueReadWriter().Format(), dEnv) if err != nil { return fmt.Errorf("failed to get remote db; %w", err) @@ -160,8 +181,18 @@ func pullHelper(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryist, } err = dEnv.DoltDB.FastForward(ctx, remoteTrackRef, srcDBCommit) - if err != nil { - return fmt.Errorf("fetch failed; %w", err) + if errors.Is(err, datas.ErrMergeNeeded) { + // If the remote tracking branch has diverged from the local copy, we just overwrite it + h, err := srcDBCommit.HashOf() + if err != nil { + return err + } + err = dEnv.DoltDB.SetHead(ctx, remoteTrackRef, h) + if err != nil { + return err + } + } else if err != nil { + return fmt.Errorf("fetch failed: %w", err) } // Merge iff branch is current branch and there is an upstream set (pullSpec.Branch is set to nil if there is no upstream) @@ -178,16 +209,16 @@ func pullHelper(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryist, name, email, configErr := env.GetNameAndEmail(dEnv.Config) // If the name and email aren't set we can set them to empty values for now. This is only valid for ff - // merges which detect for later. + // merges which we detect later. if configErr != nil { - if pullSpec.Noff { + if pullSpec.NoFF { return configErr } name, email = "", "" } - // Begin merge - mergeSpec, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, pullSpec.Msg, remoteTrackRef.String(), pullSpec.Squash, pullSpec.Noff, pullSpec.Force, pullSpec.NoCommit, pullSpec.NoEdit, t) + // Begin merge of working and head with the remote head + mergeSpec, err := merge.NewMergeSpec(ctx, dEnv.RepoStateReader(), dEnv.DoltDB, roots, name, email, remoteTrackRef.String(), t, merge.WithPullSpecOpts(pullSpec)) if err != nil { return err } @@ -217,7 +248,12 @@ func pullHelper(ctx context.Context, sqlCtx *sql.Context, queryist cli.Queryist, return err } - suggestedMsg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath()) + suggestedMsg := fmt.Sprintf( + "Merge branch '%s' of %s into %s", + pullSpec.Branch.GetPath(), + pullSpec.Remote.Url, + headRef.GetPath(), + ) tblStats, err := performMerge(ctx, sqlCtx, queryist, dEnv, mergeSpec, suggestedMsg, cliCtx) printSuccessStats(tblStats) if err != nil { diff --git a/go/libraries/doltcore/doltdb/doltdb.go b/go/libraries/doltcore/doltdb/doltdb.go index 0838816b67..3cf36a3a13 100644 --- a/go/libraries/doltcore/doltdb/doltdb.go +++ b/go/libraries/doltcore/doltdb/doltdb.go @@ -579,10 +579,6 @@ func (ddb *DoltDB) ReadCommit(ctx context.Context, h hash.Hash) (*Commit, error) // Commit will update a branch's head value to be that of a previously committed root value hash func (ddb *DoltDB) Commit(ctx context.Context, valHash hash.Hash, dref ref.DoltRef, cm *datas.CommitMeta) (*Commit, error) { - if dref.GetType() != ref.BranchRefType { - panic("can't commit to ref that isn't branch atm. will probably remove this.") - } - return ddb.CommitWithParentSpecs(ctx, valHash, dref, nil, cm) } diff --git a/go/libraries/doltcore/env/remotes.go b/go/libraries/doltcore/env/remotes.go index f7e85cd3cc..f5045addff 100644 --- a/go/libraries/doltcore/env/remotes.go +++ b/go/libraries/doltcore/env/remotes.go @@ -391,9 +391,8 @@ func GetTrackingRef(branchRef ref.DoltRef, remote Remote) (ref.DoltRef, error) { } type PullSpec struct { - Msg string Squash bool - Noff bool + NoFF bool NoCommit bool NoEdit bool Force bool @@ -403,10 +402,47 @@ type PullSpec struct { Branch ref.DoltRef } -// NewPullSpec returns PullSpec object using arguments passed into this function, which are remoteName, remoteRefName, -// squash, noff, noCommit, noEdit, refSpecs, force and remoteOnly. This function validates remote and gets remoteRef +type PullSpecOpt func(*PullSpec) + +func WithSquash(squash bool) PullSpecOpt { + return func(ps *PullSpec) { + ps.Squash = squash + } +} + +func WithNoFF(noff bool) PullSpecOpt { + return func(ps *PullSpec) { + ps.NoFF = noff + } +} + +func WithNoCommit(nocommit bool) PullSpecOpt { + return func(ps *PullSpec) { + ps.NoCommit = nocommit + } +} + +func WithNoEdit(noedit bool) PullSpecOpt { + return func(ps *PullSpec) { + ps.NoEdit = noedit + } +} + +func WithForce(force bool) PullSpecOpt { + return func(ps *PullSpec) { + ps.Force = force + } +} + +// NewPullSpec returns a PullSpec for the arguments given. This function validates remote and gets remoteRef // for given remoteRefName; if it's not defined, it uses current branch to get its upstream branch if it exists. -func NewPullSpec(_ context.Context, rsr RepoStateReader, remoteName, remoteRefName string, squash, noff, noCommit, noEdit, force, remoteOnly bool) (*PullSpec, error) { +func NewPullSpec( + _ context.Context, + rsr RepoStateReader, + remoteName, remoteRefName string, + remoteOnly bool, + opts ...PullSpecOpt, +) (*PullSpec, error) { refSpecs, err := GetRefSpecs(rsr, remoteName) if err != nil { return nil, err @@ -446,17 +482,18 @@ func NewPullSpec(_ context.Context, rsr RepoStateReader, remoteName, remoteRefNa remoteRef = ref.NewBranchRef(remoteRefName) } - return &PullSpec{ - Squash: squash, - Noff: noff, - NoCommit: noCommit, - NoEdit: noEdit, + spec := &PullSpec{ RemoteName: remoteName, Remote: remote, RefSpecs: refSpecs, Branch: remoteRef, - Force: force, - }, nil + } + + for _, opt := range opts { + opt(spec) + } + + return spec, nil } func GetAbsRemoteUrl(fs filesys2.Filesys, cfg config.ReadableConfig, urlArg string) (string, string, error) { diff --git a/go/libraries/doltcore/merge/action.go b/go/libraries/doltcore/merge/action.go index 307ddbc1fc..02fbc7654c 100644 --- a/go/libraries/doltcore/merge/action.go +++ b/go/libraries/doltcore/merge/action.go @@ -28,10 +28,7 @@ import ( "github.com/dolthub/dolt/go/store/hash" ) -var ErrFailedToDetermineUnstagedDocs = errors.New("failed to determine unstaged docs") var ErrFailedToReadDatabase = errors.New("failed to read database") -var ErrMergeFailedToUpdateDocs = errors.New("failed to update docs to the new working root") -var ErrMergeFailedToUpdateRepoState = errors.New("unable to execute repo state update") var ErrFailedToDetermineMergeability = errors.New("failed to determine mergeability") type MergeSpec struct { @@ -43,21 +40,67 @@ type MergeSpec struct { StompedTblNames []string WorkingDiffs map[string]hash.Hash Squash bool - Msg string - Noff bool + NoFF bool NoCommit bool NoEdit bool Force bool - AllowEmpty bool Email string Name string Date time.Time } -// NewMergeSpec returns MergeSpec object using arguments passed into this function, which are doltdb.Roots, username, -// user email, commit msg, commitSpecStr, to squash, to noff, to force, noCommit, noEdit and date. This function -// resolves head and merge commit, and it gets current diffs between current head and working set if it exists. -func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.DoltDB, roots doltdb.Roots, name, email, msg, commitSpecStr string, squash, noff, force, noCommit, noEdit bool, date time.Time) (*MergeSpec, error) { +type MergeSpecOpt func(*MergeSpec) + +func WithNoFF(noFF bool) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.NoFF = noFF + } +} + +func WithNoCommit(noCommit bool) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.NoCommit = noCommit + } +} + +func WithNoEdit(noEdit bool) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.NoEdit = noEdit + } +} + +func WithForce(force bool) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.Force = force + } +} + +func WithSquash(squash bool) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.Squash = squash + } +} + +func WithPullSpecOpts(pullSpec *env.PullSpec) MergeSpecOpt { + return func(ms *MergeSpec) { + ms.NoEdit = pullSpec.NoEdit + ms.NoCommit = pullSpec.NoCommit + ms.Force = pullSpec.Force + ms.NoFF = pullSpec.NoFF + ms.Squash = pullSpec.Squash + } +} + +// NewMergeSpec returns a MergeSpec with the arguments provided. +func NewMergeSpec( + ctx context.Context, + rsr env.RepoStateReader, + ddb *doltdb.DoltDB, + roots doltdb.Roots, + name, email, commitSpecStr string, + date time.Time, + opts ...MergeSpecOpt, +) (*MergeSpec, error) { headCS, err := doltdb.NewCommitSpec("HEAD") if err != nil { return nil, err @@ -99,7 +142,7 @@ func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.Dolt return nil, fmt.Errorf("%w; %s", ErrFailedToDetermineMergeability, err.Error()) } - return &MergeSpec{ + spec := &MergeSpec{ HeadH: headH, MergeH: mergeH, HeadC: headCM, @@ -107,16 +150,16 @@ func NewMergeSpec(ctx context.Context, rsr env.RepoStateReader, ddb *doltdb.Dolt MergeC: mergeCM, StompedTblNames: stompedTblNames, WorkingDiffs: workingDiffs, - Squash: squash, - Msg: msg, - Noff: noff, - NoCommit: noCommit, - NoEdit: noEdit, - Force: force, Email: email, Name: name, Date: date, - }, nil + } + + for _, opt := range opts { + opt(spec) + } + + return spec, nil } func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map[string]*MergeStats, error) { diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go b/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go index c118ff174e..a470ecc0c1 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go @@ -155,7 +155,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, error) { msg = userMsg } - ws, commit, conflicts, fastForward, err := performMerge(ctx, sess, roots, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) + ws, commit, conflicts, fastForward, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) if err != nil || conflicts != 0 || fastForward != 0 { return commit, conflicts, fastForward, err } @@ -169,7 +169,15 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, error) { // fast-forward was performed. This commits the working set if merge is successful and // 'no-commit' flag is not defined. // TODO FF merging commit with constraint violations requires `constraint verify` -func performMerge(ctx *sql.Context, sess *dsess.DoltSession, roots doltdb.Roots, ws *doltdb.WorkingSet, dbName string, spec *merge.MergeSpec, noCommit bool, msg string) (*doltdb.WorkingSet, string, int, int, error) { +func performMerge( + ctx *sql.Context, + sess *dsess.DoltSession, + ws *doltdb.WorkingSet, + dbName string, + spec *merge.MergeSpec, + noCommit bool, + msg string, +) (*doltdb.WorkingSet, string, int, int, error) { // todo: allow merges even when an existing merge is uncommitted if ws.MergeActive() { return ws, "", noConflictsOrViolations, threeWayMerge, doltdb.ErrMergeActive @@ -195,9 +203,9 @@ func performMerge(ctx *sql.Context, sess *dsess.DoltSession, roots doltdb.Roots, } if canFF { - if spec.Noff { + if spec.NoFF { var commit *doltdb.Commit - ws, commit, err = executeNoFFMerge(ctx, sess, spec, dbName, ws, dbData, noCommit) + ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit) if err == doltdb.ErrUnresolvedConflictsOrViolations { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message @@ -302,7 +310,17 @@ func abortMerge(ctx *sql.Context, workingSet *doltdb.WorkingSet, roots doltdb.Ro return workingSet, nil } -func executeMerge(ctx *sql.Context, sess *dsess.DoltSession, dbName string, squash bool, head, cm *doltdb.Commit, cmSpec string, ws *doltdb.WorkingSet, opts editor.Options, workingDiffs map[string]hash.Hash) (*doltdb.WorkingSet, error) { +func executeMerge( + ctx *sql.Context, + sess *dsess.DoltSession, + dbName string, + squash bool, + head, cm *doltdb.Commit, + cmSpec string, + ws *doltdb.WorkingSet, + opts editor.Options, + workingDiffs map[string]hash.Hash, +) (*doltdb.WorkingSet, error) { result, err := merge.MergeCommits(ctx, head, cm, opts) if err != nil { switch err { @@ -369,9 +387,9 @@ func executeNoFFMerge( ctx *sql.Context, dSess *dsess.DoltSession, spec *merge.MergeSpec, + msg string, dbName string, ws *doltdb.WorkingSet, - dbData env.DbData, noCommit bool, ) (*doltdb.WorkingSet, *doltdb.Commit, error) { mergeRoot, err := spec.MergeC.GetRootValue(ctx) @@ -412,12 +430,11 @@ func executeNoFFMerge( } pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{ - Message: spec.Msg, - Date: spec.Date, - AllowEmpty: spec.AllowEmpty, - Force: spec.Force, - Name: spec.Name, - Email: spec.Email, + Message: msg, + Date: spec.Date, + Force: spec.Force, + Name: spec.Name, + Email: spec.Email, }) if err != nil { return nil, nil, err @@ -440,22 +457,9 @@ func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, a dbData, ok := sess.GetDbData(ctx, dbName) - msg, ok := apr.GetValue(cli.MessageArg) - if !ok { - // TODO probably change, but we can't open editor so it'll have to be automated - msg = "automatic SQL merge" - } - - var err error - var name, email string - if authorStr, ok := apr.GetValue(cli.AuthorParam); ok { - name, email, err = cli.ParseAuthor(authorStr) - if err != nil { - return nil, err - } - } else { - name = ctx.Client().User - email = fmt.Sprintf("%s@%s", ctx.Client().User, ctx.Client().Address) + name, email, err := getNameAndEmail(ctx, apr) + if err != nil { + return nil, err } t := ctx.QueryTime() @@ -474,7 +478,36 @@ func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, a if apr.Contains(cli.NoCommitFlag) && apr.Contains(cli.CommitFlag) { return nil, errors.New("cannot define both 'commit' and 'no-commit' flags at the same time") } - return merge.NewMergeSpec(ctx, dbData.Rsr, ddb, roots, name, email, msg, commitSpecStr, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.ForceFlag), apr.Contains(cli.NoCommitFlag), apr.Contains(cli.NoEditFlag), t) + return merge.NewMergeSpec( + ctx, + dbData.Rsr, + ddb, + roots, + name, + email, + commitSpecStr, + t, + merge.WithSquash(apr.Contains(cli.SquashParam)), + merge.WithNoFF(apr.Contains(cli.NoFFParam)), + merge.WithForce(apr.Contains(cli.ForceFlag)), + merge.WithNoCommit(apr.Contains(cli.NoCommitFlag)), + merge.WithNoEdit(apr.Contains(cli.NoEditFlag)), + ) +} + +func getNameAndEmail(ctx *sql.Context, apr *argparser.ArgParseResults) (string, string, error) { + var err error + var name, email string + if authorStr, ok := apr.GetValue(cli.AuthorParam); ok { + name, email, err = cli.ParseAuthor(authorStr) + if err != nil { + return "", "", err + } + } else { + name = ctx.Client().User + email = fmt.Sprintf("%s@%s", ctx.Client().User, ctx.Client().Address) + } + return name, email, nil } func mergeRootToWorking( diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go index 3d0c7ce20e..04c8d0e8d9 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go @@ -29,6 +29,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/datas/pull" ) @@ -75,7 +76,19 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { remoteRefName = apr.Arg(1) } - pullSpec, err := env.NewPullSpec(ctx, dbData.Rsr, remoteName, remoteRefName, apr.Contains(cli.SquashParam), apr.Contains(cli.NoFFParam), apr.Contains(cli.NoCommitFlag), apr.Contains(cli.NoEditFlag), apr.Contains(cli.ForceFlag), apr.NArg() == 1) + remoteOnly := apr.NArg() == 1 + pullSpec, err := env.NewPullSpec( + ctx, + dbData.Rsr, + remoteName, + remoteRefName, + remoteOnly, + env.WithSquash(apr.Contains(cli.SquashParam)), + env.WithNoFF(apr.Contains(cli.NoFFParam)), + env.WithNoCommit(apr.Contains(cli.NoCommitFlag)), + env.WithNoEdit(apr.Contains(cli.NoEditFlag)), + env.WithForce(apr.Contains(cli.ForceFlag)), + ) if err != nil { return noConflictsOrViolations, threeWayMerge, err } @@ -127,9 +140,27 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { return noConflictsOrViolations, threeWayMerge, err } + headRef, err := dbData.Rsr.CWBHeadRef() + if err != nil { + return noConflictsOrViolations, threeWayMerge, err + } + + msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath()) + // TODO: this could be replaced with a canFF check to test for error err = dbData.Ddb.FastForward(ctx, remoteTrackRef, srcDBCommit) - if err != nil { + if errors.Is(err, datas.ErrMergeNeeded) { + // If the remote tracking branch has diverged from the local copy, we just overwrite it + // TODO: none of this is transactional + h, err := srcDBCommit.HashOf() + if err != nil { + return noConflictsOrViolations, threeWayMerge, err + } + err = dbData.Ddb.SetHead(ctx, remoteTrackRef, h) + if err != nil { + return noConflictsOrViolations, threeWayMerge, err + } + } else if err != nil { return noConflictsOrViolations, threeWayMerge, fmt.Errorf("fetch failed; %w", err) } @@ -148,11 +179,6 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { return noConflictsOrViolations, threeWayMerge, err } - headRef, err := dbData.Rsr.CWBHeadRef() - if err != nil { - return noConflictsOrViolations, threeWayMerge, err - } - uncommittedChanges, _, _, err := actions.RootHasUncommittedChanges(roots) if err != nil { return noConflictsOrViolations, threeWayMerge, err @@ -161,8 +187,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, error) { return noConflictsOrViolations, threeWayMerge, ErrUncommittedChanges.New() } - msg := fmt.Sprintf("Merge branch '%s' of %s into %s", pullSpec.Branch.GetPath(), pullSpec.Remote.Url, headRef.GetPath()) - ws, _, conflicts, fastForward, err = performMerge(ctx, sess, roots, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) + ws, _, conflicts, fastForward, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) if err != nil && !errors.Is(doltdb.ErrUpToDate, err) { return conflicts, fastForward, err } diff --git a/go/libraries/doltcore/sqle/dsess/dolt_session_test.go b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go index 8b3fd1f923..f13f910913 100644 --- a/go/libraries/doltcore/sqle/dsess/dolt_session_test.go +++ b/go/libraries/doltcore/sqle/dsess/dolt_session_test.go @@ -142,14 +142,14 @@ func TestSetPersistedValue(t *testing.T) { Value: "7", }, { - Name: "bool", - Value: true, - Err: sql.ErrInvalidType, + Name: "bool", + Value: true, + ExpectedRes: "1", }, { - Name: "bool", - Value: false, - Err: sql.ErrInvalidType, + Name: "bool", + Value: false, + ExpectedRes: "0", }, { Value: complex64(7), diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 8172d984e4..086af221d6 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -1575,7 +1575,11 @@ func setPersistedValue(conf config.WritableConfig, key string, value interface{} case string: return config.SetString(conf, key, v) case bool: - return sql.ErrInvalidType.New(v) + if v { + return config.SetInt(conf, key, 1) + } else { + return config.SetInt(conf, key, 0) + } default: return sql.ErrInvalidType.New(v) } diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go index 905e779e44..75a54598f0 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_server_test.go @@ -421,6 +421,57 @@ var DropDatabaseMultiSessionScriptTests = []queries.ScriptTest{ }, } +var PersistVariableTests = []queries.ScriptTest{ + { + Name: "set persisted variables with on and off", + SetUpScript: []string{ + "set @@persist.dolt_skip_replication_errors = on;", + "set @@persist.dolt_read_replica_force_pull = off;", + }, + }, + { + Name: "retrieve persisted variables", + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select @@dolt_skip_replication_errors", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "select @@dolt_read_replica_force_pull", + Expected: []sql.Row{ + {0}, + }, + }, + }, + }, + { + Name: "set persisted variables with 1 and 0", + SetUpScript: []string{ + "set @@persist.dolt_skip_replication_errors = 0;", + "set @@persist.dolt_read_replica_force_pull = 1;", + }, + }, + { + Name: "retrieve persisted variables", + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select @@dolt_skip_replication_errors", + Expected: []sql.Row{ + {0}, + }, + }, + { + Query: "select @@dolt_read_replica_force_pull", + Expected: []sql.Row{ + {1}, + }, + }, + }, + }, +} + // TestDoltMultiSessionBehavior runs tests that exercise multi-session logic on a running SQL server. Statements // are sent through the server, from out of process, instead of directly to the in-process engine API. func TestDoltMultiSessionBehavior(t *testing.T) { @@ -433,6 +484,11 @@ func TestDropDatabaseMultiSessionBehavior(t *testing.T) { testMultiSessionScriptTests(t, DropDatabaseMultiSessionScriptTests) } +// TestPersistVariable tests persisting variables across server starts +func TestPersistVariable(t *testing.T) { + testSerialSessionScriptTests(t, PersistVariableTests) +} + func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { @@ -490,6 +546,62 @@ func testMultiSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { } } +// testSerialSessionScriptTests creates an environment, then for each script starts a server and runs assertions, +// stopping the server in between scripts. Unlike other script test executors, scripts may influence later scripts in +// the block. +func testSerialSessionScriptTests(t *testing.T, tests []queries.ScriptTest) { + dEnv := dtestutils.CreateTestEnv() + serverConfig := sqlserver.DefaultServerConfig() + rand.Seed(time.Now().UnixNano()) + port := 15403 + rand.Intn(25) + serverConfig = serverConfig.WithPort(port) + defer dEnv.DoltDB.Close() + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + sc, serverConfig := startServerOnEnv(t, serverConfig, dEnv) + err := sc.WaitForStart() + require.NoError(t, err) + + conn1, sess1 := newConnection(t, serverConfig) + + t.Run(test.Name, func(t *testing.T) { + for _, setupStatement := range test.SetUpScript { + _, err := sess1.Exec(setupStatement) + require.NoError(t, err) + } + + for _, assertion := range test.Assertions { + t.Run(assertion.Query, func(t *testing.T) { + activeSession := sess1 + rows, err := activeSession.Query(assertion.Query) + + if len(assertion.ExpectedErrStr) > 0 { + require.EqualError(t, err, assertion.ExpectedErrStr) + } else if assertion.ExpectedErr != nil { + require.True(t, assertion.ExpectedErr.Is(err), "expected error %v, got %v", assertion.ExpectedErr, err) + } else if assertion.Expected != nil { + require.NoError(t, err) + assertResultsEqual(t, assertion.Expected, rows) + } else { + require.Fail(t, "unsupported ScriptTestAssertion property: %v", assertion) + } + if rows != nil { + require.NoError(t, rows.Close()) + } + }) + } + }) + + require.NoError(t, conn1.Close()) + + sc.StopServer() + err = sc.WaitForClose() + require.NoError(t, err) + }) + } +} + func makeDestinationSlice(t *testing.T, columnTypes []*gosql.ColumnType) []interface{} { dest := make([]any, len(columnTypes)) for i, columnType := range columnTypes { @@ -548,7 +660,6 @@ func assertResultsEqual(t *testing.T, expected []sql.Row, rows *gosql.Rows) { func startServer(t *testing.T, withPort bool, host string, unixSocketPath string) (*env.DoltEnv, *sqlserver.ServerController, sqlserver.ServerConfig) { dEnv := dtestutils.CreateTestEnv() serverConfig := sqlserver.DefaultServerConfig() - if withPort { rand.Seed(time.Now().UnixNano()) port := 15403 + rand.Intn(25) @@ -561,6 +672,11 @@ func startServer(t *testing.T, withPort bool, host string, unixSocketPath string serverConfig = serverConfig.WithSocket(unixSocketPath) } + onEnv, config := startServerOnEnv(t, serverConfig, dEnv) + return dEnv, onEnv, config +} + +func startServerOnEnv(t *testing.T, serverConfig sqlserver.ServerConfig, dEnv *env.DoltEnv) (*sqlserver.ServerController, sqlserver.ServerConfig) { sc := sqlserver.NewServerController() go func() { _, _ = sqlserver.Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv) @@ -568,7 +684,7 @@ func startServer(t *testing.T, withPort bool, host string, unixSocketPath string err := sc.WaitForStart() require.NoError(t, err) - return dEnv, sc, serverConfig + return sc, serverConfig } // newConnection takes sqlserver.serverConfig and opens a connection, and will return that connection with a new session diff --git a/go/libraries/doltcore/sqle/system_variables.go b/go/libraries/doltcore/sqle/system_variables.go index a18ec75336..6f74e548e1 100644 --- a/go/libraries/doltcore/sqle/system_variables.go +++ b/go/libraries/doltcore/sqle/system_variables.go @@ -59,8 +59,8 @@ func AddDoltSystemVariables() { Scope: sql.SystemVariableScope_Global, Dynamic: true, SetVarHintApplies: false, - Type: types.NewSystemStringType(dsess.ReadReplicaForcePull), - Default: int8(0), + Type: types.NewSystemBoolType(dsess.ReadReplicaForcePull), + Default: int8(1), }, { Name: dsess.SkipReplicationErrors, diff --git a/integration-tests/bats/remotes.bats b/integration-tests/bats/remotes.bats index 56b42d9436..cde038250e 100644 --- a/integration-tests/bats/remotes.bats +++ b/integration-tests/bats/remotes.bats @@ -1057,6 +1057,80 @@ SQL [[ ! "$output" =~ "another test commit" ]] || false } +@test "remotes: dolt_pull() with divergent head" { + dolt remote add test-remote http://localhost:50051/test-org/test-repo + dolt push test-remote main + dolt sql <