diff --git a/go/libraries/doltcore/env/actions/checkout.go b/go/libraries/doltcore/env/actions/checkout.go index 23bab107bd..9e83035659 100644 --- a/go/libraries/doltcore/env/actions/checkout.go +++ b/go/libraries/doltcore/env/actions/checkout.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/dolt/go/libraries/utils/set" "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/hash" @@ -88,6 +89,36 @@ func MoveTablesFromHeadToWorking(ctx context.Context, roots doltdb.Roots, tbls [ return roots, nil } +// FindTableInRoots resolves a table by looking in all three roots (working, +// staged, head) in that order. +func FindTableInRoots(ctx *sql.Context, roots doltdb.Roots, name string) (doltdb.TableName, *doltdb.Table, bool, error) { + tbl, root, tblExists, err := resolve.Table(ctx, roots.Working, name) + if err != nil { + return doltdb.TableName{}, nil, false, err + } + if tblExists { + return tbl, root, true, nil + } + + tbl, root, tblExists, err = resolve.Table(ctx, roots.Staged, name) + if err != nil { + return doltdb.TableName{}, nil, false, err + } + if tblExists { + return tbl, root, true, nil + } + + tbl, root, tblExists, err = resolve.Table(ctx, roots.Head, name) + if err != nil { + return doltdb.TableName{}, nil, false, err + } + if tblExists { + return tbl, root, true, nil + } + + return doltdb.TableName{}, nil, false, nil +} + // RootsForBranch returns the roots needed for a branch checkout. |roots.Head| should be the pre-checkout head. The // returned roots struct has |Head| set to |branchRoot|. func RootsForBranch(ctx context.Context, roots doltdb.Roots, branchRoot doltdb.RootValue, force bool) (doltdb.Roots, error) { diff --git a/go/libraries/doltcore/env/actions/reset.go b/go/libraries/doltcore/env/actions/reset.go index bec0a40027..18fbf3c96f 100644 --- a/go/libraries/doltcore/env/actions/reset.go +++ b/go/libraries/doltcore/env/actions/reset.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/ref" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/dolt/go/store/datas" ) @@ -271,7 +272,7 @@ func getUnionedTables(ctx context.Context, tables []doltdb.TableName, stagedRoot // CleanUntracked deletes untracked tables from the working root. // Evaluates untracked tables as: all working tables - all staged tables. -func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool) (doltdb.Roots, error) { +func CleanUntracked(ctx *sql.Context, roots doltdb.Roots, tables []string, dryrun bool, force bool) (doltdb.Roots, error) { untrackedTables := make(map[doltdb.TableName]struct{}) var err error @@ -284,21 +285,24 @@ func CleanUntracked(ctx context.Context, roots doltdb.Roots, tables []string, dr for i := range tables { name := tables[i] - _, _, err = roots.Working.GetTable(ctx, doltdb.TableName{Name: name}) + resolvedName, _, tblExists, err := resolve.Table(ctx, roots.Working, name) if err != nil { return doltdb.Roots{}, err } - untrackedTables[doltdb.TableName{Name: name}] = struct{}{} + if !tblExists { + return doltdb.Roots{}, fmt.Errorf("%w: '%s'", doltdb.ErrTableNotFound, name) + } + untrackedTables[resolvedName] = struct{}{} } // untracked tables = working tables - staged tables - headTblNames, err := roots.Staged.GetTableNames(ctx, doltdb.DefaultSchemaName) + headTblNames := GetAllTableNames(ctx, roots.Staged) if err != nil { return doltdb.Roots{}, err } for _, name := range headTblNames { - delete(untrackedTables, doltdb.TableName{Name: name}) + delete(untrackedTables, name) } newRoot := roots.Working diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go index 6a02fa4e92..57096997db 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_checkout.go @@ -476,9 +476,20 @@ func doGlobalCheckout(ctx *sql.Context, branchName string, isForce bool, isNewBr } func checkoutTables(ctx *sql.Context, roots doltdb.Roots, name string, tables []string) error { - // TODO: schema name - roots, err := actions.MoveTablesFromHeadToWorking(ctx, roots, doltdb.ToTableNames(tables, doltdb.DefaultSchemaName)) + tableNames := make([]doltdb.TableName, len(tables)) + for i, table := range tables { + tbl, _, exists, err := actions.FindTableInRoots(ctx, roots, table) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("error: given tables do not exist") + } + tableNames[i] = tbl + } + + roots, err := actions.MoveTablesFromHeadToWorking(ctx, roots, tableNames) if err != nil { if doltdb.IsRootValUnreachable(err) { rt := doltdb.GetUnreachableRootType(err)