diff --git a/go/cmd/dolt/commands/ci/init.go b/go/cmd/dolt/commands/ci/init.go index 147f1e2e8e..1308161d0c 100644 --- a/go/cmd/dolt/commands/ci/init.go +++ b/go/cmd/dolt/commands/ci/init.go @@ -16,6 +16,7 @@ package ci import ( "context" + "fmt" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" "github.com/dolthub/dolt/go/cmd/dolt/errhand" @@ -89,8 +90,19 @@ func (cmd InitCmd) Exec(ctx context.Context, commandStr string, args []string, d return commands.HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage) } - tc := sqle.NewDoltCITablesCreator(sqlCtx, db) var verr errhand.VerboseError + tc := sqle.NewDoltCITablesCreator(sqlCtx, db) + + hasTables, err := tc.HasTables(sqlCtx) + if err != nil { + verr = errhand.VerboseErrorFromError(err) + } + + if hasTables { + verr = errhand.VerboseErrorFromError(fmt.Errorf("dolt ci has already been initialized")) + return commands.HandleVErrAndExitCode(verr, usage) + } + if err = tc.CreateTables(sqlCtx); err != nil { verr = errhand.VerboseErrorFromError(err) } diff --git a/go/libraries/doltcore/sqle/dolt_ci_tables_creator.go b/go/libraries/doltcore/sqle/dolt_ci_tables_creator.go index 1a22a1c41a..0bf132a7cc 100644 --- a/go/libraries/doltcore/sqle/dolt_ci_tables_creator.go +++ b/go/libraries/doltcore/sqle/dolt_ci_tables_creator.go @@ -22,7 +22,18 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) +var ExpectedDoltCITables = []string{ + doltdb.WorkflowsTableName, + doltdb.WorkflowEventsTableName, +} + type DoltCITablesCreator interface { + // HasTables is used to check whether the database + // already contains dolt ci tables. If any expected tables are missing, + // an error is returned + HasTables(ctx *sql.Context) (bool, error) + + // CreateTables creates all tables required for dolt ci CreateTables(ctx *sql.Context) error } @@ -51,6 +62,29 @@ func (d *doltCITablesCreator) createTables(ctx *sql.Context) error { return d.workflowEventsTC.CreateTable(ctx) } +func (d *doltCITablesCreator) HasTables(ctx *sql.Context) (bool, error) { + dbName := ctx.GetCurrentDatabase() + dSess := dsess.DSessFromSess(ctx.Session) + ws, err := dSess.WorkingSet(ctx, dbName) + if err != nil { + return false, err + } + + root := ws.WorkingRoot() + + for _, tableName := range ExpectedDoltCITables { + found, err := root.HasTable(ctx, doltdb.TableName{Name: tableName}) + if err != nil { + return false, err + } + if !found { + return false, fmt.Errorf("required dolt ci table `%s` not found", doltdb.WorkflowsTableName) + } + } + + return true, nil +} + func (d *doltCITablesCreator) CreateTables(ctx *sql.Context) error { if err := dsess.CheckAccessForDb(d.ctx, d.db, branch_control.Permissions_Write); err != nil { return err