diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 4adfc5bcef..e441a7a626 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -97,10 +97,11 @@ func CreateMergeArgParser() *argparser.ArgParser { } func CreateRebaseArgParser() *argparser.ArgParser { - ap := argparser.NewArgParserWithMaxArgs("merge", 1) + ap := argparser.NewArgParserWithMaxArgs("rebase", 1) ap.TooManyArgsErrorFunc = func(receivedArgs []string) error { return errors.New("rebase takes at most one positional argument.") } + ap.SupportsString(EmptyParam, "", "empty", "How to handle commits that are not empty to start, but which become empty after rebasing. Valid values are: drop (default) or keep") ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state") ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan") ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase") @@ -172,6 +173,9 @@ func CreateCheckoutArgParser() *argparser.ArgParser { func CreateCherryPickArgParser() *argparser.ArgParser { ap := argparser.NewArgParserWithMaxArgs("cherrypick", 1) ap.SupportsFlag(AbortParam, "", "Abort the current conflict resolution process, and revert all changes from the in-process cherry-pick operation.") + ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+ + "Note that use of this option only keeps commits that were initially empty. "+ + "Commits which become empty, due to a previous commit, will cause cherry-pick to fail.") ap.TooManyArgsErrorFunc = func(receivedArgs []string) error { return errors.New("cherry-picking multiple commits is not supported yet.") } diff --git a/go/cmd/dolt/cli/cli_context.go b/go/cmd/dolt/cli/cli_context.go index cae0f45e51..659608836e 100644 --- a/go/cmd/dolt/cli/cli_context.go +++ b/go/cmd/dolt/cli/cli_context.go @@ -27,7 +27,12 @@ import ( // LateBindQueryist is a function that will be called the first time Queryist is needed for use. Input is a context which // is appropriate for the call to commence. Output is a Queryist, a sql.Context, a closer function, and an error. -// The closer function is called when the Queryist is no longer needed, typically a defer right after getting it. +// +// The closer function is called when the Queryist is no longer needed, typically a defer right after getting it. If a nil +// closer function is returned, then the caller knows that the queryist returned is being managed by another command. Effectively +// this means you are running in another command's session. This is particularly interesting when running a \checkout in a +// dolt sql session. It makes sense to do so in the context of `dolt sql`, but not in the context of `dolt checkout` when +// connected to a remote server. type LateBindQueryist func(ctx context.Context) (Queryist, *sql.Context, func(), error) // CliContexct is used to pass top level command information down to subcommands. diff --git a/go/cmd/dolt/cli/flags.go b/go/cmd/dolt/cli/flags.go index 441af53c0d..b059de50eb 100644 --- a/go/cmd/dolt/cli/flags.go +++ b/go/cmd/dolt/cli/flags.go @@ -35,6 +35,7 @@ const ( DeleteForceFlag = "D" DepthFlag = "depth" DryRunFlag = "dry-run" + EmptyParam = "empty" ForceFlag = "force" GraphFlag = "graph" HardResetParam = "hard" diff --git a/go/cmd/dolt/cli/messages.go b/go/cmd/dolt/cli/messages.go index 58f57668db..db71cae941 100644 --- a/go/cmd/dolt/cli/messages.go +++ b/go/cmd/dolt/cli/messages.go @@ -19,5 +19,5 @@ package cli const ( // Single variable - the name of the command. `dolt ` is how the commandString is formatted in calls to the Exec method // for dolt commands. - RemoteUnsupportedMsg = "%s can not currently be used when there is a local server running. Please stop your dolt sql-server and try again." + RemoteUnsupportedMsg = "%s can not currently be used when there is a local server running. Please stop your dolt sql-server or connect using `dolt sql` instead." ) diff --git a/go/cmd/dolt/commands/checkout.go b/go/cmd/dolt/commands/checkout.go index 3b98236c53..c09631ab31 100644 --- a/go/cmd/dolt/commands/checkout.go +++ b/go/cmd/dolt/commands/checkout.go @@ -97,15 +97,17 @@ func (cmd CheckoutCmd) Exec(ctx context.Context, commandStr string, args []strin } if closeFunc != nil { defer closeFunc() - } - _, ok := queryEngine.(*engine.SqlEngine) - if !ok { - // Currently checkout does not fully support remote connections. Prevent them from being used until we have better - // CLI session support. - msg := fmt.Sprintf(cli.RemoteUnsupportedMsg, commandStr) - cli.Println(msg) - return 1 + // We only check for this case when checkout is the first command in a session. The reason for this is that checkout + // when connected to a remote server will not work as it won't set the branch. But when operating within the context + // of another session, specifically a \checkout in a dolt sql session, this makes sense. Since no closeFunc would be + // returned, we don't need to check for this case. + _, ok := queryEngine.(*engine.SqlEngine) + if !ok { + msg := fmt.Sprintf(cli.RemoteUnsupportedMsg, commandStr) + cli.Println(msg) + return 1 + } } // Argument validation in the CLI is strictly nice to have. The stored procedure will do the same, but the errors @@ -165,8 +167,8 @@ func (cmd CheckoutCmd) Exec(ctx context.Context, commandStr string, args []strin return HandleVErrAndExitCode(errhand.BuildDError("no 'message' field in response from %s", sqlQuery).Build(), usage) } - var message string - if message, ok = rows[0][1].(string); !ok { + message, ok := rows[0][1].(string) + if !ok { return HandleVErrAndExitCode(errhand.BuildDError("expected string value for 'message' field in response from %s ", sqlQuery).Build(), usage) } diff --git a/go/cmd/dolt/commands/cherry-pick.go b/go/cmd/dolt/commands/cherry-pick.go index 1414d4dbe1..d45272855e 100644 --- a/go/cmd/dolt/commands/cherry-pick.go +++ b/go/cmd/dolt/commands/cherry-pick.go @@ -20,8 +20,6 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" - "github.com/gocraft/dbr/v2" - "github.com/gocraft/dbr/v2/dialect" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/dolt/go/cmd/dolt/cli" @@ -122,11 +120,11 @@ func (cmd CherryPickCmd) Exec(ctx context.Context, commandStr string, args []str return HandleVErrAndExitCode(errhand.BuildDError("cherry-picking multiple commits is not supported yet").SetPrintUsage().Build(), usage) } - err = cherryPick(queryist, sqlCtx, apr) + err = cherryPick(queryist, sqlCtx, apr, args) return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } -func cherryPick(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults) error { +func cherryPick(queryist cli.Queryist, sqlCtx *sql.Context, apr *argparser.ArgParseResults, args []string) error { cherryStr := apr.Arg(0) if len(cherryStr) == 0 { return fmt.Errorf("error: cannot cherry-pick empty string") @@ -154,7 +152,7 @@ hint: commit your changes (dolt commit -am \"\") or reset them (dolt re return fmt.Errorf("error: failed to set @@dolt_force_transaction_commit: %w", err) } - q, err := dbr.InterpolateForDialect("call dolt_cherry_pick(?)", []interface{}{cherryStr}, dialect.MySQL) + q, err := interpolateStoredProcedureCall("DOLT_CHERRY_PICK", args) if err != nil { return fmt.Errorf("error: failed to interpolate query: %w", err) } @@ -200,7 +198,7 @@ hint: commit your changes (dolt commit -am \"\") or reset them (dolt re if succeeded { // on success, print the commit info commit, err := getCommitInfo(queryist, sqlCtx, commitHash) - if err != nil { + if commit == nil || err != nil { return fmt.Errorf("error: failed to get commit metadata for ref '%s': %v", commitHash, err) } diff --git a/go/cmd/dolt/commands/log.go b/go/cmd/dolt/commands/log.go index 03e09226f0..455b0fd917 100644 --- a/go/cmd/dolt/commands/log.go +++ b/go/cmd/dolt/commands/log.go @@ -284,6 +284,9 @@ func logCommits(apr *argparser.ArgParseResults, commitHashes []sql.Row, queryist for _, hash := range commitHashes { cmHash := hash[0].(string) commit, err := getCommitInfo(queryist, sqlCtx, cmHash) + if commit == nil { + return fmt.Errorf("no commits found for ref %s", cmHash) + } if err != nil { return err } diff --git a/go/cmd/dolt/commands/log_graph.go b/go/cmd/dolt/commands/log_graph.go index d2c8fbc7ba..2d25cd83b2 100644 --- a/go/cmd/dolt/commands/log_graph.go +++ b/go/cmd/dolt/commands/log_graph.go @@ -391,7 +391,7 @@ func printOneLineGraph(graph [][]string, pager *outputpager.Pager, apr *argparse pager.Writer.Write([]byte("\n")) } - pager.Writer.Write([]byte(fmt.Sprintf("%s %s ", strings.Join(graph[commits[i].Row], ""), color.YellowString("commit%s ", commits[i].Commit.commitHash)))) + pager.Writer.Write([]byte(fmt.Sprintf("%s %s ", strings.Join(graph[commits[i].Row], ""), color.YellowString("commit %s ", commits[i].Commit.commitHash)))) if decoration != "no" { printRefs(pager, &commits[i].Commit, decoration) } @@ -436,7 +436,7 @@ func printGraphAndCommitsInfo(graph [][]string, pager *outputpager.Pager, apr *a last_commit_row := commits[len(commits)-1].Row printCommitMetadata(graph, pager, last_commit_row, len(graph[last_commit_row]), commits[len(commits)-1], decoration) for _, line := range commits[len(commits)-1].formattedMessage { - pager.Writer.Write([]byte(color.WhiteString("\t", line))) + pager.Writer.Write([]byte(color.WhiteString("\t%s", line))) pager.Writer.Write([]byte("\n")) } } @@ -556,7 +556,7 @@ func drawCommitDotsAndBranchPaths(commits []*commitInfoWithChildren, commitsMap } for i := col + 1; i < parent.Col-verticalDistance+1; i++ { if graph[row][i] == " " { - graph[row][i] = branchColor.Sprintf("s-") + graph[row][i] = branchColor.Sprintf("-") } } } diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index 1f5892a755..8666de33d3 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -106,12 +106,12 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string, } if closeFunc != nil { defer closeFunc() - } - if _, ok := queryist.(*engine.SqlEngine); !ok { - msg := fmt.Sprintf(cli.RemoteUnsupportedMsg, commandStr) - cli.Println(msg) - return 1 + if _, ok := queryist.(*engine.SqlEngine); !ok { + msg := fmt.Sprintf(cli.RemoteUnsupportedMsg, commandStr) + cli.Println(msg) + return 1 + } } ok := validateDoltMergeArgs(apr, usage, cliCtx) diff --git a/go/cmd/dolt/commands/rebase.go b/go/cmd/dolt/commands/rebase.go index d73e21834e..cc56d1137a 100644 --- a/go/cmd/dolt/commands/rebase.go +++ b/go/cmd/dolt/commands/rebase.go @@ -23,8 +23,6 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" - "github.com/gocraft/dbr/v2" - "github.com/gocraft/dbr/v2/dialect" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/errhand" @@ -102,7 +100,7 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - query, err := constructInterpolatedDoltRebaseQuery(apr) + query, err := interpolateStoredProcedureCall("DOLT_REBASE", args) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } @@ -181,30 +179,6 @@ func (cmd RebaseCmd) Exec(ctx context.Context, commandStr string, args []string, return HandleVErrAndExitCode(nil, usage) } -// constructInterpolatedDoltRebaseQuery generates the sql query necessary to call the DOLT_REBASE() function. -// Also interpolates this query to prevent sql injection. -func constructInterpolatedDoltRebaseQuery(apr *argparser.ArgParseResults) (string, error) { - var params []interface{} - var args []string - - if apr.NArg() == 1 { - params = append(params, apr.Arg(0)) - args = append(args, "?") - } - if apr.Contains(cli.InteractiveFlag) { - args = append(args, "'--interactive'") - } - if apr.Contains(cli.ContinueFlag) { - args = append(args, "'--continue'") - } - if apr.Contains(cli.AbortParam) { - args = append(args, "'--abort'") - } - - query := fmt.Sprintf("CALL DOLT_REBASE(%s);", strings.Join(args, ", ")) - return dbr.InterpolateForDialect(query, params, dialect.MySQL) -} - // getRebasePlan opens an editor for users to edit the rebase plan and returns the parsed rebase plan from the editor. func getRebasePlan(cliCtx cli.CliContext, sqlCtx *sql.Context, queryist cli.Queryist, rebaseBranch, currentBranch string) (*rebase.RebasePlan, error) { if cli.ExecuteWithStdioRestored == nil { diff --git a/go/cmd/dolt/commands/show.go b/go/cmd/dolt/commands/show.go index 43a7208af7..3e6cde95ad 100644 --- a/go/cmd/dolt/commands/show.go +++ b/go/cmd/dolt/commands/show.go @@ -30,19 +30,18 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/utils/argparser" - "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/types" "github.com/dolthub/dolt/go/store/util/outputpager" ) var hashRegex = regexp.MustCompile(`^#?[0-9a-v]{32}$`) type showOpts struct { - showParents bool - pretty bool - decoration string - specRefs []string - resolvedNonCommitSpecs map[string]string + showParents bool + pretty bool + decoration string + specRefs []string *diffDisplaySettings } @@ -120,33 +119,6 @@ func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, d opts.diffDisplaySettings = parseDiffDisplaySettings(apr) - // Decide if we're going to use dolt or sql for this execution. - // We can use SQL in the following cases: - // 1. `--no-pretty` is not set, so we will be producing "pretty" output. - // 2. opts.specRefs contains values that are NOT commit hashes. - // In all other cases, we should use DoltEnv - allSpecRefsAreCommits := true - allSpecRefsAreNonCommits := true - resolvedNonCommitSpecs := map[string]string{} - for _, specRef := range opts.specRefs { - isNonCommitSpec, resolvedValue, err := resolveNonCommitSpec(ctx, dEnv, specRef) - if err != nil { - err = fmt.Errorf("error resolving spec ref '%s': %w", specRef, err) - return handleErrAndExit(err) - } - allSpecRefsAreNonCommits = allSpecRefsAreNonCommits && isNonCommitSpec - allSpecRefsAreCommits = allSpecRefsAreCommits && !isNonCommitSpec - - if isNonCommitSpec { - resolvedNonCommitSpecs[specRef] = resolvedValue - } - } - - if !allSpecRefsAreCommits && !allSpecRefsAreNonCommits { - err = fmt.Errorf("cannot mix commit hashes and non-commit spec refs") - return handleErrAndExit(err) - } - queryist, sqlCtx, closeFunc, err := cliCtx.QueryEngine(ctx) if err != nil { return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) @@ -155,81 +127,138 @@ func (cmd ShowCmd) Exec(ctx context.Context, commandStr string, args []string, d defer closeFunc() } - useDoltEnv := !opts.pretty || (len(opts.specRefs) > 0 && allSpecRefsAreNonCommits) - if useDoltEnv { + // There are two response formats: + // - "pretty", which shows commits in a human-readable fashion + // - "raw", which shows the underlying SerialMessage + // All responses should be in the same format. + // The pretty format is preferred unless the --no-pretty flag is provided. + // But only commits are supported in the "pretty" format. + // Thus, if the --no-pretty flag is not set, then we require that either all the references are commits, or none of them are. + + isDEnvRequired := false + + if !opts.pretty { + isDEnvRequired = true + } + for _, specRef := range opts.specRefs { + upperCaseSpecRef := strings.ToUpper(specRef) + if !hashRegex.MatchString(specRef) && upperCaseSpecRef != "HEAD" { + isDEnvRequired = true + } + } + + if isDEnvRequired { + // use dEnv instead of the SQL engine _, ok := queryist.(*engine.SqlEngine) if !ok { - cli.PrintErrln("`dolt show --no-pretty` or `dolt show NON_COMMIT_REF` only supported in local mode.") + cli.PrintErrln("`dolt show --no-pretty` or `dolt show (BRANCHNAME)` only supported in local mode.") return 1 } if !opts.pretty && !dEnv.DoltDB.Format().UsesFlatbuffers() { - cli.PrintErrln("dolt show --no-pretty is not supported when using old LD_1 storage format.") + cli.PrintErrln("`dolt show --no-pretty` or `dolt show (BRANCHNAME)` is not supported when using old LD_1 storage format.") return 1 } - opts.resolvedNonCommitSpecs = resolvedNonCommitSpecs - err = printObjects(ctx, dEnv, opts) - return handleErrAndExit(err) - } else { - err = printObjectsPretty(queryist, sqlCtx, opts) - return handleErrAndExit(err) } -} -// resolveNonCommitSpec resolves a non-commit spec ref. -// A non-commit spec ref in this context is a ref that is returned by `dolt show --no-pretty` but is NOT a commit hash. -// These refs need env.DoltEnv in order to be resolved to a human-readable value. -func resolveNonCommitSpec(ctx context.Context, dEnv *env.DoltEnv, specRef string) (isNonCommitSpec bool, resolvedValue string, err error) { - isNonCommitSpec = false - resolvedValue = "" + specRefs := opts.specRefs + if len(specRefs) == 0 { + specRefs = []string{"HEAD"} + } - roots, err := dEnv.Roots(ctx) + for _, specRef := range specRefs { + // If --no-pretty was supplied, always display the raw contents of the referenced object. + if !opts.pretty { + err := printRawValue(ctx, dEnv, specRef) + if err != nil { + return handleErrAndExit(err) + } + continue + } + + // If the argument is a commit, display it in the "pretty" format. + // But if it's a hash, we don't know whether it's a commit until we query the engine. + commitInfo, err := getCommitSpecPretty(queryist, sqlCtx, opts, specRef) + if commitInfo == nil { + // Hash is not a commit + _, ok := queryist.(*engine.SqlEngine) + if !ok { + cli.PrintErrln("`dolt show (NON_COMMIT_HASH)` only supported in local mode.") + return 1 + } + + if !opts.pretty && !dEnv.DoltDB.Format().UsesFlatbuffers() { + cli.PrintErrln("`dolt show (NON_COMMIT_HASH)` is not supported when using old LD_1 storage format.") + return 1 + } + value, err := getValueFromRefSpec(ctx, dEnv, specRef) + if err != nil { + err = fmt.Errorf("error resolving spec ref '%s': %w", specRef, err) + if err != nil { + return handleErrAndExit(err) + } + } + cli.Println(value.Kind(), value.HumanReadableString()) + continue + } else { + // Hash is a commit + err = fetchAndPrintCommit(queryist, sqlCtx, opts, commitInfo) + if err != nil { + return handleErrAndExit(err) + } + continue + } + } + return 0 +} + +func printRawValue(ctx context.Context, dEnv *env.DoltEnv, specRef string) error { + value, err := getValueFromRefSpec(ctx, dEnv, specRef) if err != nil { - return isNonCommitSpec, resolvedValue, err + return fmt.Errorf("error resolving spec ref '%s': %w", specRef, err) } + cli.Println(value.Kind(), value.HumanReadableString()) + return nil +} +func getValueFromRefSpec(ctx context.Context, dEnv *env.DoltEnv, specRef string) (types.Value, error) { + var refHash hash.Hash + var err error + roots, err := dEnv.Roots(ctx) upperCaseSpecRef := strings.ToUpper(specRef) - if upperCaseSpecRef == doltdb.Working || upperCaseSpecRef == doltdb.Staged || hashRegex.MatchString(specRef) { - var refHash hash.Hash - var err error - if upperCaseSpecRef == doltdb.Working { - refHash, err = roots.Working.HashOf() - } else if upperCaseSpecRef == doltdb.Staged { - refHash, err = roots.Staged.HashOf() - } else { - refHash, err = parseHashString(specRef) - } + if upperCaseSpecRef == doltdb.Working { + refHash, err = roots.Working.HashOf() + } else if upperCaseSpecRef == doltdb.Staged { + refHash, err = roots.Staged.HashOf() + } else if hashRegex.MatchString(specRef) { + refHash, err = parseHashString(specRef) + } else { + commitSpec, err := doltdb.NewCommitSpec(specRef) if err != nil { - return isNonCommitSpec, resolvedValue, err + return nil, err } - value, err := dEnv.DoltDB.ValueReadWriter().ReadValue(ctx, refHash) + headRef, err := dEnv.RepoStateReader().CWBHeadRef() + optionalCommit, err := dEnv.DoltDB.Resolve(ctx, commitSpec, headRef) if err != nil { - return isNonCommitSpec, resolvedValue, err - } - if value == nil { - return isNonCommitSpec, resolvedValue, fmt.Errorf("Unable to resolve object ref %s", specRef) + return nil, err } - - // If this is a commit, use the pretty printer. To determine whether it's a commit, try calling NewCommitFromValue. - _, err = doltdb.NewCommitFromValue(ctx, dEnv.DoltDB.ValueReadWriter(), dEnv.DoltDB.NodeStore(), value) - - if err == datas.ErrNotACommit { - if !dEnv.DoltDB.Format().UsesFlatbuffers() { - return isNonCommitSpec, resolvedValue, fmt.Errorf("dolt show cannot show non-commit objects when using the old LD_1 storage format: %s is not a commit", specRef) - } - isNonCommitSpec = true - resolvedValue = fmt.Sprintln(value.Kind(), value.HumanReadableString()) - return isNonCommitSpec, resolvedValue, nil - } else if err == nil { - isNonCommitSpec = false - return isNonCommitSpec, resolvedValue, nil - } else { - return isNonCommitSpec, resolvedValue, err + commit, ok := optionalCommit.ToCommit() + if !ok { + return nil, doltdb.ErrGhostCommitEncountered } - } else { // specRef is a CommitSpec, which must resolve to a Commit. - isNonCommitSpec = false - return isNonCommitSpec, resolvedValue, nil + return commit.Value(), nil + } + if err != nil { + return nil, err + } + value, err := dEnv.DoltDB.ValueReadWriter().ReadValue(ctx, refHash) + if err != nil { + return nil, err + } + if value == nil { + return nil, fmt.Errorf("Unable to resolve object ref %s", specRef) } + return value, nil } func (cmd ShowCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseError { @@ -266,57 +295,6 @@ func parseShowArgs(apr *argparser.ArgParseResults) (*showOpts, error) { }, nil } -func printObjects(ctx context.Context, dEnv *env.DoltEnv, opts *showOpts) error { - if len(opts.specRefs) == 0 { - headSpec, err := dEnv.RepoStateReader().CWBHeadSpec() - if err != nil { - return err - } - - headRef, err := dEnv.RepoStateReader().CWBHeadRef() - if err != nil { - return err - } - - optCmt, err := dEnv.DoltDB.Resolve(ctx, headSpec, headRef) - if err != nil { - return err - } - commit, ok := optCmt.ToCommit() - if !ok { - return doltdb.ErrGhostCommitEncountered - } - - value := commit.Value() - cli.Println(value.Kind(), value.HumanReadableString()) - } - - for _, specRef := range opts.specRefs { - resolvedValue, ok := opts.resolvedNonCommitSpecs[specRef] - if !ok { - return fmt.Errorf("fatal: unable to resolve object ref %s", specRef) - } - cli.Println(resolvedValue) - } - - return nil -} - -func printObjectsPretty(queryist cli.Queryist, sqlCtx *sql.Context, opts *showOpts) error { - if len(opts.specRefs) == 0 { - return printCommitSpecPretty(queryist, sqlCtx, opts, "HEAD") - } - - for _, specRef := range opts.specRefs { - err := printCommitSpecPretty(queryist, sqlCtx, opts, specRef) - if err != nil { - return err - } - } - - return nil -} - // parseHashString converts a string representing a hash into a hash.Hash. func parseHashString(hashStr string) (hash.Hash, error) { unprefixed := strings.TrimPrefix(hashStr, "#") @@ -327,24 +305,19 @@ func parseHashString(hashStr string) (hash.Hash, error) { return parsedHash, nil } -func printCommitSpecPretty(queryist cli.Queryist, sqlCtx *sql.Context, opts *showOpts, commitRef string) error { +func getCommitSpecPretty(queryist cli.Queryist, sqlCtx *sql.Context, opts *showOpts, commitRef string) (commit *CommitInfo, err error) { if strings.HasPrefix(commitRef, "#") { commitRef = strings.TrimPrefix(commitRef, "#") } - commit, err := getCommitInfo(queryist, sqlCtx, commitRef) + commit, err = getCommitInfo(queryist, sqlCtx, commitRef) if err != nil { - return fmt.Errorf("error: failed to get commit metadata for ref '%s': %v", commitRef, err) + return commit, fmt.Errorf("error: failed to get commit metadata for ref '%s': %v", commitRef, err) } - - err = printCommit(queryist, sqlCtx, opts, commit) - if err != nil { - return err - } - return nil + return } -func printCommit(queryist cli.Queryist, sqlCtx *sql.Context, opts *showOpts, commit *CommitInfo) error { +func fetchAndPrintCommit(queryist cli.Queryist, sqlCtx *sql.Context, opts *showOpts, commit *CommitInfo) error { cmHash := commit.commitHash parents := commit.parentHashes diff --git a/go/cmd/dolt/commands/sql.go b/go/cmd/dolt/commands/sql.go index 39f28d0ef3..6646f5062a 100644 --- a/go/cmd/dolt/commands/sql.go +++ b/go/cmd/dolt/commands/sql.go @@ -786,7 +786,7 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu sqlCtx := sql.NewContext(subCtx, sql.WithSession(sqlCtx.Session)) - subCmd, foundCmd := findSlashCmd(query) + subCmd, foundCmd := isSlashQuery(query) if foundCmd { err := handleSlashCommand(sqlCtx, subCmd, query, cliCtx) if err != nil { @@ -831,6 +831,15 @@ func execShell(sqlCtx *sql.Context, qryist cli.Queryist, format engine.PrintResu return nil } +func isSlashQuery(query string) (cli.Command, bool) { + // strip leading whitespace + query = strings.TrimLeft(query, " \t\n\r\v\f") + if strings.HasPrefix(query, "\\") { + return findSlashCmd(query[1:]) + } + return nil, false +} + // postCommandUpdate is a helper function that is run after the shell has completed a command. It updates the the database // if needed, and generates new prompts for the shell (based on the branch and if the workspace is dirty). func postCommandUpdate(sqlCtx *sql.Context, qryist cli.Queryist) (string, string) { diff --git a/go/cmd/dolt/commands/sql_slash.go b/go/cmd/dolt/commands/sql_slash.go index fa1d1c8088..f1d392a844 100644 --- a/go/cmd/dolt/commands/sql_slash.go +++ b/go/cmd/dolt/commands/sql_slash.go @@ -31,6 +31,7 @@ var slashCmds = []cli.Command{ StatusCmd{}, DiffCmd{}, LogCmd{}, + ShowCmd{}, AddCmd{}, CommitCmd{}, CheckoutCmd{}, diff --git a/go/cmd/dolt/commands/utils.go b/go/cmd/dolt/commands/utils.go index a2b9d04b5a..9c60d717f9 100644 --- a/go/cmd/dolt/commands/utils.go +++ b/go/cmd/dolt/commands/utils.go @@ -631,14 +631,14 @@ func printRefs(pager *outputpager.Pager, comm *CommitInfo, decoration string) { yellow := color.New(color.FgYellow) boldCyan := color.New(color.FgCyan, color.Bold) - pager.Writer.Write([]byte(yellow.Sprintf(" ("))) + pager.Writer.Write([]byte(yellow.Sprintf("("))) if comm.isHead { pager.Writer.Write([]byte(boldCyan.Sprintf("HEAD -> "))) } joinedReferences := strings.Join(references, yellow.Sprint(", ")) - pager.Writer.Write([]byte(yellow.Sprintf("%s)", joinedReferences))) + pager.Writer.Write([]byte(yellow.Sprintf("%s) ", joinedReferences))) } // getCommitInfo returns the commit info for the given ref. @@ -657,7 +657,8 @@ func getCommitInfo(queryist cli.Queryist, sqlCtx *sql.Context, ref string) (*Com return nil, fmt.Errorf("error getting logs for ref '%s': %v", ref, err) } if len(rows) == 0 { - return nil, fmt.Errorf("no commits found for ref %s", ref) + // No commit with this hash exists + return nil, nil } row := rows[0] @@ -826,3 +827,25 @@ func HandleVErrAndExitCode(verr errhand.VerboseError, usage cli.UsagePrinter) in return 0 } + +// interpolateStoredProcedureCall returns an interpolated query to call |storedProcedureName| with the arguments +// |args|. +func interpolateStoredProcedureCall(storedProcedureName string, args []string) (string, error) { + query := fmt.Sprintf("CALL %s(%s);", storedProcedureName, buildPlaceholdersString(len(args))) + return dbr.InterpolateForDialect(query, stringSliceToInterfaceSlice(args), dialect.MySQL) +} + +// stringSliceToInterfaceSlice converts the string slice |ss| into an interface slice with the same values. +func stringSliceToInterfaceSlice(ss []string) []interface{} { + retSlice := make([]interface{}, 0, len(ss)) + for _, s := range ss { + retSlice = append(retSlice, s) + } + return retSlice +} + +// buildPlaceholdersString returns a placeholder string to use in an interpolated query with the specified +// |count| of parameter placeholders. +func buildPlaceholdersString(count int) string { + return strings.Join(make([]string, count), "?, ") + "?" +} diff --git a/go/cmd/dolt/commands/version.go b/go/cmd/dolt/commands/version.go index fe1d77fb3a..1eef9b0a8c 100644 --- a/go/cmd/dolt/commands/version.go +++ b/go/cmd/dolt/commands/version.go @@ -185,6 +185,14 @@ func checkAndPrintVersionOutOfDateWarning(curVersion string, dEnv *env.DoltEnv) } } + // If we still don't have a valid latestRelease, even after trying to query it, then skip the out of date + // check and print a warning message. This can happen for example, if we get a 403 from GitHub when + // querying for the latest release tag. + if latestRelease == "" { + cli.Printf(color.YellowString("Warning: unable to query latest released Dolt version")) + return nil + } + // if there were new releases in the last week, the latestRelease stored might be behind the current version built isOutOfDate, verr := isOutOfDate(curVersion, latestRelease) if verr != nil { diff --git a/go/cmd/dolt/dolt.go b/go/cmd/dolt/dolt.go index d2ecddfe0d..1390473a4a 100644 --- a/go/cmd/dolt/dolt.go +++ b/go/cmd/dolt/dolt.go @@ -510,6 +510,11 @@ func runMain() int { return 1 } + if dEnv.CfgLoadErr != nil { + cli.PrintErrln(color.RedString("Failed to load the global config. %v", dEnv.CfgLoadErr)) + return 1 + } + strMetricsDisabled := dEnv.Config.GetStringOrDefault(config.MetricsDisabled, "false") var metricsEmitter events.Emitter metricsEmitter = events.NewFileEmitter(homeDir, dbfactory.DoltDir) @@ -520,10 +525,6 @@ func runMain() int { events.SetGlobalCollector(events.NewCollector(doltversion.Version, metricsEmitter)) - if dEnv.CfgLoadErr != nil { - cli.PrintErrln(color.RedString("Failed to load the global config. %v", dEnv.CfgLoadErr)) - return 1 - } globalConfig, ok := dEnv.Config.GetConfig(env.GlobalConfig) if !ok { cli.PrintErrln(color.RedString("Failed to get global config")) diff --git a/go/cmd/dolt/doltversion/version.go b/go/cmd/dolt/doltversion/version.go index 890d73f639..ffcd6f4cbd 100644 --- a/go/cmd/dolt/doltversion/version.go +++ b/go/cmd/dolt/doltversion/version.go @@ -16,5 +16,5 @@ package doltversion const ( - Version = "1.42.10" + Version = "1.42.13" ) diff --git a/go/gen/fb/serial/workingset.go b/go/gen/fb/serial/workingset.go index 63948a2e67..424b406ac3 100644 --- a/go/gen/fb/serial/workingset.go +++ b/go/gen/fb/serial/workingset.go @@ -531,7 +531,31 @@ func (rcv *RebaseState) MutateOntoCommitAddr(j int, n byte) bool { return false } -const RebaseStateNumFields = 3 +func (rcv *RebaseState) EmptyCommitHandling() byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetByte(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *RebaseState) MutateEmptyCommitHandling(n byte) bool { + return rcv._tab.MutateByteSlot(10, n) +} + +func (rcv *RebaseState) CommitBecomesEmptyHandling() byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.GetByte(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *RebaseState) MutateCommitBecomesEmptyHandling(n byte) bool { + return rcv._tab.MutateByteSlot(12, n) +} + +const RebaseStateNumFields = 5 func RebaseStateStart(builder *flatbuffers.Builder) { builder.StartObject(RebaseStateNumFields) @@ -554,6 +578,12 @@ func RebaseStateAddOntoCommitAddr(builder *flatbuffers.Builder, ontoCommitAddr f func RebaseStateStartOntoCommitAddrVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(1, numElems, 1) } +func RebaseStateAddEmptyCommitHandling(builder *flatbuffers.Builder, emptyCommitHandling byte) { + builder.PrependByteSlot(3, emptyCommitHandling, 0) +} +func RebaseStateAddCommitBecomesEmptyHandling(builder *flatbuffers.Builder, commitBecomesEmptyHandling byte) { + builder.PrependByteSlot(4, commitBecomesEmptyHandling, 0) +} func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/go/go.mod b/go/go.mod index 79e4a8a597..11ac2a254c 100644 --- a/go/go.mod +++ b/go/go.mod @@ -57,7 +57,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.18.2-0.20240808231249-e035ac0ed25a + github.com/dolthub/go-mysql-server v0.18.2-0.20240819234152-ac84f8593e99 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/goccy/go-json v0.10.2 diff --git a/go/go.sum b/go/go.sum index 3c2efd039c..0337921a31 100644 --- a/go/go.sum +++ b/go/go.sum @@ -183,8 +183,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.2-0.20240808231249-e035ac0ed25a h1:t5lkm+LGwj8xnDs+jiONt26fAhtWG/Blk0Ucvr8gN8w= -github.com/dolthub/go-mysql-server v0.18.2-0.20240808231249-e035ac0ed25a/go.mod h1:PwuemL+YK+YiWcUFhknixeqNLjJNfCx7KDsHNajx9fM= +github.com/dolthub/go-mysql-server v0.18.2-0.20240819234152-ac84f8593e99 h1:GKa4Wu7SS0FDJkdGBRR2hCMXXlqdNmsqraRl+ZKYW4U= +github.com/dolthub/go-mysql-server v0.18.2-0.20240819234152-ac84f8593e99/go.mod h1:nbdOzd0ceWONE80vbfwoRBjut7z3CIj69ZgDF/cKuaA= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/go/go.work.sum b/go/go.work.sum index f740cf0835..37a6dc28ca 100644 --- a/go/go.work.sum +++ b/go/go.work.sum @@ -318,6 +318,10 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHH github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dolthub/go-mysql-server v0.18.2-0.20240812011431-f3892cc42bbf h1:F4OT8cjaQzGlLne9vp7/q0i5QFsQE2OUWIaL5thO5qA= +github.com/dolthub/go-mysql-server v0.18.2-0.20240812011431-f3892cc42bbf/go.mod h1:PwuemL+YK+YiWcUFhknixeqNLjJNfCx7KDsHNajx9fM= +github.com/dolthub/vitess v0.0.0-20240807181005-71d735078e24 h1:/zCd98CLZURqK85jQ+qRmEMx/dpXz85F1/Et7gqMGkk= +github.com/dolthub/vitess v0.0.0-20240807181005-71d735078e24/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/eapache/go-resiliency v1.1.0 h1:1NtRmCAqadE2FN4ZcN6g90TP3uk8cg9rn9eNK2197aU= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= diff --git a/go/libraries/doltcore/cherry_pick/cherry_pick.go b/go/libraries/doltcore/cherry_pick/cherry_pick.go index 1b6af8d27b..1952741405 100644 --- a/go/libraries/doltcore/cherry_pick/cherry_pick.go +++ b/go/libraries/doltcore/cherry_pick/cherry_pick.go @@ -37,11 +37,38 @@ type CherryPickOptions struct { // CommitMessage is optional, and controls the message for the new commit. CommitMessage string + + // CommitBecomesEmptyHandling describes how commits that do not start off as empty, but become empty after applying + // the changes, should be handled. For example, if cherry-picking a change from another branch, but the changes + // have already been applied on the target branch in another commit, the new commit will be empty. Note that this + // is distinct from how to handle commits that start off empty. By default, in Git, the cherry-pick command will + // stop when processing a commit that becomes empty and allow the user to take additional action. Dolt doesn't + // support this flow, so instead, Dolt's default is to fail the cherry-pick operation. In Git rebase, and in Dolt + // rebase, the default for handling commits that become empty while being processed is to drop them. + CommitBecomesEmptyHandling doltdb.EmptyCommitHandling + + // EmptyCommitHandling describes how commits that start off as empty should be handled. Note that this is distinct + // from how to handle commits that start off with changes, but become empty after applying the changes. In Git + // and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git + // and Dolt rebase implementations, the default action is to keep commits that start off as empty. + EmptyCommitHandling doltdb.EmptyCommitHandling +} + +// NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick. +func NewCherryPickOptions() CherryPickOptions { + return CherryPickOptions{ + Amend: false, + CommitMessage: "", + CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit, + EmptyCommitHandling: doltdb.ErrorOnEmptyCommit, + } } // CherryPick replays a commit, specified by |options.Commit|, and applies it as a new commit to the current HEAD. If -// successful, the hash of the new commit is returned. If the cherry-pick results in merge conflicts, the merge result -// is returned. If any unexpected error occur, it is returned. +// successful and a new commit is created, the hash of the new commit is returned. If successful, but no new commit +// was created (for example, when dropping an empty commit), then the first return parameter will be the empty string. +// If the cherry-pick results in merge conflicts, the merge result is returned. If the operation is not successful for +// any reason, then the error return parameter will be populated. func CherryPick(ctx *sql.Context, commit string, options CherryPickOptions) (string, *merge.Result, error) { doltSession := dsess.DSessFromSess(ctx.Session) dbName := ctx.GetCurrentDatabase() @@ -51,7 +78,7 @@ func CherryPick(ctx *sql.Context, commit string, options CherryPickOptions) (str return "", nil, fmt.Errorf("failed to get roots for current session") } - mergeResult, commitMsg, err := cherryPick(ctx, doltSession, roots, dbName, commit) + mergeResult, commitMsg, err := cherryPick(ctx, doltSession, roots, dbName, commit, options.EmptyCommitHandling) if err != nil { return "", mergeResult, err } @@ -94,6 +121,17 @@ func CherryPick(ctx *sql.Context, commit string, options CherryPickOptions) (str if options.Amend { commitProps.Amend = true } + if options.EmptyCommitHandling == doltdb.KeepEmptyCommit { + commitProps.AllowEmpty = true + } + + if options.CommitBecomesEmptyHandling == doltdb.DropEmptyCommit { + commitProps.SkipEmpty = true + } else if options.CommitBecomesEmptyHandling == doltdb.KeepEmptyCommit { + commitProps.AllowEmpty = true + } else if options.CommitBecomesEmptyHandling == doltdb.StopOnEmptyCommit { + return "", nil, fmt.Errorf("stop on empty commit is not currently supported") + } // NOTE: roots are old here (after staging the tables) and need to be refreshed roots, ok = doltSession.GetRoots(ctx, dbName) @@ -106,7 +144,11 @@ func CherryPick(ctx *sql.Context, commit string, options CherryPickOptions) (str return "", nil, err } if pendingCommit == nil { - return "", nil, errors.New("nothing to commit") + if commitProps.SkipEmpty { + return "", nil, nil + } else if !commitProps.AllowEmpty { + return "", nil, errors.New("nothing to commit") + } } newCommit, err := doltSession.DoltCommit(ctx, dbName, doltSession.GetTransaction(), pendingCommit) @@ -166,7 +208,7 @@ func AbortCherryPick(ctx *sql.Context, dbName string) error { // cherryPick checks that the current working set is clean, verifies the cherry-pick commit is not a merge commit // or a commit without parent commit, performs merge and returns the new working set root value and // the commit message of cherry-picked commit as the commit message of the new commit created during this command. -func cherryPick(ctx *sql.Context, dSess *dsess.DoltSession, roots doltdb.Roots, dbName, cherryStr string) (*merge.Result, string, error) { +func cherryPick(ctx *sql.Context, dSess *dsess.DoltSession, roots doltdb.Roots, dbName, cherryStr string, emptyCommitHandling doltdb.EmptyCommitHandling) (*merge.Result, string, error) { // check for clean working set wsOnlyHasIgnoredTables, err := diff.WorkingSetContainsOnlyIgnoredTables(ctx, roots) if err != nil { @@ -241,6 +283,24 @@ func cherryPick(ctx *sql.Context, dSess *dsess.DoltSession, roots doltdb.Roots, return nil, "", err } + isEmptyCommit, err := rootsEqual(cherryRoot, parentRoot) + if err != nil { + return nil, "", err + } + if isEmptyCommit { + switch emptyCommitHandling { + case doltdb.KeepEmptyCommit: + // No action; keep processing the empty commit + case doltdb.DropEmptyCommit: + return nil, "", nil + case doltdb.ErrorOnEmptyCommit: + return nil, "", fmt.Errorf("The previous cherry-pick commit is empty. " + + "Use --allow-empty to cherry-pick empty commits.") + default: + return nil, "", fmt.Errorf("Unsupported empty commit handling options: %v", emptyCommitHandling) + } + } + dbState, ok, err := dSess.LookupDbState(ctx, dbName) if err != nil { return nil, "", err @@ -269,7 +329,7 @@ func cherryPick(ctx *sql.Context, dSess *dsess.DoltSession, roots doltdb.Roots, } } - if headRootHash.Equal(workingRootHash) { + if headRootHash.Equal(workingRootHash) && !isEmptyCommit { return nil, "", fmt.Errorf("no changes were made, nothing to commit") } @@ -299,6 +359,20 @@ func cherryPick(ctx *sql.Context, dSess *dsess.DoltSession, roots doltdb.Roots, return result, cherryCommitMeta.Description, nil } +func rootsEqual(root1, root2 doltdb.RootValue) (bool, error) { + root1Hash, err := root1.HashOf() + if err != nil { + return false, err + } + + root2Hash, err := root2.HashOf() + if err != nil { + return false, err + } + + return root1Hash.Equal(root2Hash), nil +} + // stageCherryPickedTables stages the tables from |mergeStats| that don't have any merge artifacts – i.e. // tables that don't have any data or schema conflicts and don't have any constraint violations. func stageCherryPickedTables(ctx *sql.Context, mergeStats map[string]*merge.MergeStats) (err error) { diff --git a/go/libraries/doltcore/doltdb/system_table.go b/go/libraries/doltcore/doltdb/system_table.go index 87d59e57b2..909a95d773 100644 --- a/go/libraries/doltcore/doltdb/system_table.go +++ b/go/libraries/doltcore/doltdb/system_table.go @@ -185,16 +185,13 @@ var generatedSystemTables = []string{ RemotesTableName, } -var generatedSystemViewPrefixes = []string{ - DoltBlameViewPrefix, -} - var generatedSystemTablePrefixes = []string{ DoltDiffTablePrefix, DoltCommitDiffTablePrefix, DoltHistoryTablePrefix, DoltConfTablePrefix, DoltConstViolTablePrefix, + DoltWorkspaceTablePrefix, } const ( @@ -264,6 +261,8 @@ const ( DoltConfTablePrefix = "dolt_conflicts_" // DoltConstViolTablePrefix is the prefix assigned to all the generated constraint violation tables DoltConstViolTablePrefix = "dolt_constraint_violations_" + // DoltWorkspaceTablePrefix is the prefix assigned to all the generated workspace tables + DoltWorkspaceTablePrefix = "dolt_workspace_" ) const ( diff --git a/go/libraries/doltcore/doltdb/workingset.go b/go/libraries/doltcore/doltdb/workingset.go index 7467467a9c..5c0a42ed07 100755 --- a/go/libraries/doltcore/doltdb/workingset.go +++ b/go/libraries/doltcore/doltdb/workingset.go @@ -29,6 +29,29 @@ import ( "github.com/dolthub/dolt/go/store/types" ) +// EmptyCommitHandling describes how a cherry-pick action should handle empty commits. This applies to commits that +// start off as empty, as well as commits whose changes are applied, but are redundant, and become empty. Note that +// cherry-pick and rebase treat these two cases separately – commits that start as empty versus commits that become +// empty while being rebased or cherry-picked. +type EmptyCommitHandling int + +const ( + // ErrorOnEmptyCommit instructs a cherry-pick or rebase operation to fail with an error when an empty commit + // is encountered. + ErrorOnEmptyCommit = iota + + // DropEmptyCommit instructs a cherry-pick or rebase operation to drop empty commits and to not create new + // commits for them. + DropEmptyCommit + + // KeepEmptyCommit instructs a cherry-pick or rebase operation to keep empty commits. + KeepEmptyCommit + + // StopOnEmptyCommit instructs a cherry-pick or rebase operation to stop and let the user take additional action + // to decide how to handle an empty commit. + StopOnEmptyCommit +) + // RebaseState tracks the state of an in-progress rebase action. It records the name of the branch being rebased, the // commit onto which the new commits will be rebased, and the root value of the previous working set, which is used if // the rebase is aborted and the working set needs to be restored to its previous state. @@ -36,6 +59,13 @@ type RebaseState struct { preRebaseWorking RootValue ontoCommit *Commit branch string + + // commitBecomesEmptyHandling specifies how to handle a commit that contains changes, but when cherry-picked, + // results in no changes being applied. + commitBecomesEmptyHandling EmptyCommitHandling + + // emptyCommitHandling specifies how to handle empty commits that contain no changes. + emptyCommitHandling EmptyCommitHandling } // Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point @@ -55,6 +85,14 @@ func (rs RebaseState) PreRebaseWorkingRoot() RootValue { return rs.preRebaseWorking } +func (rs RebaseState) EmptyCommitHandling() EmptyCommitHandling { + return rs.emptyCommitHandling +} + +func (rs RebaseState) CommitBecomesEmptyHandling() EmptyCommitHandling { + return rs.commitBecomesEmptyHandling +} + type MergeState struct { // the source commit commit *Commit @@ -257,11 +295,13 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe // the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED // root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY // if it contains only ignored tables. -func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue) (*WorkingSet, error) { +func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) { ws.rebaseState = &RebaseState{ - ontoCommit: ontoCommit, - preRebaseWorking: previousRoot, - branch: branch, + ontoCommit: ontoCommit, + preRebaseWorking: previousRoot, + branch: branch, + commitBecomesEmptyHandling: commitBecomesEmptyHandling, + emptyCommitHandling: emptyCommitHandling, } ontoRoot, err := ontoCommit.GetRootValue(ctx) @@ -472,9 +512,11 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter, } rebaseState = &RebaseState{ - preRebaseWorking: preRebaseWorkingRoot, - ontoCommit: ontoCommit, - branch: dsws.RebaseState.Branch(ctx), + preRebaseWorking: preRebaseWorkingRoot, + ontoCommit: ontoCommit, + branch: dsws.RebaseState.Branch(ctx), + commitBecomesEmptyHandling: EmptyCommitHandling(dsws.RebaseState.CommitBecomesEmptyHandling(ctx)), + emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)), } } @@ -570,7 +612,8 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W return nil, err } - rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch) + rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch, + uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling)) } return &datas.WorkingSetSpec{ diff --git a/go/libraries/doltcore/merge/merge_prolly_rows.go b/go/libraries/doltcore/merge/merge_prolly_rows.go index 8f3caaf71a..4466abb191 100644 --- a/go/libraries/doltcore/merge/merge_prolly_rows.go +++ b/go/libraries/doltcore/merge/merge_prolly_rows.go @@ -26,7 +26,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql/types" - "golang.org/x/exp/maps" errorkinds "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -1978,20 +1977,20 @@ func (m *valueMerger) processColumn(ctx *sql.Context, i int, left, right, base v } func (m *valueMerger) mergeJSONAddr(ctx context.Context, baseAddr []byte, leftAddr []byte, rightAddr []byte) (resultAddr []byte, conflict bool, err error) { - baseDoc, err := tree.NewJSONDoc(hash.New(baseAddr), m.ns).ToJSONDocument(ctx) + baseDoc, err := tree.NewJSONDoc(hash.New(baseAddr), m.ns).ToIndexedJSONDocument(ctx) if err != nil { return nil, true, err } - leftDoc, err := tree.NewJSONDoc(hash.New(leftAddr), m.ns).ToJSONDocument(ctx) + leftDoc, err := tree.NewJSONDoc(hash.New(leftAddr), m.ns).ToIndexedJSONDocument(ctx) if err != nil { return nil, true, err } - rightDoc, err := tree.NewJSONDoc(hash.New(rightAddr), m.ns).ToJSONDocument(ctx) + rightDoc, err := tree.NewJSONDoc(hash.New(rightAddr), m.ns).ToIndexedJSONDocument(ctx) if err != nil { return nil, true, err } - mergedDoc, conflict, err := mergeJSON(ctx, baseDoc, leftDoc, rightDoc) + mergedDoc, conflict, err := mergeJSON(ctx, m.ns, baseDoc, leftDoc, rightDoc) if err != nil { return nil, true, err } @@ -1999,35 +1998,36 @@ func (m *valueMerger) mergeJSONAddr(ctx context.Context, baseAddr []byte, leftAd return nil, true, nil } - mergedVal, err := mergedDoc.ToInterface() - if err != nil { - return nil, true, err - } - mergedBytes, err := json.Marshal(mergedVal) - if err != nil { - return nil, true, err - } - mergedAddr, err := tree.SerializeBytesToAddr(ctx, m.ns, bytes.NewReader(mergedBytes), len(mergedBytes)) + root, err := tree.SerializeJsonToAddr(ctx, m.ns, mergedDoc) if err != nil { return nil, true, err } + mergedAddr := root.HashOf() return mergedAddr[:], false, nil - } -func mergeJSON(ctx context.Context, base types.JSONDocument, left types.JSONDocument, right types.JSONDocument) (resultDoc types.JSONDocument, conflict bool, err error) { +func mergeJSON(ctx context.Context, ns tree.NodeStore, base, left, right sql.JSONWrapper) (resultDoc sql.JSONWrapper, conflict bool, err error) { // First, deserialize each value into JSON. // We can only merge if the value at all three commits is a JSON object. - baseObject, baseIsObject := base.Val.(types.JsonObject) - leftObject, leftIsObject := left.Val.(types.JsonObject) - rightObject, rightIsObject := right.Val.(types.JsonObject) + baseIsObject, err := tree.IsJsonObject(base) + if err != nil { + return nil, true, err + } + leftIsObject, err := tree.IsJsonObject(left) + if err != nil { + return nil, true, err + } + rightIsObject, err := tree.IsJsonObject(right) + if err != nil { + return nil, true, err + } if !baseIsObject || !leftIsObject || !rightIsObject { // At least one of the commits does not have a JSON object. // If both left and right have the same value, use that value. // But if they differ, this is an unresolvable merge conflict. - cmp, err := left.Compare(right) + cmp, err := types.CompareJSON(left, right) if err != nil { return types.JSONDocument{}, true, err } @@ -2039,26 +2039,83 @@ func mergeJSON(ctx context.Context, base types.JSONDocument, left types.JSONDocu } } - mergedObject := maps.Clone(leftObject) - merged := types.JSONDocument{Val: mergedObject} + indexedBase, isBaseIndexed := base.(tree.IndexedJsonDocument) + indexedLeft, isLeftIndexed := left.(tree.IndexedJsonDocument) + indexedRight, isRightIndexed := right.(tree.IndexedJsonDocument) + + // We only do three way merges on values read from tables right now, which are read in as tree.IndexedJsonDocument. + + var leftDiffer tree.IJsonDiffer + if isBaseIndexed && isLeftIndexed { + leftDiffer, err = tree.NewIndexedJsonDiffer(ctx, indexedBase, indexedLeft) + if err != nil { + return nil, true, err + } + } else { + baseObject, err := base.ToInterface() + if err != nil { + return nil, true, err + } + leftObject, err := left.ToInterface() + if err != nil { + return nil, true, err + } + leftDiffer = tree.NewJsonDiffer(baseObject.(types.JsonObject), leftObject.(types.JsonObject)) + } + + var rightDiffer tree.IJsonDiffer + if isBaseIndexed && isRightIndexed { + rightDiffer, err = tree.NewIndexedJsonDiffer(ctx, indexedBase, indexedRight) + if err != nil { + return nil, true, err + } + } else { + baseObject, err := base.ToInterface() + if err != nil { + return nil, true, err + } + rightObject, err := right.ToInterface() + if err != nil { + return nil, true, err + } + rightDiffer = tree.NewJsonDiffer(baseObject.(types.JsonObject), rightObject.(types.JsonObject)) + } - threeWayDiffer := NewThreeWayJsonDiffer(baseObject, leftObject, rightObject) + threeWayDiffer := ThreeWayJsonDiffer{ + leftDiffer: leftDiffer, + rightDiffer: rightDiffer, + ns: ns, + } // Compute the merged object by applying diffs to the left object as needed. + // If the left object isn't an IndexedJsonDocument, we make one. + var ok bool + var merged tree.IndexedJsonDocument + if merged, ok = left.(tree.IndexedJsonDocument); !ok { + root, err := tree.SerializeJsonToAddr(ctx, ns, left) + if err != nil { + return types.JSONDocument{}, true, err + } + merged = tree.NewIndexedJsonDocument(ctx, root, ns) + } + for { threeWayDiff, err := threeWayDiffer.Next(ctx) if err == io.EOF { return merged, false, nil } + if err != nil { + return types.JSONDocument{}, true, err + } switch threeWayDiff.Op { - case tree.DiffOpRightAdd, tree.DiffOpConvergentAdd, tree.DiffOpRightModify, tree.DiffOpConvergentModify: - _, _, err := merged.Set(ctx, threeWayDiff.Key, threeWayDiff.Right) + case tree.DiffOpRightAdd, tree.DiffOpConvergentAdd, tree.DiffOpRightModify, tree.DiffOpConvergentModify, tree.DiffOpDivergentModifyResolved: + merged, _, err = merged.SetWithKey(ctx, threeWayDiff.Key, threeWayDiff.Right) if err != nil { return types.JSONDocument{}, true, err } case tree.DiffOpRightDelete, tree.DiffOpConvergentDelete: - _, _, err := merged.Remove(ctx, threeWayDiff.Key) + merged, _, err = merged.RemoveWithKey(ctx, threeWayDiff.Key) if err != nil { return types.JSONDocument{}, true, err } diff --git a/go/libraries/doltcore/merge/schema_merge_test.go b/go/libraries/doltcore/merge/schema_merge_test.go index 9f0acbd068..f92d194797 100644 --- a/go/libraries/doltcore/merge/schema_merge_test.go +++ b/go/libraries/doltcore/merge/schema_merge_test.go @@ -16,9 +16,13 @@ package merge_test import ( "context" + "fmt" "testing" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/expression/function/json" + sqltypes "github.com/dolthub/go-mysql-server/sql/types" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,6 +39,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" ) @@ -105,6 +110,9 @@ func TestSchemaMerge(t *testing.T) { t.Run("json merge tests", func(t *testing.T) { testSchemaMerge(t, jsonMergeTests) }) + t.Run("large json merge tests", func(t *testing.T) { + testSchemaMerge(t, jsonMergeLargeDocumentTests(t)) + }) } var columnAddDropTests = []schemaMergeTest{ @@ -1207,6 +1215,13 @@ var jsonMergeTests = []schemaMergeTest{ right: singleRow(1, 1, 2, `{ "key1": "value1", "key2": "value4" }`), merged: singleRow(1, 2, 2, `{ "key1": "value3", "key2": "value4" }`), }, + { + name: `parallel array modification`, + ancestor: singleRow(1, 1, 1, `{"a": [1, 2, 1], "b":0, "c":0}`), + left: singleRow(1, 2, 1, `{"a": [2, 1, 2], "b":1, "c":0}`), + right: singleRow(1, 1, 2, `{"a": [2, 1, 2], "b":0, "c":1}`), + merged: singleRow(1, 2, 2, `{"a": [2, 1, 2], "b":1, "c":1}`), + }, { name: `parallel deletion`, ancestor: singleRow(1, 1, 1, `{ "key1": "value1" }`), @@ -1337,7 +1352,7 @@ var jsonMergeTests = []schemaMergeTest{ // Which array element should go first? // We avoid making assumptions and flag this as a conflict. name: "object inside array conflict", - ancestor: singleRow(1, 1, 1, `{ "key1": [ { } ] }`), + ancestor: singleRow(1, 1, 1, `{ "key1": [ ] }`), left: singleRow(1, 2, 1, `{ "key1": [ { "key2": "value2" } ] }`), right: singleRow(1, 1, 2, `{ "key1": [ { "key3": "value3" } ] }`), dataConflict: true, @@ -1354,10 +1369,244 @@ var jsonMergeTests = []schemaMergeTest{ right: singleRow(1, 1, 2, `{ "key1": [ 1, 2 ] }`), dataConflict: true, }, + { + // Regression test: Older versions of json documents could accidentally think that $.aa is a child + // of $.a and see this as a conflict, even though it isn't one. + name: "false positive conflict", + ancestor: singleRow(1, 1, 1, `{ "a": 1, "aa":2 }`), + left: singleRow(1, 2, 1, `{ "aa":2 }`), + right: singleRow(1, 1, 2, `{ "a": 1, "aa": 3 }`), + merged: singleRow(1, 2, 2, `{ "aa": 3 }`), + }, }, }, } +// newIndexedJsonDocumentFromValue creates an IndexedJsonDocument from a provided value. +func newIndexedJsonDocumentFromValue(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) tree.IndexedJsonDocument { + doc, _, err := sqltypes.JSON.Convert(v) + require.NoError(t, err) + root, err := tree.SerializeJsonToAddr(ctx, ns, doc.(sql.JSONWrapper)) + require.NoError(t, err) + return tree.NewIndexedJsonDocument(ctx, root, ns) +} + +// createLargeDocumentForTesting creates a JSON document large enough to be split across multiple chunks. +// This is useful for testing mutation operations in large documents. +// Every different possible jsonPathType appears on a chunk boundary, for better test coverage: +// chunk 0 key: $[6].children[2].children[0].number(endOfValue) +// chunk 2 key: $[7].children[5].children[4].children[2].children(arrayInitialElement) +// chunk 5 key: $[8].children[6].children[4].children[3].children[0](startOfValue) +// chunk 8 key: $[8].children[7].children[6].children[5].children[3].children[2].children[1](objectInitialElement) +func createLargeDocumentForTesting(t *testing.T, ctx *sql.Context, ns tree.NodeStore) tree.IndexedJsonDocument { + leafDoc := make(map[string]interface{}) + leafDoc["number"] = float64(1.0) + leafDoc["string"] = "dolt" + var docExpression sql.Expression = expression.NewLiteral(newIndexedJsonDocumentFromValue(t, ctx, ns, leafDoc), sqltypes.JSON) + var err error + + for level := 0; level < 8; level++ { + docExpression, err = json.NewJSONInsert(docExpression, expression.NewLiteral(fmt.Sprintf("$.level%d", level), sqltypes.Text), docExpression) + require.NoError(t, err) + } + doc, err := docExpression.Eval(ctx, nil) + require.NoError(t, err) + return newIndexedJsonDocumentFromValue(t, ctx, ns, doc) +} + +func jsonMergeLargeDocumentTests(t *testing.T) []schemaMergeTest { + // Test for each possible case in the three-way merge code. + // Test for multiple diffs in the same chunk, + // multiple diffs in adjacent chunks (with a moved boundary) + // and multiple diffs in non-adjacent chunks. + + ctx := sql.NewEmptyContext() + ns := tree.NewTestNodeStore() + + largeObject := createLargeDocumentForTesting(t, ctx, ns) + + insert := func(document sqltypes.MutableJSON, path string, val interface{}) sqltypes.MutableJSON { + jsonVal, inRange, err := sqltypes.JSON.Convert(val) + require.NoError(t, err) + require.True(t, (bool)(inRange)) + newDoc, changed, err := document.Insert(ctx, path, jsonVal.(sql.JSONWrapper)) + require.NoError(t, err) + require.True(t, changed) + return newDoc + } + + set := func(document sqltypes.MutableJSON, path string, val interface{}) sqltypes.MutableJSON { + jsonVal, inRange, err := sqltypes.JSON.Convert(val) + require.NoError(t, err) + require.True(t, (bool)(inRange)) + newDoc, changed, err := document.Replace(ctx, path, jsonVal.(sql.JSONWrapper)) + require.NoError(t, err) + require.True(t, changed) + return newDoc + } + + delete := func(document sqltypes.MutableJSON, path string) sqltypes.MutableJSON { + newDoc, changed, err := document.Remove(ctx, path) + require.True(t, changed) + require.NoError(t, err) + return newDoc + } + + var largeJsonMergeTests = []schemaMergeTest{ + { + name: "json merge", + ancestor: *tbl(sch("CREATE TABLE t (id int PRIMARY KEY, a int, b int, j json)")), + left: tbl(sch("CREATE TABLE t (id int PRIMARY KEY, a int, b int, j json)")), + right: tbl(sch("CREATE TABLE t (id int PRIMARY KEY, a int, b int, j json)")), + merged: *tbl(sch("CREATE TABLE t (id int PRIMARY KEY, a int, b int, j json)")), + dataTests: []dataTest{ + { + name: "parallel insertion", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, insert(largeObject, "$.a", 1)), + right: singleRow(1, 1, 2, insert(largeObject, "$.a", 1)), + merged: singleRow(1, 2, 2, insert(largeObject, "$.a", 1)), + }, + { + name: "convergent insertion", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, insert(largeObject, "$.a", 1)), + right: singleRow(1, 1, 2, insert(largeObject, "$.z", 2)), + merged: singleRow(1, 2, 2, insert(insert(largeObject, "$.a", 1), "$.z", 2)), + }, + { + name: "multiple insertions", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, insert(insert(largeObject, "$.a1", 1), "$.z2", 2)), + right: singleRow(1, 1, 2, insert(insert(largeObject, "$.a2", 3), "$.z1", 4)), + merged: singleRow(1, 2, 2, insert(insert(insert(insert(largeObject, "$.z1", 4), "$.z2", 2), "$.a2", 3), "$.a1", 1)), + }, + { + name: "convergent insertion with escaped quotes in keys", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, insert(largeObject, `$."\"a\""`, 1)), + right: singleRow(1, 1, 2, insert(largeObject, `$."\"z\""`, 2)), + merged: singleRow(1, 2, 2, insert(insert(largeObject, `$."\"a\""`, 1), `$."\"z\""`, 2)), + }, + { + name: "parallel modification", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, set(largeObject, "$.level7", 1)), + right: singleRow(1, 1, 2, set(largeObject, "$.level7", 1)), + merged: singleRow(1, 2, 2, set(largeObject, "$.level7", 1)), + }, + { + name: "convergent modification", + ancestor: singleRow(1, 1, 1, insert(largeObject, "$.a", 1)), + left: singleRow(1, 2, 1, set(insert(largeObject, "$.a", 1), "$.level7", 2)), + right: singleRow(1, 1, 2, set(insert(largeObject, "$.a", 1), "$.a", 3)), + merged: singleRow(1, 2, 2, set(insert(largeObject, "$.a", 3), "$.level7", 2)), + }, + { + name: `parallel deletion`, + ancestor: singleRow(1, 1, 1, insert(largeObject, "$.a", 1)), + left: singleRow(1, 2, 1, largeObject), + right: singleRow(1, 1, 2, largeObject), + merged: singleRow(1, 2, 2, largeObject), + }, + { + name: `convergent deletion`, + ancestor: singleRow(1, 1, 1, insert(insert(largeObject, "$.a", 1), "$.z", 2)), + left: singleRow(1, 2, 1, insert(largeObject, "$.a", 1)), + right: singleRow(1, 1, 2, insert(largeObject, "$.z", 2)), + merged: singleRow(1, 2, 2, largeObject), + }, + { + name: `divergent insertion`, + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, insert(largeObject, "$.z", 1)), + right: singleRow(1, 1, 2, insert(largeObject, "$.z", 2)), + dataConflict: true, + }, + { + name: `divergent modification`, + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, set(largeObject, "$.level7", 1)), + right: singleRow(1, 1, 2, set(largeObject, "$.level7", 2)), + dataConflict: true, + }, + { + name: `divergent modification and deletion`, + ancestor: singleRow(1, 1, 1, insert(largeObject, "$.a", 1)), + left: singleRow(1, 2, 1, insert(largeObject, "$.a", 2)), + right: singleRow(1, 1, 2, largeObject), + dataConflict: true, + }, + { + name: `nested insertion`, + ancestor: singleRow(1, 1, 1, insert(largeObject, "$.level7.level4.new", map[string]interface{}{})), + left: singleRow(1, 2, 1, insert(largeObject, "$.level7.level4.new", map[string]interface{}{"a": 1})), + right: singleRow(1, 1, 2, insert(largeObject, "$.level7.level4.new", map[string]interface{}{"b": 2})), + merged: singleRow(1, 2, 2, insert(largeObject, "$.level7.level4.new", map[string]interface{}{"a": 1, "b": 2})), + }, + { + name: `nested insertion with escaped quotes in keys`, + ancestor: singleRow(1, 1, 1, insert(largeObject, `$.level7.level4."\"new\""`, map[string]interface{}{})), + left: singleRow(1, 2, 1, insert(largeObject, `$.level7.level4."\"new\""`, map[string]interface{}{"a": 1})), + right: singleRow(1, 1, 2, insert(largeObject, `$.level7.level4."\"new\""`, map[string]interface{}{"b": 2})), + merged: singleRow(1, 2, 2, insert(largeObject, `$.level7.level4."\"new\""`, map[string]interface{}{"a": 1, "b": 2})), + }, + { + name: "nested convergent modification", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, set(largeObject, "$.level7.level4", 1)), + right: singleRow(1, 1, 2, set(largeObject, "$.level7.level5", 2)), + merged: singleRow(1, 2, 2, set(set(largeObject, "$.level7.level5", 2), "$.level7.level4", 1)), + }, + { + name: `nested deletion`, + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, delete(largeObject, "$.level7")), + right: singleRow(1, 1, 2, delete(largeObject, "$.level6")), + merged: singleRow(1, 2, 2, delete(delete(largeObject, "$.level6"), "$.level7")), + }, + { + name: "complicated nested merge", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, delete(set(insert(largeObject, "$.added", 7), "$.level5.level1", 5), "$.level4")), + right: singleRow(1, 1, 2, delete(set(insert(largeObject, "$.level6.added", 8), "$.level1", 6), "$.level5.level2")), + merged: singleRow(1, 2, 2, delete(set(insert(delete(set(insert(largeObject, "$.added", 7), "$.level5.level1", 5), "$.level4"), "$.level6.added", 8), "$.level1", 6), "$.level5.level2")), + }, + { + name: "changing types", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, set(largeObject, "$.level3.number", `"dolt"`)), + right: singleRow(1, 1, 2, set(largeObject, "$.level4.string", 4)), + merged: singleRow(1, 2, 2, set(set(largeObject, "$.level3.number", `"dolt"`), "$.level4.string", 4)), + }, + { + name: "changing types conflict", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 2, 1, set(largeObject, "$.level4.string", []interface{}{})), + right: singleRow(1, 1, 2, set(largeObject, "$.level4.string", 4)), + dataConflict: true, + }, + { + name: "object insert and modify conflict", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 1, 1, insert(largeObject, "$.level5.a", 5)), + right: singleRow(1, 1, 2, set(largeObject, "$.level5", 6)), + dataConflict: true, + }, + { + name: "object insert and delete conflict", + ancestor: singleRow(1, 1, 1, largeObject), + left: singleRow(1, 1, 1, insert(largeObject, "$.level5.a", 5)), + right: singleRow(1, 1, 2, delete(largeObject, "$.level5")), + dataConflict: true, + }, + }, + }, + } + + return largeJsonMergeTests +} + func testSchemaMerge(t *testing.T, tests []schemaMergeTest) { t.Run("merge left to right", func(t *testing.T) { testSchemaMergeHelper(t, tests, false) diff --git a/go/libraries/doltcore/merge/three_way_json_differ.go b/go/libraries/doltcore/merge/three_way_json_differ.go index 5bb89dff3e..d406b27596 100644 --- a/go/libraries/doltcore/merge/three_way_json_differ.go +++ b/go/libraries/doltcore/merge/three_way_json_differ.go @@ -18,24 +18,18 @@ import ( "bytes" "context" "io" - "strings" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" "github.com/dolthub/dolt/go/store/prolly/tree" ) type ThreeWayJsonDiffer struct { - leftDiffer, rightDiffer tree.JsonDiffer + leftDiffer, rightDiffer tree.IJsonDiffer leftCurrentDiff, rightCurrentDiff *tree.JsonDiff leftIsDone, rightIsDone bool -} - -func NewThreeWayJsonDiffer(base, left, right types.JsonObject) ThreeWayJsonDiffer { - return ThreeWayJsonDiffer{ - leftDiffer: tree.NewJsonDiffer("$", base, left), - rightDiffer: tree.NewJsonDiffer("$", base, right), - } + ns tree.NodeStore } type ThreeWayJsonDiff struct { @@ -43,13 +37,13 @@ type ThreeWayJsonDiff struct { Op tree.DiffOp // a partial set of document values are set // depending on the diffOp - Key string - Base, Left, Right, Merged *types.JSONDocument + Key []byte + Left, Right, Merged sql.JSONWrapper } func (differ *ThreeWayJsonDiffer) Next(ctx context.Context) (ThreeWayJsonDiff, error) { for { - err := differ.loadNextDiff() + err := differ.loadNextDiff(ctx) if err != nil { return ThreeWayJsonDiff{}, err } @@ -66,13 +60,22 @@ func (differ *ThreeWayJsonDiffer) Next(ctx context.Context) (ThreeWayJsonDiff, e // !differ.rightIsDone && !differ.leftIsDone leftDiff := differ.leftCurrentDiff rightDiff := differ.rightCurrentDiff - cmp := bytes.Compare([]byte(leftDiff.Key), []byte(rightDiff.Key)) + leftKey := leftDiff.Key + rightKey := rightDiff.Key + + cmp := bytes.Compare(leftKey, rightKey) + // If both sides modify the same array to different values, we currently consider that to be a conflict. + // This may be relaxed in the future. + if cmp != 0 && tree.JsonKeysModifySameArray(leftKey, rightKey) { + result := ThreeWayJsonDiff{ + Op: tree.DiffOpDivergentModifyConflict, + } + return result, nil + } if cmp > 0 { - if strings.HasPrefix(leftDiff.Key, rightDiff.Key) { - // The left diff must be replacing or deleting an object, - // and the right diff makes changes to that object. - // Note the fact that all keys in these paths are quoted means we don't have to worry about - // one key being a prefix of the other and triggering a false positive here. + if tree.IsJsonKeyPrefix(leftKey, rightKey) { + // The right diff must be replacing or deleting an object, + // and the left diff makes changes to that object. result := ThreeWayJsonDiff{ Op: tree.DiffOpDivergentModifyConflict, } @@ -82,11 +85,9 @@ func (differ *ThreeWayJsonDiffer) Next(ctx context.Context) (ThreeWayJsonDiff, e // key only changed on right return differ.processRightSideOnlyDiff(), nil } else if cmp < 0 { - if strings.HasPrefix(rightDiff.Key, leftDiff.Key) { + if tree.IsJsonKeyPrefix(rightKey, leftKey) { // The right diff must be replacing or deleting an object, // and the left diff makes changes to that object. - // Note the fact that all keys in these paths are quoted means we don't have to worry about - // one key being a prefix of the other and triggering a false positive here. result := ThreeWayJsonDiff{ Op: tree.DiffOpDivergentModifyConflict, } @@ -101,12 +102,12 @@ func (differ *ThreeWayJsonDiffer) Next(ctx context.Context) (ThreeWayJsonDiff, e if differ.leftCurrentDiff.From == nil { // Key did not exist at base, so both sides are inserts. // Check that they're inserting the same value. - valueCmp, err := differ.leftCurrentDiff.To.Compare(differ.rightCurrentDiff.To) + valueCmp, err := types.CompareJSON(differ.leftCurrentDiff.To, differ.rightCurrentDiff.To) if err != nil { return ThreeWayJsonDiff{}, err } if valueCmp == 0 { - return differ.processMergedDiff(tree.DiffOpConvergentModify, differ.leftCurrentDiff.To), nil + return differ.processMergedDiff(tree.DiffOpConvergentAdd, differ.leftCurrentDiff.To), nil } else { return differ.processMergedDiff(tree.DiffOpDivergentModifyConflict, nil), nil } @@ -120,24 +121,24 @@ func (differ *ThreeWayJsonDiffer) Next(ctx context.Context) (ThreeWayJsonDiff, e // If the key existed at base, we can do a recursive three-way merge to resolve // changes to the values. // This shouldn't be necessary: if its an object on all three branches, the original diff is recursive. - mergedValue, conflict, err := mergeJSON(ctx, *differ.leftCurrentDiff.From, - *differ.leftCurrentDiff.To, - *differ.rightCurrentDiff.To) + mergedValue, conflict, err := mergeJSON(ctx, differ.ns, differ.leftCurrentDiff.From, + differ.leftCurrentDiff.To, + differ.rightCurrentDiff.To) if err != nil { return ThreeWayJsonDiff{}, err } if conflict { return differ.processMergedDiff(tree.DiffOpDivergentModifyConflict, nil), nil } else { - return differ.processMergedDiff(tree.DiffOpDivergentModifyResolved, &mergedValue), nil + return differ.processMergedDiff(tree.DiffOpDivergentModifyResolved, mergedValue), nil } } } } -func (differ *ThreeWayJsonDiffer) loadNextDiff() error { +func (differ *ThreeWayJsonDiffer) loadNextDiff(ctx context.Context) error { if differ.leftCurrentDiff == nil && !differ.leftIsDone { - newLeftDiff, err := differ.leftDiffer.Next() + newLeftDiff, err := differ.leftDiffer.Next(ctx) if err == io.EOF { differ.leftIsDone = true } else if err != nil { @@ -147,7 +148,7 @@ func (differ *ThreeWayJsonDiffer) loadNextDiff() error { } } if differ.rightCurrentDiff == nil && !differ.rightIsDone { - newRightDiff, err := differ.rightDiffer.Next() + newRightDiff, err := differ.rightDiffer.Next(ctx) if err == io.EOF { differ.rightIsDone = true } else if err != nil { @@ -172,9 +173,8 @@ func (differ *ThreeWayJsonDiffer) processRightSideOnlyDiff() ThreeWayJsonDiff { case tree.RemovedDiff: result := ThreeWayJsonDiff{ - Op: tree.DiffOpRightDelete, - Key: differ.rightCurrentDiff.Key, - Base: differ.rightCurrentDiff.From, + Op: tree.DiffOpRightDelete, + Key: differ.rightCurrentDiff.Key, } differ.rightCurrentDiff = nil return result @@ -183,7 +183,6 @@ func (differ *ThreeWayJsonDiffer) processRightSideOnlyDiff() ThreeWayJsonDiff { result := ThreeWayJsonDiff{ Op: tree.DiffOpRightModify, Key: differ.rightCurrentDiff.Key, - Base: differ.rightCurrentDiff.From, Right: differ.rightCurrentDiff.To, } differ.rightCurrentDiff = nil @@ -193,11 +192,10 @@ func (differ *ThreeWayJsonDiffer) processRightSideOnlyDiff() ThreeWayJsonDiff { } } -func (differ *ThreeWayJsonDiffer) processMergedDiff(op tree.DiffOp, merged *types.JSONDocument) ThreeWayJsonDiff { +func (differ *ThreeWayJsonDiffer) processMergedDiff(op tree.DiffOp, merged sql.JSONWrapper) ThreeWayJsonDiff { result := ThreeWayJsonDiff{ Op: op, Key: differ.leftCurrentDiff.Key, - Base: differ.leftCurrentDiff.From, Left: differ.leftCurrentDiff.To, Right: differ.rightCurrentDiff.To, Merged: merged, diff --git a/go/libraries/doltcore/schema/index_coll.go b/go/libraries/doltcore/schema/index_coll.go index c09f256ae7..00c8c71c6d 100644 --- a/go/libraries/doltcore/schema/index_coll.go +++ b/go/libraries/doltcore/schema/index_coll.go @@ -41,7 +41,7 @@ type IndexCollection interface { Equals(other IndexCollection) bool // GetByName returns the index with the given name, or nil if it does not exist. GetByName(indexName string) Index - // GetByName returns the index with a matching case-insensitive name, the bool return value indicates if a match was found. + // GetByNameCaseInsensitive returns the index with a matching case-insensitive name, the bool return value indicates if a match was found. GetByNameCaseInsensitive(indexName string) (Index, bool) // GetIndexByColumnNames returns whether the collection contains an index that has this exact collection and ordering of columns. GetIndexByColumnNames(cols ...string) (Index, bool) @@ -173,10 +173,6 @@ func (ixc *indexCollectionImpl) AddIndex(indexes ...Index) { if ok { ixc.removeIndex(oldNamedIndex) } - oldTaggedIndex := ixc.containsColumnTagCollection(index.tags...) - if oldTaggedIndex != nil { - ixc.removeIndex(oldTaggedIndex) - } ixc.indexes[lowerName] = index for _, tag := range index.tags { ixc.colTagToIndex[tag] = append(ixc.colTagToIndex[tag], index) diff --git a/go/libraries/doltcore/schema/index_test.go b/go/libraries/doltcore/schema/index_test.go index 5c3ad696c9..c5b65ebadf 100644 --- a/go/libraries/doltcore/schema/index_test.go +++ b/go/libraries/doltcore/schema/index_test.go @@ -84,56 +84,30 @@ func TestIndexCollectionAddIndex(t *testing.T) { indexColl.clear(t) } - const prefix = "new_" - - t.Run("Tag Overwrites", func(t *testing.T) { + t.Run("Duplicate column set", func(t *testing.T) { for _, testIndex := range testIndexes { indexColl.AddIndex(testIndex) newIndex := testIndex.copy() - newIndex.name = prefix + testIndex.name + newIndex.name = "dupe_" + testIndex.name indexColl.AddIndex(newIndex) assert.Equal(t, newIndex, indexColl.GetByName(newIndex.Name())) - assert.Nil(t, indexColl.GetByName(testIndex.Name())) + assert.Equal(t, testIndex, indexColl.GetByName(testIndex.Name())) assert.Contains(t, indexColl.AllIndexes(), newIndex) - assert.NotContains(t, indexColl.AllIndexes(), testIndex) + assert.Contains(t, indexColl.AllIndexes(), testIndex) for _, tag := range newIndex.IndexedColumnTags() { assert.Contains(t, indexColl.IndexesWithTag(tag), newIndex) - assert.NotContains(t, indexColl.IndexesWithTag(tag), testIndex) + assert.Contains(t, indexColl.IndexesWithTag(tag), testIndex) } for _, col := range newIndex.ColumnNames() { assert.Contains(t, indexColl.IndexesWithColumn(col), newIndex) - assert.NotContains(t, indexColl.IndexesWithColumn(col), testIndex) + assert.Contains(t, indexColl.IndexesWithColumn(col), testIndex) } assert.True(t, indexColl.Contains(newIndex.Name())) - assert.False(t, indexColl.Contains(testIndex.Name())) + assert.True(t, indexColl.Contains(testIndex.Name())) assert.True(t, indexColl.hasIndexOnColumns(newIndex.ColumnNames()...)) assert.True(t, indexColl.hasIndexOnTags(newIndex.IndexedColumnTags()...)) } }) - - t.Run("Name Overwrites", func(t *testing.T) { - // should be able to reduce collection to one index - lastStanding := &indexImpl{ - name: "none", - tags: []uint64{4}, - allTags: []uint64{4, 1, 2}, - indexColl: indexColl, - } - - for _, testIndex := range testIndexes { - lastStanding.name = prefix + testIndex.name - indexColl.AddIndex(lastStanding) - } - - assert.Equal(t, map[string]*indexImpl{lastStanding.name: lastStanding}, indexColl.indexes) - for tag, indexes := range indexColl.colTagToIndex { - if tag == 4 { - assert.Equal(t, indexes, []*indexImpl{lastStanding}) - } else { - assert.Empty(t, indexes) - } - } - }) } func TestIndexCollectionAddIndexByColNames(t *testing.T) { diff --git a/go/libraries/doltcore/schema/typeinfo/int_test.go b/go/libraries/doltcore/schema/typeinfo/int_test.go index 2b6b12f66a..cf0928dcb4 100644 --- a/go/libraries/doltcore/schema/typeinfo/int_test.go +++ b/go/libraries/doltcore/schema/typeinfo/int_test.go @@ -230,7 +230,7 @@ func TestIntParseValue(t *testing.T) { { Int64Type, "100.5", - 100, + 101, false, }, { diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 1151f6f2c5..25af5c0ad0 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -397,6 +397,17 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds return nil, false, err } return dt, true, nil + case strings.HasPrefix(lwrName, doltdb.DoltWorkspaceTablePrefix): + sess := dsess.DSessFromSess(ctx.Session) + roots, _ := sess.GetRoots(ctx, db.RevisionQualifiedName()) + + userTable := tblName[len(doltdb.DoltWorkspaceTablePrefix):] + + dt, err := dtables.NewWorkspaceTable(ctx, userTable, roots) + if err != nil { + return nil, false, err + } + return dt, true, nil } var dt sql.Table diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index 8abe6ca785..6413a5f7dd 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -1289,12 +1289,12 @@ func (p *DoltDatabaseProvider) SessionDatabase(ctx *sql.Context, name string) (d } // Function implements the FunctionProvider interface -func (p *DoltDatabaseProvider) Function(_ *sql.Context, name string) (sql.Function, error) { +func (p *DoltDatabaseProvider) Function(_ *sql.Context, name string) (sql.Function, bool) { fn, ok := p.functions[strings.ToLower(name)] if !ok { - return nil, sql.ErrFunctionNotFound.New(name) + return nil, false } - return fn, nil + return fn, true } func (p *DoltDatabaseProvider) Register(d sql.ExternalStoredProcedureDetails) { diff --git a/go/libraries/doltcore/sqle/dolt_diff_table_function.go b/go/libraries/doltcore/sqle/dolt_diff_table_function.go index 6d0b01c861..8200aa99e8 100644 --- a/go/libraries/doltcore/sqle/dolt_diff_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_diff_table_function.go @@ -186,7 +186,7 @@ func (dtf *DiffTableFunction) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, ddb := sqledb.DbData().Ddb dp := dtables.NewDiffPartition(dtf.tableDelta.ToTable, dtf.tableDelta.FromTable, toCommitStr, fromCommitStr, dtf.toDate, dtf.fromDate, dtf.tableDelta.ToSch, dtf.tableDelta.FromSch) - return dtables.NewDiffPartitionRowIter(*dp, ddb, dtf.joiner), nil + return dtables.NewDiffPartitionRowIter(dp, ddb, dtf.joiner), nil } // findMatchingDelta returns the best matching table delta for the table name diff --git a/go/libraries/doltcore/sqle/dolt_patch_table_function.go b/go/libraries/doltcore/sqle/dolt_patch_table_function.go index 6c5bba841b..1a6732e60d 100644 --- a/go/libraries/doltcore/sqle/dolt_patch_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_patch_table_function.go @@ -632,7 +632,7 @@ func getDiffQuery(ctx *sql.Context, dbData env.DbData, td diff.TableDelta, fromR diffQuerySqlSch, projections := getDiffQuerySqlSchemaAndProjections(diffPKSch.Schema, columnsWithDiff) dp := dtables.NewDiffPartition(td.ToTable, td.FromTable, toRefDetails.hashStr, fromRefDetails.hashStr, toRefDetails.commitTime, fromRefDetails.commitTime, td.ToSch, td.FromSch) - ri := dtables.NewDiffPartitionRowIter(*dp, dbData.Ddb, j) + ri := dtables.NewDiffPartitionRowIter(dp, dbData.Ddb, j) return diffQuerySqlSch, projections, ri, nil } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go b/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go index 645f9bafe3..7c6d39021a 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" "github.com/dolthub/dolt/go/libraries/doltcore/cherry_pick" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" ) var ErrEmptyCherryPick = errors.New("cannot cherry-pick empty string") @@ -95,7 +96,14 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e return "", 0, 0, 0, ErrEmptyCherryPick } - commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherry_pick.CherryPickOptions{}) + cherryPickOptions := cherry_pick.NewCherryPickOptions() + + // If --allow-empty is specified, then empty commits are allowed to be cherry-picked + if apr.Contains(cli.AllowEmptyFlag) { + cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit + } + + commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions) if err != nil { return "", 0, 0, 0, err } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go b/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go index ca789beda6..d3296ed2a6 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go @@ -17,6 +17,7 @@ package dprocedures import ( "errors" "fmt" + "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -31,6 +32,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/rebase" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/libraries/utils/argparser" ) var doltRebaseProcedureSchema = []*sql.Column{ @@ -134,6 +136,15 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) { } default: + commitBecomesEmptyHandling, err := processCommitBecomesEmptyParams(apr) + if err != nil { + return 1, "", err + } + + // The default, in rebase, for handling commits that start off empty is to keep them + // TODO: Add support for --keep-empty and --no-keep-empty flags + emptyCommitHandling := doltdb.EmptyCommitHandling(doltdb.KeepEmptyCommit) + if apr.NArg() == 0 { return 1, "", fmt.Errorf("not enough args") } else if apr.NArg() > 1 { @@ -142,7 +153,7 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) { if !apr.Contains(cli.InteractiveFlag) { return 1, "", fmt.Errorf("non-interactive rebases not currently supported") } - err = startRebase(ctx, apr.Arg(0)) + err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling) if err != nil { return 1, "", err } @@ -158,7 +169,33 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) { } } -func startRebase(ctx *sql.Context, upstreamPoint string) error { +// processCommitBecomesEmptyParams examines the parsed arguments in |apr| for the "empty" arg +// and returns the empty commit handling strategy to use when a commit being rebased becomes +// empty. If an invalid argument value is encountered, an error is returned. +func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.EmptyCommitHandling, error) { + commitBecomesEmptyParam, isCommitBecomesEmptySpecified := apr.GetValue(cli.EmptyParam) + if !isCommitBecomesEmptySpecified { + // If no option is specified, then by default, commits that become empty are dropped. Git has the same + // default for non-interactive rebases; for interactive rebases, Git uses the default action of "stop" to + // let the user examine the changes and decide what to do next. We don't support the "stop" action yet, so + // we default to "drop" even in the interactive rebase case. + return doltdb.DropEmptyCommit, nil + } + + if strings.EqualFold(commitBecomesEmptyParam, "keep") { + return doltdb.KeepEmptyCommit, nil + } else if strings.EqualFold(commitBecomesEmptyParam, "drop") { + return doltdb.DropEmptyCommit, nil + } else { + return -1, fmt.Errorf("unsupported option for the empty flag (%s); "+ + "only 'keep' or 'drop' are allowed", commitBecomesEmptyParam) + } +} + +// startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased +// commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but +// do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits. +func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error { if upstreamPoint == "" { return fmt.Errorf("no upstream branch specified") } @@ -245,7 +282,8 @@ func startRebase(ctx *sql.Context, upstreamPoint string) error { return err } - newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working) + newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working, + commitBecomesEmptyHandling, emptyCommitHandling) if err != nil { return err } @@ -415,7 +453,9 @@ func continueRebase(ctx *sql.Context) (string, error) { } for _, step := range rebasePlan.Steps { - err = processRebasePlanStep(ctx, &step) + err = processRebasePlanStep(ctx, &step, + workingSet.RebaseState().CommitBecomesEmptyHandling(), + workingSet.RebaseState().EmptyCommitHandling()) if err != nil { return "", err } @@ -471,7 +511,8 @@ func continueRebase(ctx *sql.Context) (string, error) { }, doltSession.Provider(), nil) } -func processRebasePlanStep(ctx *sql.Context, planStep *rebase.RebasePlanStep) error { +func processRebasePlanStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, + commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error { // Make sure we have a transaction opened for the session // NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started // as we process additional rebase actions. @@ -483,19 +524,24 @@ func processRebasePlanStep(ctx *sql.Context, planStep *rebase.RebasePlanStep) er } } + // Override the default empty commit handling options for cherry-pick, since + // rebase has slightly different defaults + options := cherry_pick.NewCherryPickOptions() + options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling + options.EmptyCommitHandling = emptyCommitHandling + switch planStep.Action { case rebase.RebaseActionDrop: return nil case rebase.RebaseActionPick, rebase.RebaseActionReword: - options := cherry_pick.CherryPickOptions{} if planStep.Action == rebase.RebaseActionReword { options.CommitMessage = planStep.CommitMsg } return handleRebaseCherryPick(ctx, planStep.CommitHash, options) case rebase.RebaseActionSquash, rebase.RebaseActionFixup: - options := cherry_pick.CherryPickOptions{Amend: true} + options.Amend = true if planStep.Action == rebase.RebaseActionSquash { commitMessage, err := squashCommitMessage(ctx, planStep.CommitHash) if err != nil { diff --git a/go/libraries/doltcore/sqle/dtables/column_diff_table.go b/go/libraries/doltcore/sqle/dtables/column_diff_table.go index 3e18c307c5..27b4417cdf 100644 --- a/go/libraries/doltcore/sqle/dtables/column_diff_table.go +++ b/go/libraries/doltcore/sqle/dtables/column_diff_table.go @@ -527,7 +527,7 @@ func calculateColDelta(ctx *sql.Context, ddb *doltdb.DoltDB, delta *diff.TableDe now := time.Now() // accurate commit time returned elsewhere // TODO: schema name? dp := NewDiffPartition(delta.ToTable, delta.FromTable, delta.ToName.Name, delta.FromName.Name, (*dtypes.Timestamp)(&now), (*dtypes.Timestamp)(&now), delta.ToSch, delta.FromSch) - ri := NewDiffPartitionRowIter(*dp, ddb, j) + ri := NewDiffPartitionRowIter(dp, ddb, j) var resultColNames []string var resultDiffTypes []string diff --git a/go/libraries/doltcore/sqle/dtables/diff_iter.go b/go/libraries/doltcore/sqle/dtables/diff_iter.go index 5e284a8a30..464c781375 100644 --- a/go/libraries/doltcore/sqle/dtables/diff_iter.go +++ b/go/libraries/doltcore/sqle/dtables/diff_iter.go @@ -34,7 +34,9 @@ import ( "github.com/dolthub/dolt/go/store/val" ) -type diffRowItr struct { +// ldDiffRowItr is a sql.RowIter implementation which iterates over an LD formated DB in order to generate the +// dolt_diff_{table} results. This is legacy code at this point, as the DOLT format is what we'll support going forward. +type ldDiffRowItr struct { ad diff.RowDiffer diffSrc *diff.RowDiffSource joiner *rowconv.Joiner @@ -43,7 +45,7 @@ type diffRowItr struct { toCommitInfo commitInfo } -var _ sql.RowIter = &diffRowItr{} +var _ sql.RowIter = &ldDiffRowItr{} type commitInfo struct { name types.String @@ -52,7 +54,7 @@ type commitInfo struct { dateTag uint64 } -func newNomsDiffIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joiner, dp DiffPartition, lookup sql.IndexLookup) (*diffRowItr, error) { +func newLdDiffIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joiner, dp DiffPartition, lookup sql.IndexLookup) (*ldDiffRowItr, error) { fromData, fromSch, err := tableData(ctx, dp.from, ddb) if err != nil { @@ -110,7 +112,7 @@ func newNomsDiffIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joine src := diff.NewRowDiffSource(rd, joiner, ctx.Warn) src.AddInputRowConversion(fromConv, toConv) - return &diffRowItr{ + return &ldDiffRowItr{ ad: rd, diffSrc: src, joiner: joiner, @@ -121,7 +123,7 @@ func newNomsDiffIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner *rowconv.Joine } // Next returns the next row -func (itr *diffRowItr) Next(ctx *sql.Context) (sql.Row, error) { +func (itr *ldDiffRowItr) Next(ctx *sql.Context) (sql.Row, error) { r, err := itr.diffSrc.NextDiff() if err != nil { @@ -180,7 +182,7 @@ func (itr *diffRowItr) Next(ctx *sql.Context) (sql.Row, error) { } // Close closes the iterator -func (itr *diffRowItr) Close(*sql.Context) (err error) { +func (itr *ldDiffRowItr) Close(*sql.Context) (err error) { defer itr.ad.Close() defer func() { closeErr := itr.diffSrc.Close() @@ -203,7 +205,6 @@ type prollyDiffIter struct { fromSch, toSch schema.Schema targetFromSch, targetToSch schema.Schema fromConverter, toConverter ProllyRowConverter - fromVD, toVD val.TupleDesc keyless bool fromCm commitInfo2 @@ -285,8 +286,6 @@ func newProllyDiffIter(ctx *sql.Context, dp DiffPartition, targetFromSchema, tar return prollyDiffIter{}, err } - fromVD := fsch.GetValueDescriptor() - toVD := tsch.GetValueDescriptor() keyless := schema.IsKeyless(targetFromSchema) && schema.IsKeyless(targetToSchema) child, cancel := context.WithCancel(ctx) iter := prollyDiffIter{ @@ -298,8 +297,6 @@ func newProllyDiffIter(ctx *sql.Context, dp DiffPartition, targetFromSchema, tar targetToSch: targetToSchema, fromConverter: fromConverter, toConverter: toConverter, - fromVD: fromVD, - toVD: toVD, keyless: keyless, fromCm: fromCm, toCm: toCm, @@ -372,7 +369,7 @@ func (itr prollyDiffIter) queueRows(ctx context.Context) { // todo(andy): copy string fields func (itr prollyDiffIter) makeDiffRowItr(ctx context.Context, d tree.Diff) (*repeatingRowIter, error) { if !itr.keyless { - r, err := itr.getDiffRow(ctx, d) + r, err := itr.getDiffTableRow(ctx, d) if err != nil { return nil, err } @@ -404,7 +401,7 @@ func (itr prollyDiffIter) getDiffRowAndCardinality(ctx context.Context, d tree.D } } - r, err = itr.getDiffRow(ctx, d) + r, err = itr.getDiffTableRow(ctx, d) if err != nil { return nil, 0, err } @@ -412,7 +409,9 @@ func (itr prollyDiffIter) getDiffRowAndCardinality(ctx context.Context, d tree.D return r, n, nil } -func (itr prollyDiffIter) getDiffRow(ctx context.Context, dif tree.Diff) (row sql.Row, err error) { +// getDiffTableRow returns a row for the diff table given the diff type and the row from the source and target tables. The +// output schema is intended for dolt_diff_* tables and dolt_diff function. +func (itr prollyDiffIter) getDiffTableRow(ctx context.Context, dif tree.Diff) (row sql.Row, err error) { tLen := schemaSize(itr.targetToSch) fLen := schemaSize(itr.targetFromSch) @@ -500,16 +499,15 @@ func maybeTime(t *time.Time) interface{} { var _ sql.RowIter = (*diffPartitionRowIter)(nil) type diffPartitionRowIter struct { - diffPartitions *DiffPartitions ddb *doltdb.DoltDB joiner *rowconv.Joiner - currentPartition *sql.Partition + currentPartition *DiffPartition currentRowIter *sql.RowIter } -func NewDiffPartitionRowIter(partition sql.Partition, ddb *doltdb.DoltDB, joiner *rowconv.Joiner) *diffPartitionRowIter { +func NewDiffPartitionRowIter(partition *DiffPartition, ddb *doltdb.DoltDB, joiner *rowconv.Joiner) *diffPartitionRowIter { return &diffPartitionRowIter{ - currentPartition: &partition, + currentPartition: partition, ddb: ddb, joiner: joiner, } @@ -518,16 +516,10 @@ func NewDiffPartitionRowIter(partition sql.Partition, ddb *doltdb.DoltDB, joiner func (itr *diffPartitionRowIter) Next(ctx *sql.Context) (sql.Row, error) { for { if itr.currentPartition == nil { - nextPartition, err := itr.diffPartitions.Next(ctx) - if err != nil { - return nil, err - } - itr.currentPartition = &nextPartition + return nil, io.EOF } - if itr.currentRowIter == nil { - dp := (*itr.currentPartition).(DiffPartition) - rowIter, err := dp.GetRowIter(ctx, itr.ddb, itr.joiner, sql.IndexLookup{}) + rowIter, err := itr.currentPartition.GetRowIter(ctx, itr.ddb, itr.joiner, sql.IndexLookup{}) if err != nil { return nil, err } @@ -538,12 +530,7 @@ func (itr *diffPartitionRowIter) Next(ctx *sql.Context) (sql.Row, error) { if err == io.EOF { itr.currentPartition = nil itr.currentRowIter = nil - - if itr.diffPartitions == nil { - return nil, err - } - - continue + return nil, err } else if err != nil { return nil, err } else { diff --git a/go/libraries/doltcore/sqle/dtables/diff_table.go b/go/libraries/doltcore/sqle/dtables/diff_table.go index 1651c4c56b..f20e95f71a 100644 --- a/go/libraries/doltcore/sqle/dtables/diff_table.go +++ b/go/libraries/doltcore/sqle/dtables/diff_table.go @@ -674,7 +674,7 @@ func (dp DiffPartition) GetRowIter(ctx *sql.Context, ddb *doltdb.DoltDB, joiner if types.IsFormat_DOLT(ddb.Format()) { return newProllyDiffIter(ctx, dp, dp.fromSch, dp.toSch) } else { - return newNomsDiffIter(ctx, ddb, joiner, dp, lookup) + return newLdDiffIter(ctx, ddb, joiner, dp, lookup) } } diff --git a/go/libraries/doltcore/sqle/dtables/workspace.go b/go/libraries/doltcore/sqle/dtables/workspace.go new file mode 100644 index 0000000000..30481c71be --- /dev/null +++ b/go/libraries/doltcore/sqle/dtables/workspace.go @@ -0,0 +1,587 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dtables + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/dolthub/go-mysql-server/sql" + sqltypes "github.com/dolthub/go-mysql-server/sql/types" + + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/types" + "github.com/dolthub/dolt/go/store/val" +) + +type WorkspaceTable struct { + roots doltdb.Roots + tableName string + nomsSchema schema.Schema + sqlSchema sql.Schema + stagedDeltas *diff.TableDelta + workingDeltas *diff.TableDelta + + headSchema schema.Schema +} + +var _ sql.Table = (*WorkspaceTable)(nil) + +func NewWorkspaceTable(ctx *sql.Context, tblName string, roots doltdb.Roots) (sql.Table, error) { + stageDlt, err := diff.GetTableDeltas(ctx, roots.Head, roots.Staged) + if err != nil { + return nil, err + } + var stgDel *diff.TableDelta + for _, delta := range stageDlt { + if delta.FromName.Name == tblName || delta.ToName.Name == tblName { + stgDel = &delta + break + } + } + + workingDlt, err := diff.GetTableDeltas(ctx, roots.Head, roots.Working) + if err != nil { + return nil, err + } + + var wkDel *diff.TableDelta + for _, delta := range workingDlt { + if delta.FromName.Name == tblName || delta.ToName.Name == tblName { + wkDel = &delta + break + } + } + + if wkDel == nil && stgDel == nil { + emptyTable := emptyWorkspaceTable{tableName: tblName} + return &emptyTable, nil + } + + var fromSch schema.Schema + if stgDel != nil && stgDel.FromTable != nil { + fromSch, err = stgDel.FromTable.GetSchema(ctx) + if err != nil { + return nil, err + } + } else if wkDel != nil && wkDel.FromTable != nil { + fromSch, err = wkDel.FromTable.GetSchema(ctx) + if err != nil { + return nil, err + } + } + + toSch := fromSch + if wkDel != nil && wkDel.ToTable != nil { + toSch, err = wkDel.ToTable.GetSchema(ctx) + if err != nil { + return nil, err + } + } else if stgDel != nil && stgDel.ToTable != nil { + toSch, err = stgDel.ToTable.GetSchema(ctx) + if err != nil { + return nil, err + } + } + if fromSch == nil && toSch == nil { + return nil, errors.New("Runtime error: from and to schemas are both nil") + } + if fromSch == nil { + fromSch = toSch + } + + totalSch, err := workspaceSchema(fromSch, toSch) + if err != nil { + return nil, err + } + finalSch, err := sqlutil.FromDoltSchema("", "", totalSch) + if err != nil { + return nil, err + } + + return &WorkspaceTable{ + roots: roots, + tableName: tblName, + nomsSchema: totalSch, + sqlSchema: finalSch.Schema, + stagedDeltas: stgDel, + workingDeltas: wkDel, + headSchema: fromSch, + }, nil +} + +func (wt *WorkspaceTable) Name() string { + return doltdb.DoltWorkspaceTablePrefix + wt.tableName +} + +func (wt *WorkspaceTable) String() string { + return wt.Name() +} + +func (wt *WorkspaceTable) Schema() sql.Schema { + return wt.sqlSchema +} + +// CalculateDiffSchema returns the schema for the dolt_diff table based on the schemas from the from and to tables. +// Either may be nil, in which case the nil argument will use the schema of the non-nil argument +func workspaceSchema(fromSch, toSch schema.Schema) (schema.Schema, error) { + if fromSch == nil && toSch == nil { + return nil, errors.New("Runtime error:non-nil argument required to CalculateDiffSchema") + } else if fromSch == nil { + fromSch = toSch + } else if toSch == nil { + toSch = fromSch + } + + cols := make([]schema.Column, 0, 3+toSch.GetAllCols().Size()+fromSch.GetAllCols().Size()) + + cols = append(cols, + schema.NewColumn("id", 0, types.UintKind, true), + schema.NewColumn("staged", 0, types.BoolKind, false), + schema.NewColumn("diff_type", 0, types.StringKind, false), + ) + + transformer := func(sch schema.Schema, namer func(string) string) error { + return sch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + c, err := schema.NewColumnWithTypeInfo( + namer(col.Name), + uint64(len(cols)), + col.TypeInfo, + false, + col.Default, + false, + col.Comment) + if err != nil { + return true, err + } + cols = append(cols, c) + return false, nil + }) + } + + err := transformer(toSch, diff.ToColNamer) + if err != nil { + return nil, err + } + err = transformer(fromSch, diff.FromColNamer) + if err != nil { + return nil, err + } + + return schema.UnkeyedSchemaFromCols(schema.NewColCollection(cols...)), nil +} + +func (wt *WorkspaceTable) Collation() sql.CollationID { return sql.Collation_Default } + +type WorkspacePartitionItr struct { + partition *WorkspacePartition +} + +func (w *WorkspacePartitionItr) Close(_ *sql.Context) error { + return nil +} + +func (w *WorkspacePartitionItr) Next(_ *sql.Context) (sql.Partition, error) { + if w.partition == nil { + return nil, io.EOF + } + ans := w.partition + w.partition = nil + return ans, nil +} + +type WorkspacePartition struct { + name string + base *doltdb.Table + baseSch schema.Schema + working *doltdb.Table + workingSch schema.Schema + staging *doltdb.Table + stagingSch schema.Schema +} + +var _ sql.Partition = (*WorkspacePartition)(nil) + +func (w *WorkspacePartition) Key() []byte { + return []byte(w.name) +} + +func (wt *WorkspaceTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) { + _, baseTable, baseTableExists, err := resolve.Table(ctx, wt.roots.Head, wt.tableName) + if err != nil { + return nil, err + } + var baseSchema schema.Schema = schema.EmptySchema + if baseTableExists { + if baseSchema, err = baseTable.GetSchema(ctx); err != nil { + return nil, err + } + } + + _, stagingTable, stagingTableExists, err := resolve.Table(ctx, wt.roots.Staged, wt.tableName) + if err != nil { + return nil, err + } + var stagingSchema schema.Schema = schema.EmptySchema + if stagingTableExists { + if stagingSchema, err = stagingTable.GetSchema(ctx); err != nil { + return nil, err + } + } + + _, workingTable, workingTableExists, err := resolve.Table(ctx, wt.roots.Working, wt.tableName) + if err != nil { + return nil, err + } + var workingSchema schema.Schema = schema.EmptySchema + if workingTableExists { + if workingSchema, err = workingTable.GetSchema(ctx); err != nil { + return nil, err + } + } + + part := WorkspacePartition{ + name: wt.Name(), + base: baseTable, + baseSch: baseSchema, + staging: stagingTable, + stagingSch: stagingSchema, + working: workingTable, + workingSch: workingSchema, + } + + return &WorkspacePartitionItr{&part}, nil +} + +func (wt *WorkspaceTable) PartitionRows(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { + wp, ok := part.(*WorkspacePartition) + if !ok { + return nil, fmt.Errorf("Runtime Exception: expected a WorkspacePartition, got %T", part) + } + + return newWorkspaceDiffIter(ctx, *wp) +} + +// workspaceDiffIter enables the iteration over the diff information between the HEAD, STAGING, and WORKING roots. +type workspaceDiffIter struct { + base prolly.Map + working prolly.Map + staging prolly.Map + + baseConverter ProllyRowConverter + workingConverter ProllyRowConverter + stagingConverter ProllyRowConverter + + tgtBaseSch schema.Schema + tgtWorkingSch schema.Schema + tgtStagingSch schema.Schema + + rows chan sql.Row + errChan chan error + cancel func() +} + +func (itr workspaceDiffIter) Next(ctx *sql.Context) (sql.Row, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-itr.errChan: + return nil, err + case row, ok := <-itr.rows: + if !ok { + return nil, io.EOF + } + return row, nil + } +} + +func (itr workspaceDiffIter) Close(c *sql.Context) error { + itr.cancel() + return nil +} + +// getWorkspaceTableRow returns a row for the diff table given the diff type and the row from the source and target tables. The +// output schema is intended for dolt_workspace_* tables. +func getWorkspaceTableRow( + ctx context.Context, + rowId int, + staged bool, + toSch schema.Schema, + fromSch schema.Schema, + toConverter ProllyRowConverter, + fromConverter ProllyRowConverter, + dif tree.Diff, +) (row sql.Row, err error) { + tLen := schemaSize(toSch) + fLen := schemaSize(fromSch) + + if fLen == 0 && dif.Type == tree.AddedDiff { + fLen = tLen + } else if tLen == 0 && dif.Type == tree.RemovedDiff { + tLen = fLen + } + + row = make(sql.Row, 3+tLen+fLen) + + row[0] = rowId + row[1] = staged + row[2] = diffTypeString(dif) + + idx := 3 + + if dif.Type != tree.RemovedDiff { + err = toConverter.PutConverted(ctx, val.Tuple(dif.Key), val.Tuple(dif.To), row[idx:idx+tLen]) + if err != nil { + return nil, err + } + } + idx += tLen + + if dif.Type != tree.AddedDiff { + err = fromConverter.PutConverted(ctx, val.Tuple(dif.Key), val.Tuple(dif.From), row[idx:idx+fLen]) + if err != nil { + return nil, err + } + } + + return row, nil +} + +// queueWorkspaceRows is similar to prollyDiffIter.queueRows, but for workspaces. It performs two seperate calls +// to prolly.DiffMaps, one for staging and one for working. The end result is queueing the rows from both maps +// into the "rows" channel of the workspaceDiffIter. +func (itr *workspaceDiffIter) queueWorkspaceRows(ctx context.Context) { + k1 := schema.EmptySchema == itr.tgtStagingSch || schema.IsKeyless(itr.tgtStagingSch) + k2 := schema.EmptySchema == itr.tgtBaseSch || schema.IsKeyless(itr.tgtBaseSch) + k3 := schema.EmptySchema == itr.tgtWorkingSch || schema.IsKeyless(itr.tgtWorkingSch) + + keyless := k1 && k2 && k3 + + idx := 0 + + err := prolly.DiffMaps(ctx, itr.base, itr.staging, false, func(ctx context.Context, d tree.Diff) error { + rows, err := itr.makeWorkspaceRows(ctx, idx, true, itr.tgtStagingSch, itr.tgtBaseSch, keyless, itr.stagingConverter, itr.baseConverter, d) + if err != nil { + return err + } + for _, r := range rows { + select { + case <-ctx.Done(): + return ctx.Err() + case itr.rows <- r: + idx++ + continue + } + } + return nil + }) + + if err != nil && err != io.EOF { + select { + case <-ctx.Done(): + case itr.errChan <- err: + } + return + } + + err = prolly.DiffMaps(ctx, itr.staging, itr.working, false, func(ctx context.Context, d tree.Diff) error { + rows, err := itr.makeWorkspaceRows(ctx, idx, false, itr.tgtWorkingSch, itr.tgtStagingSch, keyless, itr.workingConverter, itr.stagingConverter, d) + if err != nil { + return err + } + for _, r := range rows { + select { + case <-ctx.Done(): + return ctx.Err() + case itr.rows <- r: + idx++ + continue + } + } + return nil + }) + + // we need to drain itr.rows before returning io.EOF + close(itr.rows) +} + +// makeWorkspaceRows takes the diff information from the prolly.DiffMaps and converts it into a slice of rows. In the case +// of tables with a primary key, this method will return a single row. For tables without a primary key, it will return +// 1 or more rows. The rows returned are in the full schema that the workspace table returns, so the workspace table columns +// (id, staged, diff_type) are included in the returned rows with the populated values. +func (itr *workspaceDiffIter) makeWorkspaceRows( + ctx context.Context, + idx int, + staging bool, + toSch schema.Schema, + fromSch schema.Schema, + keyless bool, + toConverter ProllyRowConverter, + fromConverter ProllyRowConverter, + d tree.Diff, +) ([]sql.Row, error) { + n := uint64(1) + if keyless { + switch d.Type { + case tree.AddedDiff: + n = val.ReadKeylessCardinality(val.Tuple(d.To)) + case tree.RemovedDiff: + n = val.ReadKeylessCardinality(val.Tuple(d.From)) + case tree.ModifiedDiff: + fN := val.ReadKeylessCardinality(val.Tuple(d.From)) + tN := val.ReadKeylessCardinality(val.Tuple(d.To)) + if fN < tN { + n = tN - fN + d.Type = tree.AddedDiff + } else { + n = fN - tN + d.Type = tree.RemovedDiff + } + } + } + + ans := make([]sql.Row, n) + for i := uint64(0); i < n; i++ { + r, err := getWorkspaceTableRow(ctx, idx, staging, toSch, fromSch, toConverter, fromConverter, d) + if err != nil { + return nil, err + } + ans[i] = r + idx++ + } + return ans, nil +} + +// newWorkspaceDiffIter takes a WorkspacePartition and returns a workspaceDiffIter. The workspaceDiffIter is used to iterate +// over the diff information from the prolly.DiffMaps. +func newWorkspaceDiffIter(ctx *sql.Context, wp WorkspacePartition) (workspaceDiffIter, error) { + var base, working, staging prolly.Map + + if wp.base != nil { + idx, err := wp.base.GetRowData(ctx) + if err != nil { + return workspaceDiffIter{}, err + } + base = durable.ProllyMapFromIndex(idx) + } + + if wp.staging != nil { + idx, err := wp.staging.GetRowData(ctx) + if err != nil { + return workspaceDiffIter{}, err + } + staging = durable.ProllyMapFromIndex(idx) + } + + if wp.working != nil { + idx, err := wp.working.GetRowData(ctx) + if err != nil { + return workspaceDiffIter{}, err + } + working = durable.ProllyMapFromIndex(idx) + } + + var nodeStore tree.NodeStore + if wp.base != nil { + nodeStore = wp.base.NodeStore() + } else if wp.staging != nil { + nodeStore = wp.staging.NodeStore() + } else if wp.working != nil { + nodeStore = wp.working.NodeStore() + } else { + return workspaceDiffIter{}, errors.New("no base, staging, or working table") + } + + baseConverter, err := NewProllyRowConverter(wp.baseSch, wp.baseSch, ctx.Warn, nodeStore) + if err != nil { + return workspaceDiffIter{}, err + } + + stagingConverter, err := NewProllyRowConverter(wp.stagingSch, wp.stagingSch, ctx.Warn, nodeStore) + if err != nil { + return workspaceDiffIter{}, err + } + + workingConverter, err := NewProllyRowConverter(wp.workingSch, wp.workingSch, ctx.Warn, nodeStore) + if err != nil { + return workspaceDiffIter{}, err + } + + child, cancel := context.WithCancel(ctx) + iter := workspaceDiffIter{ + base: base, + working: working, + staging: staging, + + tgtBaseSch: wp.baseSch, + tgtWorkingSch: wp.workingSch, + tgtStagingSch: wp.stagingSch, + + baseConverter: baseConverter, + workingConverter: workingConverter, + stagingConverter: stagingConverter, + + rows: make(chan sql.Row, 64), + errChan: make(chan error), + cancel: cancel, + } + + go func() { + iter.queueWorkspaceRows(child) + }() + + return iter, nil +} + +type emptyWorkspaceTable struct { + tableName string +} + +var _ sql.Table = (*emptyWorkspaceTable)(nil) + +func (e emptyWorkspaceTable) Name() string { + return doltdb.DoltWorkspaceTablePrefix + e.tableName +} + +func (e emptyWorkspaceTable) String() string { + return e.Name() +} + +func (e emptyWorkspaceTable) Schema() sql.Schema { + return []*sql.Column{ + {Name: "id", Type: sqltypes.Int32, Nullable: false}, + {Name: "staged", Type: sqltypes.Boolean, Nullable: false}, + } +} + +func (e emptyWorkspaceTable) Collation() sql.CollationID { return sql.Collation_Default } + +func (e emptyWorkspaceTable) Partitions(c *sql.Context) (sql.PartitionIter, error) { + return index.SinglePartitionIterFromNomsMap(nil), nil +} + +func (e emptyWorkspaceTable) PartitionRows(c *sql.Context, partition sql.Partition) (sql.RowIter, error) { + return sql.RowsToRowIter(), nil +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 64ef85798a..4ee3f350c3 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -2052,3 +2052,8 @@ func TestStatsAutoRefreshConcurrency(t *testing.T) { wg.Wait() } } + +func TestDoltWorkspace(t *testing.T) { + harness := newDoltEnginetestHarness(t) + RunDoltWorkspaceTests(t, harness) +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go index c84dc35b33..56cc97d6ce 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go @@ -1960,3 +1960,13 @@ func RunDoltReflogTestsPrepared(t *testing.T, h DoltEnginetestHarness) { }() } } + +func RunDoltWorkspaceTests(t *testing.T, h DoltEnginetestHarness) { + for _, script := range DoltWorkspaceScriptTests { + func() { + h = h.NewHarness(t) + defer h.Close() + enginetest.TestScript(t, h, script) + }() + } +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 6e6f4bdf51..52c61dbbcd 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -7176,6 +7176,7 @@ var DoltSystemVariables = []queries.ScriptTest{ {"dolt_remote_branches"}, {"dolt_remotes"}, {"dolt_status"}, + {"dolt_workspace_test"}, {"test"}, }, }, diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go index e09ecf46a4..07530feb53 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_merge.go @@ -2438,6 +2438,86 @@ var MergeScripts = []queries.ScriptTest{ }, }, }, + { + // Ensure that column defaults are normalized to the same thing, so they merge with no issue + Name: "merge with float column default", + SetUpScript: []string{ + "create table t (f float);", + "call dolt_commit('-Am', 'setup');", + "call dolt_branch('other');", + "alter table t modify column f float default '1.00';", + "call dolt_commit('-Am', 'change default on main');", + "call dolt_checkout('other');", + "alter table t modify column f float default '1.000000000';", + "call dolt_commit('-Am', 'change default on other');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_merge('main')", + Expected: []sql.Row{{doltCommit, 0, 0, "merge successful"}}, + }, + }, + }, + { + // Ensure that column defaults are normalized to the same thing, so they merge with no issue + Name: "merge with float 1.23 column default", + SetUpScript: []string{ + "create table t (f float);", + "call dolt_commit('-Am', 'setup');", + "call dolt_branch('other');", + "alter table t modify column f float default '1.23000';", + "call dolt_commit('-Am', 'change default on main');", + "call dolt_checkout('other');", + "alter table t modify column f float default '1.23000000000';", + "call dolt_commit('-Am', 'change default on other');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_merge('main')", + Expected: []sql.Row{{doltCommit, 0, 0, "merge successful"}}, + }, + }, + }, + { + // Ensure that column defaults are normalized to the same thing, so they merge with no issue + Name: "merge with decimal 1.23 column default", + SetUpScript: []string{ + "create table t (d decimal(20, 10));", + "call dolt_commit('-Am', 'setup');", + "call dolt_branch('other');", + "alter table t modify column d decimal(20, 10) default '1.23000';", + "call dolt_commit('-Am', 'change default on main');", + "call dolt_checkout('other');", + "alter table t modify column d decimal(20, 10) default '1.23000000000';", + "call dolt_commit('-Am', 'change default on other');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_merge('main')", + Expected: []sql.Row{{doltCommit, 0, 0, "merge successful"}}, + }, + }, + }, + { + // Ensure that column defaults are normalized to the same thing, so they merge with no issue + Name: "merge with different types", + SetUpScript: []string{ + "create table t (f float);", + "call dolt_commit('-Am', 'setup');", + "call dolt_branch('other');", + "alter table t modify column f float default 1.23;", + "call dolt_commit('-Am', 'change default on main');", + "call dolt_checkout('other');", + "alter table t modify column f float default '1.23';", + "call dolt_commit('-Am', 'change default on other');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_merge('main')", + Expected: []sql.Row{{doltCommit, 0, 0, "merge successful"}}, + }, + }, + }, } var KeylessMergeCVsAndConflictsScripts = []queries.ScriptTest{ diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_rebase.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_rebase.go index 988b552407..d6565bf163 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries_rebase.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_rebase.go @@ -239,6 +239,133 @@ var DoltRebaseScriptTests = []queries.ScriptTest{ }, }, }, + { + Name: "dolt_rebase: rebased commit becomes empty; --empty not specified", + SetUpScript: []string{ + "create table t (pk int primary key);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('branch1');", + + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on main');", + + "call dolt_checkout('branch1');", + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on branch1');", + "insert into t values (10);", + "call dolt_commit('-am', 'inserting row 10 on branch1');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_rebase('-i', 'main');", + Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_branch1; " + + "adjust the rebase plan in the dolt_rebase table, then " + + "continue rebasing by calling dolt_rebase('--continue')"}}, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"dolt_rebase_branch1"}}, + }, + { + Query: "call dolt_rebase('--continue');", + Expected: []sql.Row{{0, "Successfully rebased and updated refs/heads/branch1"}}, + }, + { + Query: "select message from dolt_log;", + Expected: []sql.Row{ + {"inserting row 10 on branch1"}, + {"inserting row 0 on main"}, + {"creating table t"}, + {"Initialize data repository"}, + }, + }, + }, + }, + { + Name: "dolt_rebase: rebased commit becomes empty; --empty=keep", + SetUpScript: []string{ + "create table t (pk int primary key);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('branch1');", + + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on main');", + + "call dolt_checkout('branch1');", + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on branch1');", + "insert into t values (10);", + "call dolt_commit('-am', 'inserting row 10 on branch1');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_rebase('-i', '--empty', 'keep', 'main');", + Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_branch1; " + + "adjust the rebase plan in the dolt_rebase table, then " + + "continue rebasing by calling dolt_rebase('--continue')"}}, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"dolt_rebase_branch1"}}, + }, + { + Query: "call dolt_rebase('--continue');", + Expected: []sql.Row{{0, "Successfully rebased and updated refs/heads/branch1"}}, + }, + { + Query: "select message from dolt_log;", + Expected: []sql.Row{ + {"inserting row 10 on branch1"}, + {"inserting row 0 on branch1"}, + {"inserting row 0 on main"}, + {"creating table t"}, + {"Initialize data repository"}, + }, + }, + }, + }, + { + Name: "dolt_rebase: rebased commit becomes empty; --empty=drop", + SetUpScript: []string{ + "create table t (pk int primary key);", + "call dolt_commit('-Am', 'creating table t');", + "call dolt_branch('branch1');", + + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on main');", + + "call dolt_checkout('branch1');", + "insert into t values (0);", + "call dolt_commit('-am', 'inserting row 0 on branch1');", + "insert into t values (10);", + "call dolt_commit('-am', 'inserting row 10 on branch1');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "call dolt_rebase('-i', '--empty', 'drop', 'main');", + Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_branch1; " + + "adjust the rebase plan in the dolt_rebase table, then " + + "continue rebasing by calling dolt_rebase('--continue')"}}, + }, + { + Query: "select active_branch();", + Expected: []sql.Row{{"dolt_rebase_branch1"}}, + }, + { + Query: "call dolt_rebase('--continue');", + Expected: []sql.Row{{0, "Successfully rebased and updated refs/heads/branch1"}}, + }, + { + Query: "select message from dolt_log;", + Expected: []sql.Row{ + {"inserting row 10 on branch1"}, + {"inserting row 0 on main"}, + {"creating table t"}, + {"Initialize data repository"}, + }, + }, + }, + }, { Name: "dolt_rebase: no commits to rebase", SetUpScript: []string{ diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_workspace.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_workspace.go new file mode 100644 index 0000000000..4b80059204 --- /dev/null +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_workspace.go @@ -0,0 +1,336 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package enginetest + +import ( + "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/sql" +) + +var DoltWorkspaceScriptTests = []queries.ScriptTest{ + + { + Name: "dolt_workspace_* multiple edits of a single row", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting 2 rows at HEAD');", + + "update tbl set val=51 where pk=42;", + "call dolt_add('tbl');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 51, 42, 42}, + }, + }, + { + Query: "update tbl set val= 108 where pk = 42;", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 51, 42, 42}, + {1, false, "modified", 42, 108, 42, 51}, + }, + }, + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 108, 42, 42}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* single unstaged row", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting 2 rows at HEAD');", + + "update tbl set val=51 where pk=42;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "modified", 42, 51, 42, 42}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* inserted row", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting 2 rows at HEAD');", + "insert into tbl values (44,44);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "added", 44, 44, nil, nil}, + }, + }, + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "added", 44, 44, nil, nil}, + }, + }, + { + Query: "update tbl set val = 108 where pk = 44;", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "added", 44, 44, nil, nil}, + {1, false, "modified", 44, 108, 44, 44}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* deleted row", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting 2 rows at HEAD');", + "delete from tbl where pk = 42;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "removed", nil, nil, 42, 42}, + }, + }, + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "removed", nil, nil, 42, 42}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* clean workspace", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting 2 rows at HEAD');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{}, + }, + { + Query: "select * from dolt_workspace_unknowntable", + Expected: []sql.Row{}, + }, + }, + }, + + { + Name: "dolt_workspace_* created table", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "added", 42, 42, nil, nil}, + {1, false, "added", 43, 43, nil, nil}, + }, + }, + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "added", 42, 42, nil, nil}, + {1, true, "added", 43, 43, nil, nil}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* dropped table", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting rows 3 rows at HEAD');", + "drop table tbl", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "removed", nil, nil, 42, 42}, + {1, false, "removed", nil, nil, 43, 43}, + }, + }, + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "removed", nil, nil, 42, 42}, + {1, true, "removed", nil, nil, 43, 43}, + }, + }, + }, + }, + + { + Name: "dolt_workspace_* keyless table", + SetUpScript: []string{ + "create table tbl (x int, y int);", + "insert into tbl values (42,42);", + "insert into tbl values (42,42);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "added", 42, 42, nil, nil}, + {1, false, "added", 42, 42, nil, nil}, + }, + }, + + { + Query: "call dolt_add('tbl');", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "added", 42, 42, nil, nil}, + {1, true, "added", 42, 42, nil, nil}, + }, + }, + { + Query: "insert into tbl values (42,42);", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "added", 42, 42, nil, nil}, + {1, true, "added", 42, 42, nil, nil}, + {2, false, "added", 42, 42, nil, nil}, + }, + }, + }, + }, + { + Name: "dolt_workspace_* schema change", + SetUpScript: []string{ + "create table tbl (pk int primary key, val int);", + "call dolt_commit('-Am', 'creating table t');", + + "insert into tbl values (42,42);", + "insert into tbl values (43,43);", + "call dolt_commit('-am', 'inserting rows 3 rows at HEAD');", + + "update tbl set val=51 where pk=42;", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "modified", 42, 51, 42, 42}, + }, + }, + + { + Query: "ALTER TABLE tbl ADD COLUMN newcol CHAR(36)", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, false, "modified", 42, 51, nil, 42, 42}, + }, + }, + { + Query: "call dolt_add('tbl')", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 51, nil, 42, 42}, + }, + }, + /* Three schemas are possible by having a schema change staged then altering the schema again. + Currently, it's unclear if/how dolt_workspace_* can/should present this since it's all about data changes, not schema changes. + { + Query: "ALTER TABLE tbl ADD COLUMN newcol2 float", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 51, nil, 42, 42}, + }, + }, + { + Query: "update tbl set val=59 where pk=42", + }, + { + Query: "select * from dolt_workspace_tbl", + Expected: []sql.Row{ + {0, true, "modified", 42, 51, nil, 42, 42}, + {1, false, "modified", 42, 59, nil, nil, 42, 42}, // + }, + }, + */ + }, + }, +} diff --git a/go/libraries/doltcore/sqle/expranalysis/expranalysis.go b/go/libraries/doltcore/sqle/expranalysis/expranalysis.go index dc5ae3a412..054ebe325d 100644 --- a/go/libraries/doltcore/sqle/expranalysis/expranalysis.go +++ b/go/libraries/doltcore/sqle/expranalysis/expranalysis.go @@ -52,7 +52,7 @@ func ResolveDefaultExpression(ctx *sql.Context, tableName string, sch schema.Sch return nil, fmt.Errorf("unable to find default or generated expression") } - return expr.Expr, nil + return expr, nil } // ResolveCheckExpression returns a sql.Expression for the check provided diff --git a/go/libraries/doltcore/sqle/sqlddl_test.go b/go/libraries/doltcore/sqle/sqlddl_test.go index b96e399abf..63531ff6ad 100644 --- a/go/libraries/doltcore/sqle/sqlddl_test.go +++ b/go/libraries/doltcore/sqle/sqlddl_test.go @@ -255,8 +255,8 @@ func TestCreateTable(t *testing.T) { schemaNewColumnWDefVal(t, "iso_code_3", 8427, gmstypes.MustCreateStringWithDefaults(sqltypes.VarChar, 3), false, `''`), schemaNewColumnWDefVal(t, "iso_country", 7151, gmstypes.MustCreateStringWithDefaults(sqltypes.VarChar, 255), false, `''`, schema.NotNullConstraint{}), schemaNewColumnWDefVal(t, "country", 879, gmstypes.MustCreateStringWithDefaults(sqltypes.VarChar, 255), false, `''`, schema.NotNullConstraint{}), - schemaNewColumnWDefVal(t, "lat", 3502, gmstypes.Float32, false, "0.0", schema.NotNullConstraint{}), - schemaNewColumnWDefVal(t, "lon", 9907, gmstypes.Float32, false, "0.0", schema.NotNullConstraint{})), + schemaNewColumnWDefVal(t, "lat", 3502, gmstypes.Float32, false, "0", schema.NotNullConstraint{}), + schemaNewColumnWDefVal(t, "lon", 9907, gmstypes.Float32, false, "0", schema.NotNullConstraint{})), }, } diff --git a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go index ec837f7922..08eac5d7c1 100644 --- a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go +++ b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go @@ -201,6 +201,10 @@ func GenerateCreateTableColumnDefinition(col schema.Column, tableCollation sql.C func GenerateCreateTableIndentedColumnDefinition(col schema.Column, tableCollation sql.CollationID) string { var defaultVal, genVal, onUpdateVal *sql.ColumnDefaultValue if col.Default != "" { + // hacky way to determine if column default is an expression + if col.Default[0] != '(' && col.Default[len(col.Default)-1] != ')' && col.Default[0] != '\'' && col.Default[len(col.Default)-1] != '\'' { + col.Default = fmt.Sprintf("'%s'", col.Default) + } defaultVal = sql.NewUnresolvedColumnDefaultValue(col.Default) } if col.Generated != "" { diff --git a/go/performance/microsysbench/sysbench_test.go b/go/performance/microsysbench/sysbench_test.go index ed1158ee0d..0a2795016f 100644 --- a/go/performance/microsysbench/sysbench_test.go +++ b/go/performance/microsysbench/sysbench_test.go @@ -90,6 +90,22 @@ func BenchmarkSelectRandomPoints(b *testing.B) { }) } +func BenchmarkSelectRandomRanges(b *testing.B) { + benchmarkSysbenchQuery(b, func(int) string { + var sb strings.Builder + sb.Grow(120) + sb.WriteString("SELECT count(k) FROM sbtest1 WHERE ") + sep := "" + for i := 1; i < 10; i++ { + start := rand.Intn(tableSize) + fmt.Fprintf(&sb, "%sk between %s and %s", sep, strconv.Itoa(start), strconv.Itoa(start+5)) + sep = " OR " + } + sb.WriteString(";") + return sb.String() + }) +} + func benchmarkSysbenchQuery(b *testing.B, getQuery func(int) string) { ctx, eng := setupBenchmark(b, dEnv) for i := 0; i < b.N; i++ { diff --git a/go/serial/workingset.fbs b/go/serial/workingset.fbs index 1057668acb..cee89b20a6 100644 --- a/go/serial/workingset.fbs +++ b/go/serial/workingset.fbs @@ -52,6 +52,12 @@ table RebaseState { // The commit that we are rebasing onto. onto_commit_addr:[ubyte] (required); + + // How to handle commits that start off empty + empty_commit_handling:uint8; + + // How to handle commits that become empty during rebasing + commit_becomes_empty_handling:uint8; } // KEEP THIS IN SYNC WITH fileidentifiers.go diff --git a/go/store/cmd/noms/noms_show.go b/go/store/cmd/noms/noms_show.go index 86816324f0..587140df6c 100644 --- a/go/store/cmd/noms/noms_show.go +++ b/go/store/cmd/noms/noms_show.go @@ -244,7 +244,12 @@ func outputEncodedValue(ctx context.Context, w io.Writer, value types.Value) err if err != nil { return err } - return tree.OutputProllyNodeBytes(w, node) + fmt.Fprintf(w, "(rows %d, depth %d) #%s {", + node.Count(), node.Level()+1, node.HashOf().String()[:8]) + err = tree.OutputAddressMapNode(w, node) + fmt.Fprintf(w, "}\n") + return err + default: return types.WriteEncodedValue(ctx, w, value) } diff --git a/go/store/datas/dataset.go b/go/store/datas/dataset.go index 29b7ceaedc..5f09a08559 100644 --- a/go/store/datas/dataset.go +++ b/go/store/datas/dataset.go @@ -162,9 +162,11 @@ type WorkingSetHead struct { } type RebaseState struct { - preRebaseWorkingAddr *hash.Hash - ontoCommitAddr *hash.Hash - branch string + preRebaseWorkingAddr *hash.Hash + ontoCommitAddr *hash.Hash + branch string + commitBecomesEmptyHandling uint8 + emptyCommitHandling uint8 } func (rs *RebaseState) PreRebaseWorkingAddr() hash.Hash { @@ -186,6 +188,14 @@ func (rs *RebaseState) OntoCommit(ctx context.Context, vr types.ValueReader) (*C return nil, nil } +func (rs *RebaseState) CommitBecomesEmptyHandling(_ context.Context) uint8 { + return rs.commitBecomesEmptyHandling +} + +func (rs *RebaseState) EmptyCommitHandling(_ context.Context) uint8 { + return rs.emptyCommitHandling +} + type MergeState struct { preMergeWorkingAddr *hash.Hash fromCommitAddr *hash.Hash @@ -433,7 +443,10 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) { ret.RebaseState = NewRebaseState( hash.New(rebaseState.PreWorkingRootAddrBytes()), hash.New(rebaseState.OntoCommitAddrBytes()), - string(rebaseState.BranchBytes())) + string(rebaseState.BranchBytes()), + rebaseState.CommitBecomesEmptyHandling(), + rebaseState.EmptyCommitHandling(), + ) } return &ret, nil diff --git a/go/store/datas/workingset.go b/go/store/datas/workingset.go index 61c869905f..ac4e77ed6b 100755 --- a/go/store/datas/workingset.go +++ b/go/store/datas/workingset.go @@ -192,6 +192,8 @@ func workingset_flatbuffer(working hash.Hash, staged *hash.Hash, mergeState *Mer serial.RebaseStateAddPreWorkingRootAddr(builder, preRebaseRootAddrOffset) serial.RebaseStateAddBranch(builder, branchOffset) serial.RebaseStateAddOntoCommitAddr(builder, ontoAddrOffset) + serial.RebaseStateAddCommitBecomesEmptyHandling(builder, rebaseState.commitBecomesEmptyHandling) + serial.RebaseStateAddEmptyCommitHandling(builder, rebaseState.emptyCommitHandling) rebaseStateOffset = serial.RebaseStateEnd(builder) } @@ -260,11 +262,13 @@ func NewMergeState( } } -func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string) *RebaseState { +func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8) *RebaseState { return &RebaseState{ - preRebaseWorkingAddr: &preRebaseWorkingRoot, - ontoCommitAddr: &commitAddr, - branch: branch, + preRebaseWorkingAddr: &preRebaseWorkingRoot, + ontoCommitAddr: &commitAddr, + branch: branch, + commitBecomesEmptyHandling: commitBecomesEmptyHandling, + emptyCommitHandling: emptyCommitHandling, } } diff --git a/go/store/nbs/race_off.go b/go/store/nbs/race_off.go new file mode 100644 index 0000000000..afd1ad72ce --- /dev/null +++ b/go/store/nbs/race_off.go @@ -0,0 +1,22 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !race +// +build !race + +package nbs + +func isRaceEnabled() bool { + return false +} diff --git a/go/store/nbs/race_on.go b/go/store/nbs/race_on.go new file mode 100644 index 0000000000..3d8ab55a75 --- /dev/null +++ b/go/store/nbs/race_on.go @@ -0,0 +1,22 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build race +// +build race + +package nbs + +func isRaceEnabled() bool { + return true +} diff --git a/go/store/nbs/table_index.go b/go/store/nbs/table_index.go index 43e2873348..bf210c4378 100644 --- a/go/store/nbs/table_index.go +++ b/go/store/nbs/table_index.go @@ -234,12 +234,14 @@ func newOnHeapTableIndex(indexBuff []byte, offsetsBuff1 []byte, count uint32, to return onHeapTableIndex{}, ErrWrongBufferSize } - tuples := indexBuff[:prefixTupleSize*count] - lengths := indexBuff[prefixTupleSize*count : prefixTupleSize*count+lengthSize*count] - suffixes := indexBuff[prefixTupleSize*count+lengthSize*count : indexSize(count)] + cnt64 := uint64(count) + + tuples := indexBuff[:prefixTupleSize*cnt64] + lengths := indexBuff[prefixTupleSize*cnt64 : prefixTupleSize*cnt64+lengthSize*cnt64] + suffixes := indexBuff[prefixTupleSize*cnt64+lengthSize*cnt64 : indexSize(count)] footer := indexBuff[indexSize(count):] - chunks2 := count / 2 + chunks2 := cnt64 / 2 r := NewOffsetsReader(bytes.NewReader(lengths)) _, err := io.ReadFull(r, offsetsBuff1) @@ -369,7 +371,7 @@ func (ti onHeapTableIndex) findPrefix(prefix uint64) (idx uint32) { } func (ti onHeapTableIndex) tupleAt(idx uint32) (prefix uint64, ord uint32) { - off := int64(prefixTupleSize * idx) + off := prefixTupleSize * int64(idx) b := ti.prefixTuples[off : off+prefixTupleSize] prefix = binary.BigEndian.Uint64(b[:]) @@ -378,13 +380,13 @@ func (ti onHeapTableIndex) tupleAt(idx uint32) (prefix uint64, ord uint32) { } func (ti onHeapTableIndex) prefixAt(idx uint32) uint64 { - off := int64(prefixTupleSize * idx) + off := prefixTupleSize * int64(idx) b := ti.prefixTuples[off : off+hash.PrefixLen] return binary.BigEndian.Uint64(b) } func (ti onHeapTableIndex) ordinalAt(idx uint32) uint32 { - off := int64(prefixTupleSize*idx) + hash.PrefixLen + off := prefixTupleSize*int64(idx) + hash.PrefixLen b := ti.prefixTuples[off : off+ordinalSize] return binary.BigEndian.Uint32(b) } @@ -394,10 +396,10 @@ func (ti onHeapTableIndex) offsetAt(ord uint32) uint64 { chunks1 := ti.count - ti.count/2 var b []byte if ord < chunks1 { - off := int64(offsetSize * ord) + off := offsetSize * int64(ord) b = ti.offsets1[off : off+offsetSize] } else { - off := int64(offsetSize * (ord - chunks1)) + off := offsetSize * int64(ord-chunks1) b = ti.offsets2[off : off+offsetSize] } return binary.BigEndian.Uint64(b) @@ -406,7 +408,7 @@ func (ti onHeapTableIndex) offsetAt(ord uint32) uint64 { func (ti onHeapTableIndex) ordinals() ([]uint32, error) { // todo: |o| is not accounted for in the memory quota o := make([]uint32, ti.count) - for i, off := uint32(0), 0; i < ti.count; i, off = i+1, off+prefixTupleSize { + for i, off := uint32(0), uint64(0); i < ti.count; i, off = i+1, off+prefixTupleSize { b := ti.prefixTuples[off+hash.PrefixLen : off+prefixTupleSize] o[i] = binary.BigEndian.Uint32(b) } @@ -416,7 +418,7 @@ func (ti onHeapTableIndex) ordinals() ([]uint32, error) { func (ti onHeapTableIndex) prefixes() ([]uint64, error) { // todo: |p| is not accounted for in the memory quota p := make([]uint64, ti.count) - for i, off := uint32(0), 0; i < ti.count; i, off = i+1, off+prefixTupleSize { + for i, off := uint32(0), uint64(0); i < ti.count; i, off = i+1, off+prefixTupleSize { b := ti.prefixTuples[off : off+hash.PrefixLen] p[i] = binary.BigEndian.Uint64(b) } @@ -425,7 +427,7 @@ func (ti onHeapTableIndex) prefixes() ([]uint64, error) { func (ti onHeapTableIndex) hashAt(idx uint32) hash.Hash { // Get tuple - off := int64(prefixTupleSize * idx) + off := prefixTupleSize * int64(idx) tuple := ti.prefixTuples[off : off+prefixTupleSize] // Get prefix, ordinal, and suffix diff --git a/go/store/nbs/table_index_test.go b/go/store/nbs/table_index_test.go index 3eabd99595..cf984d581f 100644 --- a/go/store/nbs/table_index_test.go +++ b/go/store/nbs/table_index_test.go @@ -56,6 +56,66 @@ func TestParseTableIndex(t *testing.T) { } } +func TestParseLargeTableIndex(t *testing.T) { + if isRaceEnabled() { + t.SkipNow() + } + + // This is large enough for the NBS table index to overflow uint32s on certain index calculations. + numChunks := uint32(320331063) + idxSize := indexSize(numChunks) + sz := idxSize + footerSize + idxBuf := make([]byte, sz) + copy(idxBuf[idxSize+12:], magicNumber) + binary.BigEndian.PutUint32(idxBuf[idxSize:], numChunks) + binary.BigEndian.PutUint64(idxBuf[idxSize+4:], uint64(numChunks)*4*1024) + + var prefix uint64 + + off := 0 + // Write Tuples + for i := uint32(0); i < numChunks; i++ { + binary.BigEndian.PutUint64(idxBuf[off:], prefix) + binary.BigEndian.PutUint32(idxBuf[off+hash.PrefixLen:], i) + prefix += 2 + off += prefixTupleSize + } + + // Write Lengths + for i := uint32(0); i < numChunks; i++ { + binary.BigEndian.PutUint32(idxBuf[off:], 4*1024) + off += lengthSize + } + + // Write Suffixes + for i := uint32(0); i < numChunks; i++ { + off += hash.SuffixLen + } + + idx, err := parseTableIndex(context.Background(), idxBuf, &UnlimitedQuotaProvider{}) + require.NoError(t, err) + h := &hash.Hash{} + h[7] = 2 + ord, err := idx.lookupOrdinal(h) + require.NoError(t, err) + assert.Equal(t, uint32(1), ord) + h[7] = 1 + ord, err = idx.lookupOrdinal(h) + require.NoError(t, err) + assert.Equal(t, numChunks, ord) + // This is the end of the chunk, not the beginning. + assert.Equal(t, uint64(8*1024), idx.offsetAt(1)) + assert.Equal(t, uint64(2), idx.prefixAt(1)) + assert.Equal(t, uint32(1), idx.ordinalAt(1)) + h[7] = 2 + assert.Equal(t, *h, idx.hashAt(1)) + entry, ok, err := idx.lookup(h) + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, uint64(4*1024), entry.Offset()) + assert.Equal(t, uint32(4*1024), entry.Length()) +} + func BenchmarkFindPrefix(b *testing.B) { ctx := context.Background() f, err := os.Open("testdata/0oa7mch34jg1rvghrnhr4shrp2fm4ftd.idx") diff --git a/go/store/prolly/tree/diff.go b/go/store/prolly/tree/diff.go index 2fc0a9956d..9fcd841e70 100644 --- a/go/store/prolly/tree/diff.go +++ b/go/store/prolly/tree/diff.go @@ -128,6 +128,13 @@ func DifferFromCursors[K ~[]byte, O Ordering[K]]( } func (td Differ[K, O]) Next(ctx context.Context) (diff Diff, err error) { + return td.next(ctx, true) +} + +// next finds the next diff and then conditionally advances the cursors past the modified chunks. +// In most cases, we want to advance the cursors, but in some circumstances the caller may want to access the cursors +// and then advance them manually. +func (td Differ[K, O]) next(ctx context.Context, advanceCursors bool) (diff Diff, err error) { for td.from.Valid() && td.from.compare(td.fromStop) < 0 && td.to.Valid() && td.to.compare(td.toStop) < 0 { f := td.from.CurrentKey() @@ -136,16 +143,16 @@ func (td Differ[K, O]) Next(ctx context.Context) (diff Diff, err error) { switch { case cmp < 0: - return sendRemoved(ctx, td.from) + return sendRemoved(ctx, td.from, advanceCursors) case cmp > 0: - return sendAdded(ctx, td.to) + return sendAdded(ctx, td.to, advanceCursors) case cmp == 0: // If the cursor schema has changed, then all rows should be considered modified. // If the cursor schema hasn't changed, rows are modified iff their bytes have changed. if td.considerAllRowsModified || !equalcursorValues(td.from, td.to) { - return sendModified(ctx, td.from, td.to) + return sendModified(ctx, td.from, td.to, advanceCursors) } // advance both cursors since we have already determined that they are equal. This needs to be done because @@ -166,42 +173,46 @@ func (td Differ[K, O]) Next(ctx context.Context) (diff Diff, err error) { } if td.from.Valid() && td.from.compare(td.fromStop) < 0 { - return sendRemoved(ctx, td.from) + return sendRemoved(ctx, td.from, advanceCursors) } if td.to.Valid() && td.to.compare(td.toStop) < 0 { - return sendAdded(ctx, td.to) + return sendAdded(ctx, td.to, advanceCursors) } return Diff{}, io.EOF } -func sendRemoved(ctx context.Context, from *cursor) (diff Diff, err error) { +func sendRemoved(ctx context.Context, from *cursor, advanceCursors bool) (diff Diff, err error) { diff = Diff{ Type: RemovedDiff, Key: from.CurrentKey(), From: from.currentValue(), } - if err = from.advance(ctx); err != nil { - return Diff{}, err + if advanceCursors { + if err = from.advance(ctx); err != nil { + return Diff{}, err + } } return } -func sendAdded(ctx context.Context, to *cursor) (diff Diff, err error) { +func sendAdded(ctx context.Context, to *cursor, advanceCursors bool) (diff Diff, err error) { diff = Diff{ Type: AddedDiff, Key: to.CurrentKey(), To: to.currentValue(), } - if err = to.advance(ctx); err != nil { - return Diff{}, err + if advanceCursors { + if err = to.advance(ctx); err != nil { + return Diff{}, err + } } return } -func sendModified(ctx context.Context, from, to *cursor) (diff Diff, err error) { +func sendModified(ctx context.Context, from, to *cursor, advanceCursors bool) (diff Diff, err error) { diff = Diff{ Type: ModifiedDiff, Key: from.CurrentKey(), @@ -209,11 +220,13 @@ func sendModified(ctx context.Context, from, to *cursor) (diff Diff, err error) To: to.currentValue(), } - if err = from.advance(ctx); err != nil { - return Diff{}, err - } - if err = to.advance(ctx); err != nil { - return Diff{}, err + if advanceCursors { + if err = from.advance(ctx); err != nil { + return Diff{}, err + } + if err = to.advance(ctx); err != nil { + return Diff{}, err + } } return } diff --git a/go/store/prolly/tree/indexed_json_diff.go b/go/store/prolly/tree/indexed_json_diff.go new file mode 100644 index 0000000000..bf4bc2e427 --- /dev/null +++ b/go/store/prolly/tree/indexed_json_diff.go @@ -0,0 +1,293 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "context" + "io" + + "github.com/dolthub/go-mysql-server/sql/types" + "golang.org/x/exp/slices" +) + +type IndexedJsonDiffer struct { + differ Differ[jsonLocationKey, jsonLocationOrdering] + currentFromCursor, currentToCursor *JsonCursor + from, to IndexedJsonDocument + started bool +} + +var _ IJsonDiffer = &IndexedJsonDiffer{} + +func NewIndexedJsonDiffer(ctx context.Context, from, to IndexedJsonDocument) (*IndexedJsonDiffer, error) { + differ, err := DifferFromRoots[jsonLocationKey, jsonLocationOrdering](ctx, from.m.NodeStore, to.m.NodeStore, from.m.Root, to.m.Root, jsonLocationOrdering{}, false) + if err != nil { + return nil, err + } + // We want to diff the prolly tree as if it was an address map pointing to the individual blob fragments, rather + // than diffing the blob fragments themselves. We can accomplish this by just replacing the cursors in the differ + // with their parents. + differ.from = differ.from.parent + differ.to = differ.to.parent + differ.fromStop = differ.fromStop.parent + differ.toStop = differ.toStop.parent + + if differ.from == nil || differ.to == nil { + // This can happen when either document fits in a single chunk. + // We don't use the chunk differ in this case, and instead we create the cursors without it. + diffKey := []byte{byte(startOfValue)} + currentFromCursor, err := newJsonCursorAtStartOfChunk(ctx, from.m.NodeStore, from.m.Root, diffKey) + if err != nil { + return nil, err + } + currentToCursor, err := newJsonCursorAtStartOfChunk(ctx, to.m.NodeStore, to.m.Root, diffKey) + if err != nil { + return nil, err + } + return &IndexedJsonDiffer{ + differ: differ, + from: from, + to: to, + currentFromCursor: currentFromCursor, + currentToCursor: currentToCursor, + }, nil + } + + return &IndexedJsonDiffer{ + differ: differ, + from: from, + to: to, + }, nil +} + +// Next computes the next diff between the two JSON documents. +// To accomplish this, it uses the underlying Differ to find chunks that have changed, and uses a pair of JsonCursors +// to walk corresponding chunks. +func (jd *IndexedJsonDiffer) Next(ctx context.Context) (diff JsonDiff, err error) { + // helper function to advance a JsonCursor and set it to nil if it reaches the end of a chunk + advanceCursor := func(jCur **JsonCursor) (err error) { + if (*jCur).jsonScanner.atEndOfChunk() { + err = (*jCur).cur.advance(ctx) + if err != nil { + return err + } + *jCur = nil + } else { + err = (*jCur).jsonScanner.AdvanceToNextLocation() + if err != nil { + return err + } + } + return nil + } + + newAddedDiff := func(key []byte) (JsonDiff, error) { + addedValue, err := jd.currentToCursor.NextValue(ctx) + if err != nil { + return JsonDiff{}, err + } + err = advanceCursor(&jd.currentToCursor) + if err != nil { + return JsonDiff{}, err + } + return JsonDiff{ + Key: key, + To: types.NewLazyJSONDocument(addedValue), + Type: AddedDiff, + }, nil + } + + newRemovedDiff := func(key []byte) (JsonDiff, error) { + removedValue, err := jd.currentFromCursor.NextValue(ctx) + if err != nil { + return JsonDiff{}, err + } + err = advanceCursor(&jd.currentFromCursor) + if err != nil { + return JsonDiff{}, err + } + return JsonDiff{ + Key: key, + From: types.NewLazyJSONDocument(removedValue), + Type: RemovedDiff, + }, nil + } + + for { + if jd.currentFromCursor == nil && jd.currentToCursor == nil { + if jd.differ.from == nil || jd.differ.to == nil { + // One of the documents fits in a single chunk. We must have walked the entire document by now. + return JsonDiff{}, io.EOF + } + + // Either this is the first iteration, or the last iteration exhausted both chunks at the same time. + // (ie, both chunks ended at the same JSON path). We can use `Differ.Next` to seek to the next difference. + // Passing advanceCursors=false means that instead of using the returned diff, we read the cursors out of + // the differ, and advance them manually once we've walked the chunk. + _, err := jd.differ.next(ctx, false) + if err != nil { + return JsonDiff{}, err + } + + jd.currentFromCursor, err = newJsonCursorFromCursor(ctx, jd.differ.from) + if err != nil { + return JsonDiff{}, err + } + jd.currentToCursor, err = newJsonCursorFromCursor(ctx, jd.differ.to) + if err != nil { + return JsonDiff{}, err + } + } else if jd.currentFromCursor == nil { + // We exhausted the current `from` chunk but not the `to` chunk. Since the chunk boundaries don't align on + // the same key, we need to continue into the next chunk. + + jd.currentFromCursor, err = newJsonCursorFromCursor(ctx, jd.differ.from) + if err != nil { + return JsonDiff{}, err + } + + err = advanceCursor(&jd.currentFromCursor) + if err != nil { + return JsonDiff{}, err + } + continue + } else if jd.currentToCursor == nil { + // We exhausted the current `to` chunk but not the `from` chunk. Since the chunk boundaries don't align on + // the same key, we need to continue into the next chunk. + + jd.currentToCursor, err = newJsonCursorFromCursor(ctx, jd.differ.to) + if err != nil { + return JsonDiff{}, err + } + + err = advanceCursor(&jd.currentToCursor) + if err != nil { + return JsonDiff{}, err + } + continue + } + // Both cursors point to chunks that are different between the two documents. + // We must be in one of the following states: + // 1) Both cursors have the JSON path with the same values: + // - This location has not changed, advance both cursors and continue. + // 2) Both cursors have the same JSON path but different values: + // - The value at that path has been modified. + // 3) Both cursors point to the start of a value, but the paths differ: + // - A value has been inserted or deleted in the beginning/middle of an object. + // 4) One cursor points to the start of a value, while the other cursor points to the end of that value's parent: + // - A value has been inserted or deleted at the end of an object or array. + // + // The following states aren't actually possible because we will encounter state 2 first. + // 5) One cursor points to the initial element of an object/array, while the other points to the end of that same path: + // - A value has been changed from an object/array to a scalar, or vice-versa. + // 6) One cursor points to the initial element of an object, while the other points to the initial element of an array: + // - The value has been changed from an object to an array, or vice-versa. + + fromScanner := &jd.currentFromCursor.jsonScanner + toScanner := &jd.currentToCursor.jsonScanner + fromScannerAtStartOfValue := fromScanner.atStartOfValue() + toScannerAtStartOfValue := toScanner.atStartOfValue() + fromCurrentLocation := fromScanner.currentPath + toCurrentLocation := toScanner.currentPath + + if !fromScannerAtStartOfValue && !toScannerAtStartOfValue { + // Neither cursor points to the start of a value. + // This should only be possible if they're at the same location. + // Do a sanity check, then continue. + if compareJsonLocations(fromCurrentLocation, toCurrentLocation) != 0 { + return JsonDiff{}, jsonParseError + } + err = advanceCursor(&jd.currentFromCursor) + if err != nil { + return JsonDiff{}, err + } + err = advanceCursor(&jd.currentToCursor) + if err != nil { + return JsonDiff{}, err + } + continue + } + + if fromScannerAtStartOfValue && toScannerAtStartOfValue { + cmp := compareJsonLocations(fromCurrentLocation, toCurrentLocation) + switch cmp { + case 0: + key := fromCurrentLocation.Clone().key + + // Both sides have the same key. If they're both an object or both an array, continue. + // Otherwise, compare them and possibly return a modification. + if (fromScanner.current() == '{' && toScanner.current() == '{') || + (fromScanner.current() == '[' && toScanner.current() == '[') { + err = advanceCursor(&jd.currentFromCursor) + if err != nil { + return JsonDiff{}, err + } + err = advanceCursor(&jd.currentToCursor) + if err != nil { + return JsonDiff{}, err + } + continue + } + + fromValue, err := jd.currentFromCursor.NextValue(ctx) + if err != nil { + return JsonDiff{}, err + } + toValue, err := jd.currentToCursor.NextValue(ctx) + if err != nil { + return JsonDiff{}, err + } + if !slices.Equal(fromValue, toValue) { + // Case 2: The value at this path has been modified + return JsonDiff{ + Key: key, + From: types.NewLazyJSONDocument(fromValue), + To: types.NewLazyJSONDocument(toValue), + Type: ModifiedDiff, + }, nil + } + // Case 1: This location has not changed + continue + + case -1: + // Case 3: A value has been removed from an object + key := fromCurrentLocation.Clone().key + return newRemovedDiff(key) + case 1: + // Case 3: A value has been added to an object + key := toCurrentLocation.Clone().key + return newAddedDiff(key) + } + } + + if !fromScannerAtStartOfValue && toScannerAtStartOfValue { + if fromCurrentLocation.getScannerState() != endOfValue { + return JsonDiff{}, jsonParseError + } + // Case 4: A value has been inserted at the end of an object or array. + key := toCurrentLocation.Clone().key + return newAddedDiff(key) + } + + if fromScannerAtStartOfValue && !toScannerAtStartOfValue { + if toCurrentLocation.getScannerState() != endOfValue { + return JsonDiff{}, jsonParseError + } + // Case 4: A value has been removed from the end of an object or array. + key := fromCurrentLocation.Clone().key + return newRemovedDiff(key) + } + } +} diff --git a/go/store/prolly/tree/json_cursor.go b/go/store/prolly/tree/json_cursor.go index e7d6be61c2..b980fcd2ef 100644 --- a/go/store/prolly/tree/json_cursor.go +++ b/go/store/prolly/tree/json_cursor.go @@ -40,7 +40,10 @@ func getPreviousKey(ctx context.Context, cur *cursor) ([]byte, error) { if !cur2.Valid() { return nil, nil } - key := cur2.parent.CurrentKey() + key := cur2.CurrentKey() + if len(key) == 0 { + key = cur2.parent.CurrentKey() + } err = errorIfNotSupportedLocation(key) if err != nil { return nil, err @@ -53,24 +56,40 @@ func getPreviousKey(ctx context.Context, cur *cursor) ([]byte, error) { // in the document. If the location does not exist in the document, the resulting JsonCursor // will be at the location where the value would be if it was inserted. func newJsonCursor(ctx context.Context, ns NodeStore, root Node, startKey jsonLocation, forRemoval bool) (jCur *JsonCursor, found bool, err error) { - cur, err := newCursorAtKey(ctx, ns, root, startKey.key, jsonLocationOrdering{}) + jcur, err := newJsonCursorAtStartOfChunk(ctx, ns, root, startKey.key) if err != nil { return nil, false, err } + found, err = jcur.AdvanceToLocation(ctx, startKey, forRemoval) + return jcur, found, err +} + +func newJsonCursorAtStartOfChunk(ctx context.Context, ns NodeStore, root Node, startKey []byte) (jCur *JsonCursor, err error) { + cur, err := newCursorAtKey(ctx, ns, root, startKey, jsonLocationOrdering{}) + if err != nil { + return nil, err + } + return newJsonCursorFromCursor(ctx, cur) +} + +func newJsonCursorFromCursor(ctx context.Context, cur *cursor) (*JsonCursor, error) { previousKey, err := getPreviousKey(ctx, cur) if err != nil { - return nil, false, err + return nil, err + } + if !cur.isLeaf() { + nd, err := fetchChild(ctx, cur.nrw, cur.currentRef()) + if err != nil { + return nil, err + } + return newJsonCursorFromCursor(ctx, &cursor{nd: nd, parent: cur, nrw: cur.nrw}) } jsonBytes := cur.currentValue() jsonDecoder := ScanJsonFromMiddleWithKey(jsonBytes, previousKey) jcur := JsonCursor{cur: cur, jsonScanner: jsonDecoder} - found, err = jcur.AdvanceToLocation(ctx, startKey, forRemoval) - if err != nil { - return nil, found, err - } - return &jcur, found, nil + return &jcur, nil } func (j JsonCursor) Valid() bool { diff --git a/go/store/prolly/tree/json_diff.go b/go/store/prolly/tree/json_diff.go index 61f716d3be..a52856c47a 100644 --- a/go/store/prolly/tree/json_diff.go +++ b/go/store/prolly/tree/json_diff.go @@ -16,17 +16,22 @@ package tree import ( "bytes" - "fmt" + "context" "io" "reflect" "strings" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) +type IJsonDiffer interface { + Next(ctx context.Context) (JsonDiff, error) +} + type JsonDiff struct { - Key string - From, To *types.JSONDocument + Key []byte + From, To sql.JSONWrapper Type DiffType } @@ -37,31 +42,44 @@ type jsonKeyPair struct { // JsonDiffer computes the diff between two JSON objects. type JsonDiffer struct { - root string + root []byte currentFromPair, currentToPair *jsonKeyPair from, to types.JSONIter subDiffer *JsonDiffer } -func NewJsonDiffer(root string, from, to types.JsonObject) JsonDiffer { +var _ IJsonDiffer = &JsonDiffer{} + +func NewJsonDiffer(from, to types.JsonObject) *JsonDiffer { + fromIter := types.NewJSONIter(from) + toIter := types.NewJSONIter(to) + return &JsonDiffer{ + root: []byte{byte(startOfValue)}, + from: fromIter, + to: toIter, + } +} + +func (differ *JsonDiffer) newSubDiffer(key string, from, to types.JsonObject) JsonDiffer { fromIter := types.NewJSONIter(from) toIter := types.NewJSONIter(to) + newRoot := differ.appendKey(key) return JsonDiffer{ - root: root, + root: newRoot, from: fromIter, to: toIter, } } -func (differ *JsonDiffer) appendKey(key string) string { +func (differ *JsonDiffer) appendKey(key string) []byte { escapedKey := strings.Replace(key, "\"", "\\\"", -1) - return fmt.Sprintf("%s.\"%s\"", differ.root, escapedKey) + return append(append(differ.root, beginObjectKey), []byte(escapedKey)...) } -func (differ *JsonDiffer) Next() (diff JsonDiff, err error) { +func (differ *JsonDiffer) Next(ctx context.Context) (diff JsonDiff, err error) { for { if differ.subDiffer != nil { - diff, err := differ.subDiffer.Next() + diff, err := differ.subDiffer.Next(ctx) if err == io.EOF { differ.subDiffer = nil differ.currentFromPair = nil @@ -116,7 +134,7 @@ func (differ *JsonDiffer) Next() (diff JsonDiff, err error) { switch from := fromValue.(type) { case types.JsonObject: // Recursively compare the objects to generate diffs. - subDiffer := NewJsonDiffer(differ.appendKey(key), from, toValue.(types.JsonObject)) + subDiffer := differ.newSubDiffer(key, from, toValue.(types.JsonObject)) differ.subDiffer = &subDiffer continue case types.JsonArray: diff --git a/go/store/prolly/tree/json_diff_test.go b/go/store/prolly/tree/json_diff_test.go index 4fa4528825..4d95573175 100644 --- a/go/store/prolly/tree/json_diff_test.go +++ b/go/store/prolly/tree/json_diff_test.go @@ -15,34 +15,49 @@ package tree import ( + "bytes" + "context" + "fmt" "io" "testing" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/expression/function/json" + "github.com/dolthub/go-mysql-server/sql/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type jsonDiffTest struct { name string - from, to types.JsonObject + from, to sql.JSONWrapper expectedDiffs []JsonDiff } +func makeJsonPathKey(parts ...string) []byte { + result := []byte{byte(startOfValue)} + for _, part := range parts { + result = append(result, beginObjectKey) + result = append(result, []byte(part)...) + } + return result +} + var simpleJsonDiffTests = []jsonDiffTest{ { name: "empty object, no modifications", - from: types.JsonObject{}, - to: types.JsonObject{}, + from: types.JSONDocument{Val: types.JsonObject{}}, + to: types.JSONDocument{Val: types.JsonObject{}}, expectedDiffs: nil, }, { name: "insert into empty object", - from: types.JsonObject{}, - to: types.JsonObject{"a": 1}, + from: types.JSONDocument{Val: types.JsonObject{}}, + to: types.JSONDocument{Val: types.JsonObject{"a": 1}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\"", + Key: makeJsonPathKey(`a`), From: nil, To: &types.JSONDocument{Val: 1}, Type: AddedDiff, @@ -51,11 +66,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "delete from object", - from: types.JsonObject{"a": 1}, - to: types.JsonObject{}, + from: types.JSONDocument{Val: types.JsonObject{"a": 1}}, + to: types.JSONDocument{Val: types.JsonObject{}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\"", + Key: makeJsonPathKey(`a`), From: &types.JSONDocument{Val: 1}, To: nil, Type: RemovedDiff, @@ -64,11 +79,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "modify object", - from: types.JsonObject{"a": 1}, - to: types.JsonObject{"a": 2}, + from: types.JSONDocument{Val: types.JsonObject{"a": 1}}, + to: types.JSONDocument{Val: types.JsonObject{"a": 2}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\"", + Key: makeJsonPathKey(`a`), From: &types.JSONDocument{Val: 1}, To: &types.JSONDocument{Val: 2}, Type: ModifiedDiff, @@ -77,11 +92,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "nested insert", - from: types.JsonObject{"a": types.JsonObject{}}, - to: types.JsonObject{"a": types.JsonObject{"b": 1}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 1}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), To: &types.JSONDocument{Val: 1}, Type: AddedDiff, }, @@ -89,11 +104,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "nested delete", - from: types.JsonObject{"a": types.JsonObject{"b": 1}}, - to: types.JsonObject{"a": types.JsonObject{}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 1}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), From: &types.JSONDocument{Val: 1}, Type: RemovedDiff, }, @@ -101,11 +116,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "nested modify", - from: types.JsonObject{"a": types.JsonObject{"b": 1}}, - to: types.JsonObject{"a": types.JsonObject{"b": 2}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 1}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 2}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), From: &types.JSONDocument{Val: 1}, To: &types.JSONDocument{Val: 2}, Type: ModifiedDiff, @@ -114,11 +129,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "insert object", - from: types.JsonObject{"a": types.JsonObject{}}, - to: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), To: &types.JSONDocument{Val: types.JsonObject{"c": 3}}, Type: AddedDiff, }, @@ -126,11 +141,11 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "modify to object", - from: types.JsonObject{"a": types.JsonObject{"b": 2}}, - to: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 2}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), From: &types.JSONDocument{Val: 2}, To: &types.JSONDocument{Val: types.JsonObject{"c": 3}}, Type: ModifiedDiff, @@ -139,24 +154,76 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "modify from object", - from: types.JsonObject{"a": types.JsonObject{"b": 2}}, - to: types.JsonObject{"a": 1}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 2}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": 1}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\"", + Key: makeJsonPathKey(`a`), From: &types.JSONDocument{Val: types.JsonObject{"b": 2}}, To: &types.JSONDocument{Val: 1}, Type: ModifiedDiff, }, }, }, + { + name: "modify to array", + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": "foo"}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": types.JsonArray{1, 2}}}}, + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`a`, `b`), + From: &types.JSONDocument{Val: "foo"}, + To: &types.JSONDocument{Val: types.JsonArray{1, 2}}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "modify from array", + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonArray{1, 2}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": 1}}, + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`a`), + From: &types.JSONDocument{Val: types.JsonArray{1, 2}}, + To: &types.JSONDocument{Val: 1}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "array to object", + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonArray{1, 2}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}}, + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`a`), + From: &types.JSONDocument{Val: types.JsonArray{1, 2}}, + To: &types.JSONDocument{Val: types.JsonObject{"b": types.JsonObject{"c": 3}}}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "object to array", + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": 2}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonArray{1, 2}}}, + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`a`), + From: &types.JSONDocument{Val: types.JsonObject{"b": 2}}, + To: &types.JSONDocument{Val: types.JsonArray{1, 2}}, + Type: ModifiedDiff, + }, + }, + }, { name: "remove object", - from: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}, - to: types.JsonObject{"a": types.JsonObject{}}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"b": types.JsonObject{"c": 3}}}}, + to: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{}}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"a\".\"b\"", + Key: makeJsonPathKey(`a`, `b`), From: &types.JSONDocument{Val: types.JsonObject{"c": 3}}, Type: RemovedDiff, }, @@ -164,17 +231,17 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "insert escaped double quotes", - from: types.JsonObject{"\"a\"": "1"}, - to: types.JsonObject{"b": "\"2\""}, + from: types.JSONDocument{Val: types.JsonObject{"\"a\"": "1"}}, + to: types.JSONDocument{Val: types.JsonObject{"b": "\"2\""}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"\\\"a\\\"\"", + Key: makeJsonPathKey(`"a"`), From: &types.JSONDocument{Val: "1"}, To: nil, Type: RemovedDiff, }, { - Key: "$.\"b\"", + Key: makeJsonPathKey(`b`), From: nil, To: &types.JSONDocument{Val: "\"2\""}, Type: AddedDiff, @@ -183,32 +250,32 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, { name: "modifications returned in lexographic order", - from: types.JsonObject{"a": types.JsonObject{"1": "i"}, "aa": 2, "b": 6}, - to: types.JsonObject{"": 1, "a": types.JsonObject{}, "aa": 3, "bb": 5}, + from: types.JSONDocument{Val: types.JsonObject{"a": types.JsonObject{"1": "i"}, "aa": 2, "b": 6}}, + to: types.JSONDocument{Val: types.JsonObject{"": 1, "a": types.JsonObject{}, "aa": 3, "bb": 5}}, expectedDiffs: []JsonDiff{ { - Key: "$.\"\"", + Key: makeJsonPathKey(``), To: &types.JSONDocument{Val: 1}, Type: AddedDiff, }, { - Key: "$.\"a\".\"1\"", + Key: makeJsonPathKey(`a`, `1`), From: &types.JSONDocument{Val: "i"}, Type: RemovedDiff, }, { - Key: "$.\"aa\"", + Key: makeJsonPathKey(`aa`), From: &types.JSONDocument{Val: 2}, To: &types.JSONDocument{Val: 3}, Type: ModifiedDiff, }, { - Key: "$.\"b\"", + Key: makeJsonPathKey(`b`), From: &types.JSONDocument{Val: 6}, Type: RemovedDiff, }, { - Key: "$.\"bb\"", + Key: makeJsonPathKey(`bb`), To: &types.JSONDocument{Val: 5}, Type: AddedDiff, }, @@ -216,10 +283,152 @@ var simpleJsonDiffTests = []jsonDiffTest{ }, } +func largeJsonDiffTests(t *testing.T) []jsonDiffTest { + ctx := sql.NewEmptyContext() + ns := NewTestNodeStore() + + insert := func(document types.MutableJSON, path string, val interface{}) types.MutableJSON { + jsonVal, inRange, err := types.JSON.Convert(val) + require.NoError(t, err) + require.True(t, (bool)(inRange)) + newDoc, changed, err := document.Insert(ctx, path, jsonVal.(sql.JSONWrapper)) + require.NoError(t, err) + require.True(t, changed) + return newDoc + } + + set := func(document types.MutableJSON, path string, val interface{}) types.MutableJSON { + jsonVal, inRange, err := types.JSON.Convert(val) + require.NoError(t, err) + require.True(t, (bool)(inRange)) + newDoc, changed, err := document.Replace(ctx, path, jsonVal.(sql.JSONWrapper)) + require.NoError(t, err) + require.True(t, changed) + return newDoc + } + + lookup := func(document types.SearchableJSON, path string) sql.JSONWrapper { + newDoc, err := document.Lookup(ctx, path) + require.NoError(t, err) + return newDoc + } + + remove := func(document types.MutableJSON, path string) types.MutableJSON { + newDoc, changed, err := document.Remove(ctx, path) + require.True(t, changed) + require.NoError(t, err) + return newDoc + } + + largeObject := createLargeArraylessDocumentForTesting(t, ctx, ns) + return []jsonDiffTest{ + { + name: "nested insert", + from: largeObject, + to: insert(largeObject, "$.level7.newKey", 2), + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `newKey`), + From: nil, + To: &types.JSONDocument{Val: 2}, + Type: AddedDiff, + }, + }, + }, + { + name: "nested remove", + from: largeObject, + to: remove(largeObject, "$.level7.level6"), + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `level6`), + From: lookup(largeObject, "$.level7.level6"), + To: nil, + Type: RemovedDiff, + }, + }, + }, + { + name: "nested modification 1", + from: largeObject, + to: set(largeObject, "$.level7.level5", 2), + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `level5`), + From: lookup(largeObject, "$.level7.level5"), + To: &types.JSONDocument{Val: 2}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "nested modification 2", + from: largeObject, + to: set(largeObject, "$.level7.level4", 1), + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `level4`), + From: lookup(largeObject, "$.level7.level4"), + To: &types.JSONDocument{Val: 1}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "convert object to array", + from: largeObject, + to: set(largeObject, "$.level7.level6", []interface{}{}), + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `level6`), + From: lookup(largeObject, "$.level7.level6"), + To: &types.JSONDocument{Val: []interface{}{}}, + Type: ModifiedDiff, + }, + }, + }, + { + name: "convert array to object", + from: set(largeObject, "$.level7.level6", []interface{}{}), + to: largeObject, + expectedDiffs: []JsonDiff{ + { + Key: makeJsonPathKey(`level7`, `level6`), + From: &types.JSONDocument{Val: []interface{}{}}, + To: lookup(largeObject, "$.level7.level6"), + Type: ModifiedDiff, + }, + }, + }, + } +} + +// createLargeArraylessDocumentForTesting creates a JSON document large enough to be split across multiple chunks that +// does not contain arrays. This makes it easier to write tests for three-way merging, since we cant't currently merge +// concurrent changes to arrays. +func createLargeArraylessDocumentForTesting(t *testing.T, ctx *sql.Context, ns NodeStore) IndexedJsonDocument { + leafDoc := make(map[string]interface{}) + leafDoc["number"] = float64(1.0) + leafDoc["string"] = "dolt" + var docExpression sql.Expression = expression.NewLiteral(newIndexedJsonDocumentFromValue(t, ctx, ns, leafDoc), types.JSON) + var err error + + for level := 0; level < 8; level++ { + docExpression, err = json.NewJSONInsert(docExpression, expression.NewLiteral(fmt.Sprintf("$.level%d", level), types.Text), docExpression) + require.NoError(t, err) + } + doc, err := docExpression.Eval(ctx, nil) + require.NoError(t, err) + return newIndexedJsonDocumentFromValue(t, ctx, ns, doc) +} + func TestJsonDiff(t *testing.T) { t.Run("simple tests", func(t *testing.T) { runTestBatch(t, simpleJsonDiffTests) }) + t.Run("large document tests", func(t *testing.T) { + runTestBatch(t, largeJsonDiffTests(t)) + }) } func runTestBatch(t *testing.T, tests []jsonDiffTest) { @@ -231,16 +440,42 @@ func runTestBatch(t *testing.T, tests []jsonDiffTest) { } func runTest(t *testing.T, test jsonDiffTest) { - differ := NewJsonDiffer("$", test.from, test.to) + ctx := context.Background() + ns := NewTestNodeStore() + from := newIndexedJsonDocumentFromValue(t, ctx, ns, test.from) + to := newIndexedJsonDocumentFromValue(t, ctx, ns, test.to) + differ, err := NewIndexedJsonDiffer(ctx, from, to) + require.NoError(t, err) var actualDiffs []JsonDiff for { - diff, err := differ.Next() + diff, err := differ.Next(ctx) if err == io.EOF { break } - assert.NoError(t, err) + require.NoError(t, err) actualDiffs = append(actualDiffs, diff) } - require.Equal(t, test.expectedDiffs, actualDiffs) + diffsEqual := func(expected, actual JsonDiff) bool { + if expected.Type != actual.Type { + return false + } + if !bytes.Equal(expected.Key, actual.Key) { + return false + } + cmp, err := types.CompareJSON(expected.From, actual.From) + require.NoError(t, err) + if cmp != 0 { + return false + } + cmp, err = types.CompareJSON(expected.To, actual.To) + require.NoError(t, err) + + return cmp == 0 + } + require.Equal(t, len(test.expectedDiffs), len(actualDiffs)) + for i, expected := range test.expectedDiffs { + actual := actualDiffs[i] + require.True(t, diffsEqual(expected, actual), fmt.Sprintf("Expected: %v\nActual: %v", expected, actual)) + } } diff --git a/go/store/prolly/tree/json_indexed_document.go b/go/store/prolly/tree/json_indexed_document.go index 0c8c3eb38c..b0ac82f43b 100644 --- a/go/store/prolly/tree/json_indexed_document.go +++ b/go/store/prolly/tree/json_indexed_document.go @@ -19,9 +19,11 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "io" "sync" "github.com/dolthub/go-mysql-server/sql" + sqljson "github.com/dolthub/go-mysql-server/sql/expression/function/json" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -205,7 +207,7 @@ func (i IndexedJsonDocument) tryInsert(ctx context.Context, path string, val sql return i.insertIntoCursor(ctx, keyPath, jsonCursor, val) } -func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonLocation, jsonCursor *JsonCursor, val sql.JSONWrapper) (types.MutableJSON, bool, error) { +func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonLocation, jsonCursor *JsonCursor, val sql.JSONWrapper) (IndexedJsonDocument, bool, error) { cursorPath := jsonCursor.GetCurrentPath() // If the inserted path is equivalent to "$" (which also includes "$[0]" on non-arrays), do nothing. @@ -241,7 +243,7 @@ func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonL jsonChunker, err := newJsonChunker(ctx, jsonCursor, i.m.NodeStore) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } originalValue, err := jsonCursor.NextValue(ctx) @@ -251,15 +253,18 @@ func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonL insertedValueBytes, err := types.MarshallJson(val) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } jsonChunker.appendJsonToBuffer([]byte(fmt.Sprintf("[%s,%s]", originalValue, insertedValueBytes))) - jsonChunker.processBuffer(ctx) + err = jsonChunker.processBuffer(ctx) + if err != nil { + return IndexedJsonDocument{}, false, err + } newRoot, err := jsonChunker.Done(ctx) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } return NewIndexedJsonDocument(ctx, newRoot, i.m.NodeStore), true, nil @@ -285,14 +290,14 @@ func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonL insertedValueBytes, err := types.MarshallJson(val) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } // The key is guaranteed to not exist in the source doc. The cursor is pointing to the start of the subsequent object, // which will be the insertion point for the added value. jsonChunker, err := newJsonChunker(ctx, jsonCursor, i.m.NodeStore) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } // If required, adds a comma before writing the value. @@ -302,18 +307,21 @@ func (i IndexedJsonDocument) insertIntoCursor(ctx context.Context, keyPath jsonL // If the value is a newly inserted key, write the key. if !keyLastPathElement.isArrayIndex { - jsonChunker.appendJsonToBuffer([]byte(fmt.Sprintf(`"%s":`, keyLastPathElement.key))) + jsonChunker.appendJsonToBuffer([]byte(fmt.Sprintf(`"%s":`, escapeKey(keyLastPathElement.key)))) } // Manually set the chunker's path and offset to the start of the value we're about to insert. jsonChunker.jScanner.valueOffset = len(jsonChunker.jScanner.jsonBuffer) jsonChunker.jScanner.currentPath = keyPath jsonChunker.appendJsonToBuffer(insertedValueBytes) - jsonChunker.processBuffer(ctx) + err = jsonChunker.processBuffer(ctx) + if err != nil { + return IndexedJsonDocument{}, false, err + } newRoot, err := jsonChunker.Done(ctx) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } return NewIndexedJsonDocument(ctx, newRoot, i.m.NodeStore), true, nil @@ -343,10 +351,17 @@ func (i IndexedJsonDocument) tryRemove(ctx context.Context, path string) (types. if err != nil { return nil, false, err } + return i.removeWithLocation(ctx, keyPath) +} + +func (i IndexedJsonDocument) RemoveWithKey(ctx context.Context, key []byte) (IndexedJsonDocument, bool, error) { + return i.removeWithLocation(ctx, jsonPathFromKey(key)) +} +func (i IndexedJsonDocument) removeWithLocation(ctx context.Context, keyPath jsonLocation) (IndexedJsonDocument, bool, error) { jsonCursor, found, err := newJsonCursor(ctx, i.m.NodeStore, i.m.Root, keyPath, true) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } if !found { // The key does not exist in the document. @@ -356,7 +371,7 @@ func (i IndexedJsonDocument) tryRemove(ctx context.Context, path string) (types. // The cursor is now pointing to the end of the value prior to the one being removed. jsonChunker, err := newJsonChunker(ctx, jsonCursor, i.m.NodeStore) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } startofRemovedLocation := jsonCursor.GetCurrentPath() @@ -367,7 +382,7 @@ func (i IndexedJsonDocument) tryRemove(ctx context.Context, path string) (types. keyPath.setScannerState(endOfValue) _, err = jsonCursor.AdvanceToLocation(ctx, keyPath, false) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } // If removing the first element of an object/array, skip past the comma, and set the chunker as if it's @@ -379,7 +394,7 @@ func (i IndexedJsonDocument) tryRemove(ctx context.Context, path string) (types. newRoot, err := jsonChunker.Done(ctx) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } return NewIndexedJsonDocument(ctx, newRoot, i.m.NodeStore), true, nil @@ -406,10 +421,17 @@ func (i IndexedJsonDocument) trySet(ctx context.Context, path string, val sql.JS if err != nil { return nil, false, err } + return i.setWithLocation(ctx, keyPath, val) +} +func (i IndexedJsonDocument) SetWithKey(ctx context.Context, key []byte, val sql.JSONWrapper) (IndexedJsonDocument, bool, error) { + return i.setWithLocation(ctx, jsonPathFromKey(key), val) +} + +func (i IndexedJsonDocument) setWithLocation(ctx context.Context, keyPath jsonLocation, val sql.JSONWrapper) (IndexedJsonDocument, bool, error) { jsonCursor, found, err := newJsonCursor(ctx, i.m.NodeStore, i.m.Root, keyPath, false) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } // The supplied path may be 0-indexing into a scalar, which is the same as referencing the scalar. Remove @@ -480,32 +502,35 @@ func (i IndexedJsonDocument) tryReplace(ctx context.Context, path string, val sq return i.replaceIntoCursor(ctx, keyPath, jsonCursor, val) } -func (i IndexedJsonDocument) replaceIntoCursor(ctx context.Context, keyPath jsonLocation, jsonCursor *JsonCursor, val sql.JSONWrapper) (types.MutableJSON, bool, error) { +func (i IndexedJsonDocument) replaceIntoCursor(ctx context.Context, keyPath jsonLocation, jsonCursor *JsonCursor, val sql.JSONWrapper) (IndexedJsonDocument, bool, error) { // The cursor is now pointing to the start of the value being replaced. jsonChunker, err := newJsonChunker(ctx, jsonCursor, i.m.NodeStore) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } // Advance the cursor to the end of the value being removed. keyPath.setScannerState(endOfValue) _, err = jsonCursor.AdvanceToLocation(ctx, keyPath, false) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } insertedValueBytes, err := types.MarshallJson(val) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } jsonChunker.appendJsonToBuffer(insertedValueBytes) - jsonChunker.processBuffer(ctx) + err = jsonChunker.processBuffer(ctx) + if err != nil { + return IndexedJsonDocument{}, false, err + } newRoot, err := jsonChunker.Done(ctx) if err != nil { - return nil, false, err + return IndexedJsonDocument{}, false, err } return NewIndexedJsonDocument(ctx, newRoot, i.m.NodeStore), true, nil @@ -548,3 +573,137 @@ func (i IndexedJsonDocument) GetBytes() (bytes []byte, err error) { // TODO: Add context parameter to JSONBytes.GetBytes return getBytesFromIndexedJsonMap(i.ctx, i.m) } + +func (i IndexedJsonDocument) getFirstCharacter(ctx context.Context) (byte, error) { + stopIterationError := fmt.Errorf("stop") + var firstCharacter byte + err := i.m.WalkNodes(ctx, func(ctx context.Context, nd Node) error { + if nd.IsLeaf() { + firstCharacter = nd.GetValue(0)[0] + return stopIterationError + } + return nil + }) + if err != stopIterationError { + return 0, err + } + return firstCharacter, nil +} + +func (i IndexedJsonDocument) getTypeCategory() (jsonTypeCategory, error) { + firstCharacter, err := i.getFirstCharacter(i.ctx) + if err != nil { + return 0, err + } + return getTypeCategoryFromFirstCharacter(firstCharacter), nil +} + +func GetTypeCategory(wrapper sql.JSONWrapper) (jsonTypeCategory, error) { + switch doc := wrapper.(type) { + case IndexedJsonDocument: + return doc.getTypeCategory() + case *types.LazyJSONDocument: + return getTypeCategoryFromFirstCharacter(doc.Bytes[0]), nil + default: + val, err := doc.ToInterface() + if err != nil { + return 0, err + } + return getTypeCategoryOfValue(val) + } +} + +// Type implements types.ComparableJson +func (i IndexedJsonDocument) Type(ctx context.Context) (string, error) { + firstCharacter, err := i.getFirstCharacter(ctx) + if err != nil { + return "", err + } + + switch firstCharacter { + case '{': + return "OBJECT", nil + case '[': + return "ARRAY", nil + } + // At this point the value must be a scalar, so it's okay to just load the whole thing. + val, err := i.ToInterface() + if err != nil { + return "", err + } + return sqljson.TypeOfJsonValue(val), nil +} + +// Compare implements types.ComparableJson +func (i IndexedJsonDocument) Compare(other interface{}) (int, error) { + thisTypeCategory, err := i.getTypeCategory() + if err != nil { + return 0, err + } + + otherIndexedDocument, ok := other.(IndexedJsonDocument) + if !ok { + val, err := i.ToInterface() + if err != nil { + return 0, err + } + otherVal := other + if otherWrapper, ok := other.(sql.JSONWrapper); ok { + otherVal, err = otherWrapper.ToInterface() + if err != nil { + return 0, err + } + } + return types.CompareJSON(val, otherVal) + } + + otherTypeCategory, err := otherIndexedDocument.getTypeCategory() + if err != nil { + return 0, err + } + if thisTypeCategory < otherTypeCategory { + return -1, nil + } + if thisTypeCategory > otherTypeCategory { + return 1, nil + } + switch thisTypeCategory { + case jsonTypeNull: + return 0, nil + case jsonTypeArray, jsonTypeObject: + // To compare two values that are both arrays or both objects, we must locate the first location + // where they differ. + + jsonDiffer, err := NewIndexedJsonDiffer(i.ctx, i, otherIndexedDocument) + if err != nil { + return 0, err + } + firstDiff, err := jsonDiffer.Next(i.ctx) + if err == io.EOF { + // The two documents have no differences. + return 0, nil + } + if err != nil { + return 0, err + } + switch firstDiff.Type { + case AddedDiff: + // A key is present in other but not this. + return -1, nil + case RemovedDiff: + return 1, nil + case ModifiedDiff: + // Since both modified values have already been loaded into memory, + // We can just compare them. + return types.JSON.Compare(firstDiff.From, firstDiff.To) + default: + panic("Impossible diff type") + } + default: + val, err := i.ToInterface() + if err != nil { + return 0, err + } + return types.CompareJSON(val, other) + } +} diff --git a/go/store/prolly/tree/json_indexed_document_test.go b/go/store/prolly/tree/json_indexed_document_test.go index 5a9b7e0e0a..7a7adc5637 100644 --- a/go/store/prolly/tree/json_indexed_document_test.go +++ b/go/store/prolly/tree/json_indexed_document_test.go @@ -27,6 +27,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression/function/json/jsontests" "github.com/dolthub/go-mysql-server/sql/types" + typetests "github.com/dolthub/go-mysql-server/sql/types/jsontests" "github.com/stretchr/testify/require" ) @@ -301,3 +302,156 @@ func TestIndexedJsonDocument_ContainsPath(t *testing.T) { testCases := jsontests.JsonContainsPathTestCases(t, convertToIndexedJsonDocument) jsontests.RunJsonTests(t, testCases) } + +func TestJsonCompare(t *testing.T) { + ctx := sql.NewEmptyContext() + ns := NewTestNodeStore() + convertToIndexedJsonDocument := func(t *testing.T, left, right interface{}) (interface{}, interface{}) { + if left != nil { + left = newIndexedJsonDocumentFromValue(t, ctx, ns, left) + } + if right != nil { + right = newIndexedJsonDocumentFromValue(t, ctx, ns, right) + } + return left, right + } + convertOnlyLeftToIndexedJsonDocument := func(t *testing.T, left, right interface{}) (interface{}, interface{}) { + if left != nil { + left = newIndexedJsonDocumentFromValue(t, ctx, ns, left) + } + if right != nil { + rightJSON, inRange, err := types.JSON.Convert(right) + require.NoError(t, err) + require.True(t, bool(inRange)) + rightInterface, err := rightJSON.(sql.JSONWrapper).ToInterface() + require.NoError(t, err) + right = types.JSONDocument{Val: rightInterface} + } + return left, right + } + convertOnlyRightToIndexedJsonDocument := func(t *testing.T, left, right interface{}) (interface{}, interface{}) { + right, left = convertOnlyLeftToIndexedJsonDocument(t, right, left) + return left, right + } + + t.Run("small documents", func(t *testing.T) { + tests := append(typetests.JsonCompareTests, typetests.JsonCompareNullsTests...) + t.Run("compare two indexed json documents", func(t *testing.T) { + typetests.RunJsonCompareTests(t, tests, convertToIndexedJsonDocument) + }) + t.Run("compare indexed json document with non-indexed", func(t *testing.T) { + typetests.RunJsonCompareTests(t, tests, convertOnlyLeftToIndexedJsonDocument) + }) + t.Run("compare non-indexed json document with indexed", func(t *testing.T) { + typetests.RunJsonCompareTests(t, tests, convertOnlyRightToIndexedJsonDocument) + }) + }) + + noError := func(j types.MutableJSON, changed bool, err error) types.MutableJSON { + require.NoError(t, err) + require.True(t, changed) + return j + } + + largeArray := createLargeDocumentForTesting(t, ctx, ns) + largeObjectWrapper, err := largeArray.Lookup(ctx, "$[7]") + largeObject := newIndexedJsonDocumentFromValue(t, ctx, ns, largeObjectWrapper) + require.NoError(t, err) + largeDocTests := []typetests.JsonCompareTest{ + { + Name: "large object < boolean", + Left: largeObject, + Right: true, + Cmp: -1, + }, + { + Name: "large object > string", + Left: largeObject, + Right: `"test"`, + Cmp: 1, + }, + { + Name: "large object > number", + Left: largeObject, + Right: 1, + Cmp: 1, + }, + { + Name: "large object > null", + Left: largeObject, + Right: `null`, + Cmp: 1, + }, + { + Name: "inserting into beginning of object makes it greater", + Left: largeObject, + Right: noError(largeObject.Insert(ctx, "$.a", types.MustJSON("1"))), + Cmp: -1, + }, + { + Name: "inserting into end of object makes it greater", + Left: largeObject, + Right: noError(largeObject.Insert(ctx, "$.z", types.MustJSON("1"))), + Cmp: -1, + }, + { + Name: "large array < boolean", + Left: largeArray, + Right: true, + Cmp: -1, + }, + { + Name: "large array > string", + Left: largeArray, + Right: `"test"`, + Cmp: 1, + }, + { + Name: "large array > number", + Left: largeArray, + Right: 1, + Cmp: 1, + }, + { + Name: "large array > null", + Left: largeArray, + Right: `null`, + Cmp: 1, + }, + { + Name: "inserting into end of array makes it greater", + Left: largeArray, + Right: noError(largeArray.ArrayAppend("$", types.MustJSON("1"))), + Cmp: -1, + }, + { + Name: "inserting high value into beginning of array makes it greater", + Left: largeArray, + Right: noError(largeArray.ArrayInsert("$[0]", types.MustJSON("true"))), + Cmp: -1, + }, + { + Name: "inserting low value into beginning of array makes it less", + Left: largeArray, + Right: noError(largeArray.ArrayInsert("$[0]", types.MustJSON("1"))), + Cmp: 1, + }, + { + Name: "large array > large object", + Left: largeArray, + Right: largeObject, + Cmp: 1, + }, + } + t.Run("large documents", func(t *testing.T) { + t.Run("compare two indexed json documents", func(t *testing.T) { + typetests.RunJsonCompareTests(t, largeDocTests, convertToIndexedJsonDocument) + }) + t.Run("compare indexed json document with non-indexed", func(t *testing.T) { + typetests.RunJsonCompareTests(t, largeDocTests, convertOnlyLeftToIndexedJsonDocument) + }) + t.Run("compare non-indexed json document with indexed", func(t *testing.T) { + typetests.RunJsonCompareTests(t, largeDocTests, convertOnlyRightToIndexedJsonDocument) + }) + }) +} diff --git a/go/store/prolly/tree/json_location.go b/go/store/prolly/tree/json_location.go index b6c8a413be..c532f07ac9 100644 --- a/go/store/prolly/tree/json_location.go +++ b/go/store/prolly/tree/json_location.go @@ -186,10 +186,31 @@ const ( lexStateEscapedQuotedKey lexState = 5 ) +func escapeKey(key []byte) []byte { + return bytes.Replace(key, []byte(`"`), []byte(`\"`), -1) +} + func unescapeKey(key []byte) []byte { return bytes.Replace(key, []byte(`\"`), []byte(`"`), -1) } +// IsJsonKeyPrefix computes whether one key encodes a json location that is a prefix of another. +// Example: $.a is a prefix of $.a.b, but not $.aa +func IsJsonKeyPrefix(path, prefix []byte) bool { + return bytes.HasPrefix(path, prefix) && (path[len(prefix)] == beginArrayKey || path[len(prefix)] == beginObjectKey) +} + +func JsonKeysModifySameArray(leftKey, rightKey []byte) bool { + i := 0 + for i < len(leftKey) && i < len(rightKey) && leftKey[i] == rightKey[i] { + if leftKey[i] == beginArrayKey { + return true + } + i++ + } + return false +} + func jsonPathElementsFromMySQLJsonPath(pathBytes []byte) (jsonLocation, error) { location := newRootLocation() state := lexStatePath @@ -417,6 +438,14 @@ type jsonLocationOrdering struct{} var _ Ordering[[]byte] = jsonLocationOrdering{} func (jsonLocationOrdering) Compare(left, right []byte) int { + // A JSON document that fits entirely in a single chunk has no keys, + if len(left) == 0 && len(right) == 0 { + return 0 + } else if len(left) == 0 { + return -1 + } else if len(right) == 0 { + return 1 + } leftPath := jsonPathFromKey(left) rightPath := jsonPathFromKey(right) return compareJsonLocations(leftPath, rightPath) diff --git a/go/store/prolly/tree/json_scanner.go b/go/store/prolly/tree/json_scanner.go index 413b44ca45..8a7ffcf28f 100644 --- a/go/store/prolly/tree/json_scanner.go +++ b/go/store/prolly/tree/json_scanner.go @@ -160,7 +160,7 @@ func (s *JsonScanner) acceptValue() error { const endOfFile byte = 0xFF // current returns the current byte being parsed, or 0xFF if we've reached the end of the file. -// (Since the JSON is UTF-8, the 0xFF byte cannot otherwise appear within in.) +// (Since the JSON is UTF-8, the 0xFF byte cannot otherwise appear within it.) func (s JsonScanner) current() byte { if s.valueOffset >= len(s.jsonBuffer) { return endOfFile diff --git a/go/store/prolly/tree/json_type_categories.go b/go/store/prolly/tree/json_type_categories.go new file mode 100644 index 0000000000..ac640e2c1f --- /dev/null +++ b/go/store/prolly/tree/json_type_categories.go @@ -0,0 +1,78 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/shopspring/decimal" +) + +type jsonTypeCategory int + +const ( + jsonTypeNull jsonTypeCategory = iota + jsonTypeNumber + jsonTypeString + jsonTypeObject + jsonTypeArray + jsonTypeBoolean +) + +func getTypeCategoryOfValue(val interface{}) (jsonTypeCategory, error) { + if val == nil { + return jsonTypeNull, nil + } + switch val.(type) { + case map[string]interface{}: + return jsonTypeObject, nil + case []interface{}: + return jsonTypeArray, nil + case bool: + return jsonTypeBoolean, nil + case string: + return jsonTypeString, nil + case decimal.Decimal, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64: + return jsonTypeNumber, nil + } + return 0, fmt.Errorf("expected json value, got %v", val) +} + +// getTypeCategoryFromFirstCharacter returns the type of a JSON object by inspecting its first byte. +func getTypeCategoryFromFirstCharacter(c byte) jsonTypeCategory { + switch c { + case '{': + return jsonTypeObject + case '[': + return jsonTypeArray + case 'n': + return jsonTypeNull + case 't', 'f': + return jsonTypeBoolean + case '"': + return jsonTypeString + default: + return jsonTypeNumber + } +} + +func IsJsonObject(json sql.JSONWrapper) (bool, error) { + valType, err := GetTypeCategory(json) + if err != nil { + return false, err + } + return valType == jsonTypeObject, nil +} diff --git a/go/utils/publishrelease/install.sh b/go/utils/publishrelease/install.sh index 6bd2f2bf72..b212efe0dc 100644 --- a/go/utils/publishrelease/install.sh +++ b/go/utils/publishrelease/install.sh @@ -17,11 +17,11 @@ _() { set -euo pipefail -DOLT_VERSION=__DOLT_VERSION__ -RELEASES_BASE_URL=https://github.com/dolthub/dolt/releases/download/v"$DOLT_VERSION" -INSTALL_URL=$RELEASES_BASE_URL/install.sh +DOLT_VERSION='__DOLT_VERSION__' +RELEASES_BASE_URL="https://github.com/dolthub/dolt/releases/download/v$DOLT_VERSION" +INSTALL_URL="$RELEASES_BASE_URL/install.sh" -CURL_USER_AGENT=${CURL_USER_AGENT:-dolt-installer} +CURL_USER_AGENT="${CURL_USER_AGENT:-dolt-installer}" OS= ARCH= @@ -30,84 +30,88 @@ WORK_DIR= PLATFORM_TUPLE= error() { - if [ $# != 0 ]; then - echo -e "\e[0;31m""$@""\e[0m" >&2 + if [ "$#" != 0 ]; then + printf '\e[0;31m%s\e[0m\n' "$*" >&2 fi } fail() { local error_code="$1" shift - echo "*** INSTALLATION FAILED ***" >&2 - echo "" >&2 + echo '*** INSTALLATION FAILED ***' >&2 + echo '' >&2 error "$@" - echo "" >&2 + echo '' >&2 exit 1 } assert_linux_or_macos() { - OS=`uname` - ARCH=`uname -m` - if [ "$OS" != Linux -a "$OS" != Darwin ]; then - fail "E_UNSUPPORTED_OS" "dolt install.sh only supports macOS and Linux." + OS="$(uname)" + ARCH="$(uname -m)" + if [ "$OS" != 'Linux' ] && [ "$OS" != 'Darwin' ]; then + fail 'E_UNSUPPORTED_OS' 'dolt install.sh only supports macOS and Linux.' fi # Translate aarch64 to arm64, since that's what GOARCH calls it - if [ "$ARCH" == "aarch64" ]; then - ARCH="arm64" + if [ "$ARCH" == 'aarch64' ]; then + ARCH='arm64' fi - if [ "$ARCH-$OS" != "x86_64-Linux" -a "$ARCH-$OS" != "x86_64-Darwin" -a "$ARCH-$OS" != "arm64-Darwin" -a "$ARCH-$OS" != "arm64-Linux" ]; then - fail "E_UNSUPPOSED_ARCH" "dolt install.sh only supports installing dolt on x86_64, x86, Linux-aarch64, or Darwin-arm64." + if [ "$ARCH-$OS" != 'x86_64-Linux' ] && [ "$ARCH-$OS" != 'x86_64-Darwin' ] && [ "$ARCH-$OS" != 'arm64-Linux' ] && [ "$ARCH-$OS" != 'arm64-Darwin' ]; then + fail 'E_UNSUPPOSED_ARCH' 'dolt install.sh only supports installing dolt on Linux-x86_64, Darwin-x86_64, Linux-aarch64, or Darwin-arm64.' fi - if [ "$OS" == Linux ]; then + if [ "$OS" == 'Linux' ]; then PLATFORM_TUPLE=linux else PLATFORM_TUPLE=darwin fi - if [ "$ARCH" == x86_64 ]; then - PLATFORM_TUPLE=$PLATFORM_TUPLE-amd64 - elif [ "$ARCH" == arm64 ]; then - PLATFORM_TUPLE=$PLATFORM_TUPLE-arm64 + + if [ "$ARCH" == 'x86_64' ]; then + PLATFORM_TUPLE="$PLATFORM_TUPLE-amd64" + else + PLATFORM_TUPLE="$PLATFORM_TUPLE-arm64" fi } assert_dependencies() { - type -p curl > /dev/null || fail "E_CURL_MISSING" "Please install curl(1)." - type -p tar > /dev/null || fail "E_TAR_MISSING" "Please install tar(1)." - type -p uname > /dev/null || fail "E_UNAME_MISSING" "Please install uname(1)." - type -p install > /dev/null || fail "E_INSTALL_MISSING" "Please install install(1)." - type -p mktemp > /dev/null || fail "E_MKTEMP_MISSING" "Please install mktemp(1)." + type -p curl > /dev/null || fail 'E_CURL_MISSING' 'Please install curl(1).' + type -p tar > /dev/null || fail 'E_TAR_MISSING' 'Please install tar(1).' + type -p uname > /dev/null || fail 'E_UNAME_MISSING' 'Please install uname(1).' + type -p install > /dev/null || fail 'E_INSTALL_MISSING' 'Please install install(1).' + type -p mktemp > /dev/null || fail 'E_MKTEMP_MISSING' 'Please install mktemp(1).' } assert_uid_zero() { - uid=`id -u` + uid="$(id -u)" if [ "$uid" != 0 ]; then - fail "E_UID_NONZERO" "dolt install.sh must run as root; please try running with sudo or running\n\`curl $INSTALL_URL | sudo bash\`." + fail 'E_UID_NONZERO' "dolt install.sh must run as root; please try running with sudo or running\n\`curl $INSTALL_URL | sudo bash\`." fi } create_workdir() { - WORK_DIR=`mktemp -d -t dolt-installer.XXXXXX` + WORK_DIR="$(mktemp -d -t dolt-installer.XXXXXX)" cleanup() { rm -rf "$WORK_DIR" } + trap cleanup EXIT cd "$WORK_DIR" } install_binary_release() { - local FILE=dolt-$PLATFORM_TUPLE.tar.gz - local URL=$RELEASES_BASE_URL/$FILE - echo "Downloading:" $URL + local FILE="dolt-$PLATFORM_TUPLE.tar.gz" + local URL="$RELEASES_BASE_URL/$FILE" + + echo "Downloading: $URL" curl -A "$CURL_USER_AGENT" -fsL "$URL" > "$FILE" tar zxf "$FILE" - echo "Installing dolt /usr/local/bin." - [ -d /usr/local/bin ] || install -o 0 -g 0 -d /usr/local/bin - install -o 0 -g 0 dolt-$PLATFORM_TUPLE/bin/dolt /usr/local/bin + + echo 'Installing dolt into /usr/local/bin.' + [ ! -d /usr/local/bin ] && install -o 0 -g 0 -d /usr/local/bin + install -o 0 -g 0 "dolt-$PLATFORM_TUPLE/bin/dolt" /usr/local/bin install -o 0 -g 0 -d /usr/local/share/doc/dolt/ - install -o 0 -g 0 -m 644 dolt-$PLATFORM_TUPLE/LICENSES /usr/local/share/doc/dolt/ + install -o 0 -g 0 -m 644 "dolt-$PLATFORM_TUPLE/LICENSES" /usr/local/share/doc/dolt/ } assert_linux_or_macos diff --git a/integration-tests/bats/cherry-pick.bats b/integration-tests/bats/cherry-pick.bats index 79a3cba0c3..8421af0e4e 100644 --- a/integration-tests/bats/cherry-pick.bats +++ b/integration-tests/bats/cherry-pick.bats @@ -75,12 +75,19 @@ teardown() { [[ "$output" =~ "ancestor" ]] || false } -@test "cherry-pick: no changes" { +@test "cherry-pick: empty commit handling" { dolt commit --allow-empty -am "empty commit" dolt checkout main + + # If an empty commit is cherry-picked, Git will stop the cherry-pick and allow you to manually commit it + # with the --allow-empty flag. We don't support that yet, so instead, empty commits generate an error. run dolt cherry-pick branch1 + [ "$status" -eq "1" ] + [[ "$output" =~ "The previous cherry-pick commit is empty. Use --allow-empty to cherry-pick empty commits." ]] || false + + # If the --allow-empty flag is specified, then empty commits can be automatically cherry-picked. + run dolt cherry-pick --allow-empty branch1 [ "$status" -eq "0" ] - [[ "$output" =~ "No changes were made" ]] || false } @test "cherry-pick: invalid hash" { diff --git a/integration-tests/bats/log.bats b/integration-tests/bats/log.bats index f5a1b878b0..245f70d1ad 100755 --- a/integration-tests/bats/log.bats +++ b/integration-tests/bats/log.bats @@ -844,10 +844,11 @@ export NO_COLOR=1 [[ "${lines[7]}" =~ "Author:" ]] || false # Author: [[ "${lines[8]}" =~ "Date:" ]] || false # Date: [[ "${lines[9]}" =~ "Initialize data repository" ]] || false # Initialize data repository + [[ ! "${lines[9]}" =~ "%!(EXTRA string=" ]] || false run dolt log --graph --oneline [ "$status" -eq 0 ] - + [[ "${lines[0]}" =~ \* ]] || false [[ ! "$output" =~ "Author" ]] || false [[ ! "$output" =~ "Date" ]] || false diff --git a/integration-tests/bats/ls.bats b/integration-tests/bats/ls.bats index 498bff7cc6..6841f4d374 100755 --- a/integration-tests/bats/ls.bats +++ b/integration-tests/bats/ls.bats @@ -60,7 +60,7 @@ teardown() { @test "ls: --system shows system tables" { run dolt ls --system [ "$status" -eq 0 ] - [ "${#lines[@]}" -eq 20 ] + [ "${#lines[@]}" -eq 22 ] [[ "$output" =~ "System tables:" ]] || false [[ "$output" =~ "dolt_status" ]] || false [[ "$output" =~ "dolt_commits" ]] || false @@ -81,6 +81,8 @@ teardown() { [[ "$output" =~ "dolt_conflicts_table_two" ]] || false [[ "$output" =~ "dolt_diff_table_two" ]] || false [[ "$output" =~ "dolt_commit_diff_table_two" ]] || false + [[ "$output" =~ "dolt_workspace_table_one" ]] || false + [[ "$output" =~ "dolt_workspace_table_two" ]] || false } @test "ls: --all shows tables in working set and system tables" { diff --git a/integration-tests/bats/no-repo.bats b/integration-tests/bats/no-repo.bats index 7b19c59983..3a030fbb70 100755 --- a/integration-tests/bats/no-repo.bats +++ b/integration-tests/bats/no-repo.bats @@ -422,3 +422,46 @@ NOT_VALID_REPO_ERROR="The current directory is not a valid dolt repository." [ "$status" -eq 1 ] [[ "$output" =~ "Unknown Command notarealcommand" ]] || false } + +@test "no-repo: the global dolt directory is not accessible due to permissions" { + noPermissionsDir=$(mktemp -d -t noPermissions-XXXX) + chmod 000 $noPermissionsDir + DOLT_ROOT_PATH=$noPermissionsDir + + run dolt version + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false + + run dolt sql + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false + + run dolt sql-server + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false +} + +@test "no-repo: the global dolt directory is accessible, but not writable" { + noPermissionsDir=$(mktemp -d -t noPermissions-XXXX) + chmod 000 $noPermissionsDir + chmod a+x $noPermissionsDir + DOLT_ROOT_PATH=$noPermissionsDir + + run dolt version + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false + + run dolt sql + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false + + run dolt sql-server + [ "$status" -eq 1 ] + [[ "$output" =~ "Failed to load the global config" ]] || false + [[ "$output" =~ "permission denied" ]] || false +} diff --git a/integration-tests/bats/rebase.bats b/integration-tests/bats/rebase.bats index 67795e43ca..e0b35a1803 100755 --- a/integration-tests/bats/rebase.bats +++ b/integration-tests/bats/rebase.bats @@ -4,17 +4,14 @@ load $BATS_TEST_DIRNAME/helper/common.bash setup() { setup_common dolt sql -q "CREATE table t1 (pk int primary key, c int);" - dolt add t1 - dolt commit -m "main commit 1" + dolt commit -Am "main commit 1" dolt branch b1 dolt sql -q "INSERT INTO t1 VALUES (1,1);" - dolt add t1 - dolt commit -m "main commit 2" + dolt commit -am "main commit 2" dolt checkout b1 dolt sql -q "CREATE table t2 (pk int primary key);" - dolt add t2 - dolt commit -m "b1 commit 1" + dolt commit -Am "b1 commit 1" dolt checkout main } @@ -380,3 +377,43 @@ setupCustomEditorScript() { [ "$status" -eq 0 ] ! [[ "$output" =~ "dolt_rebase_b1" ]] || false } + +@test "rebase: rebase with commits that become empty" { + setupCustomEditorScript + + # Apply the same change to b1 that was applied to main in it's most recent commit + # and tag the tip of b1, so we can go reset back to this commit + dolt checkout b1 + dolt sql -q "INSERT INTO t1 VALUES (1,1);" + dolt commit -am "repeating change from main on b1" + dolt tag testStartPoint + + # By default, dolt will drop the empty commit + run dolt rebase -i main + [ "$status" -eq 0 ] + [[ "$output" =~ "Successfully rebased and updated refs/heads/b1" ]] || false + + # Make sure the commit that became empty doesn't appear in the commit log + run dolt log + [[ ! $output =~ "repeating change from main on b1" ]] || false + + # Reset back to the test start point and repeat the rebase with --empty=drop (the default) + dolt reset --hard testStartPoint + run dolt rebase -i --empty=drop main + [ "$status" -eq 0 ] + [[ "$output" =~ "Successfully rebased and updated refs/heads/b1" ]] || false + + # Make sure the commit that became empty does NOT appear in the commit log + run dolt log + [[ ! $output =~ "repeating change from main on b1" ]] || false + + # Reset back to the test start point and repeat the rebase with --empty=keep + dolt reset --hard testStartPoint + run dolt rebase -i --empty=keep main + [ "$status" -eq 0 ] + [[ "$output" =~ "Successfully rebased and updated refs/heads/b1" ]] || false + + # Make sure the commit that became empty appears in the commit log + run dolt log + [[ $output =~ "repeating change from main on b1" ]] || false +} diff --git a/integration-tests/bats/show.bats b/integration-tests/bats/show.bats index 06ae97b7f5..f3918ec234 100644 --- a/integration-tests/bats/show.bats +++ b/integration-tests/bats/show.bats @@ -140,6 +140,25 @@ assert_has_key_value() { assert_has_key "ParentClosure" "$output" } +@test "show: --no-pretty commit hash" { + dolt commit --allow-empty -m "commit: initialize table1" + hash=$(dolt sql -q "select dolt_hashof('head');" -r csv | tail -n 1) + run dolt show --no-pretty $hash + [ $status -eq 0 ] + [[ "$output" =~ "SerialMessage" ]] || false + assert_has_key "Name" "$output" + assert_has_key_value "Name" "Bats Tests" "$output" + assert_has_key_value "Desc" "commit: initialize table1" "$output" + assert_has_key_value "Name" "Bats Tests" "$output" + assert_has_key_value "Email" "bats@email.fake" "$output" + assert_has_key "Timestamp" "$output" + assert_has_key "UserTimestamp" "$output" + assert_has_key_value "Height" "2" "$output" + assert_has_key "RootValue" "$output" + assert_has_key "Parents" "$output" + assert_has_key "ParentClosure" "$output" +} + @test "show: HEAD root" { dolt sql -q "create table table1 (pk int PRIMARY KEY)" dolt sql -q "insert into table1 values (1), (2), (3)" diff --git a/integration-tests/bats/sql-cherry-pick.bats b/integration-tests/bats/sql-cherry-pick.bats index 5942736b80..5f279e7971 100644 --- a/integration-tests/bats/sql-cherry-pick.bats +++ b/integration-tests/bats/sql-cherry-pick.bats @@ -87,14 +87,21 @@ SQL [[ "$output" =~ "ancestor" ]] || false } -@test "sql-cherry-pick: no changes" { +@test "sql-cherry-pick: empty commit handling" { run dolt sql<>" + # Now match patternB + expect { + -re $patternB { + puts "<>" + eval $action + } + timeout { + puts "<>" + exit 1 + } + eof { + puts "<>" + exit 1 + } + failed { + puts "<>" + exit 1 + } + } } timeout { - puts "<>"; + puts "<>" exit 1 } eof { - puts "<>"; + puts "<>" exit 1 } failed { - puts "<>"; + puts "<>" exit 1 } } } - spawn dolt sql -expect_with_defaults {dolt-repo-[0-9]+/main\*> } { send "\\commit -A -m \"sql-shell-slash-cmds commit\"\r"; } +expect_with_defaults {dolt-repo-[0-9]+/main\*> } { send "\\commit -A -m \"sql-shell-slash-cmds commit\"\r"; } + +expect_with_defaults {dolt-repo-[0-9]+/main> } { send "\\log -n 1;\r"; } + +expect_with_defaults_2 {sql-shell-slash-cmds commit} {dolt-repo-[0-9]+/main> } { send "\\status\r"; } + +expect_with_defaults_2 {nothing to commit, working tree clean} {dolt-repo-[0-9]+/main> } { send "\\checkout -b br1\r"; } + +expect_with_defaults_2 {Switched to branch 'br1'} {dolt-repo-[0-9]+/br1> } { send "\\commit --allow-empty -m \"empty cmt\"\r"; } + +expect_with_defaults_2 {empty cmt} {dolt-repo-[0-9]+/br1> } { send "\\checkout main\r"; } + +expect_with_defaults_2 {Switched to branch 'main'} {dolt-repo-[0-9]+/main> } { send "\\commit --allow-empty -m \"main cmt\"\r"; } + +expect_with_defaults_2 {main cmt} {dolt-repo-[0-9]+/main> } { send "\\merge br1\r"; } + +expect_with_defaults_2 {Everything up-to-date} {dolt-repo-[0-9]+/main> } { send "\\show\r"; } + +expect_with_defaults_2 {Merge branch 'br1'} {dolt-repo-[0-9]+/main> } { send "\\log -n 3\r"; } + +expect_with_defaults_2 {empty cmt} {dolt-repo-[0-9]+/main> } { send "\\checkout br1\r"; } -expect_with_defaults {dolt-repo-[0-9]+/main> } { send "\\log -n 1;\r"; } +expect_with_defaults_2 {Switched to branch 'br1'} {dolt-repo-[0-9]+/br1> } { send "\\merge main\r"; } -expect_with_defaults_2 {sql-shell-slash-cmds commit} {dolt-repo-[0-9]+/main> } { send "\\status\r"; } +expect_with_defaults_2 {Fast-forward} {dolt-repo-[0-9]+/br1> } { send "\\reset HEAD~3;\r"; } -expect_with_defaults {dolt-repo-[0-9]+/main> } { send "\\reset HEAD~1;\r"; } +expect_with_defaults {dolt-repo-[0-9]+/br1\*> } { send "\\diff\r"; } -expect_with_defaults {dolt-repo-[0-9]+/main\*> } { send "\\diff\r"; } +expect_with_defaults_2 {diff --dolt a/test b/test} {dolt-repo-[0-9]+/br1\*> } { send "\\reset main\r"; } -expect_with_defaults_2 {diff --dolt a/tbl b/tbl} {dolt-repo-[0-9]+/main\*> } {send "quit\r";} +expect_with_defaults {dolt-repo-[0-9]+/br1> } { send "quit\r" } expect eof exit