Skip to content

Commit

Permalink
Merge pull request #6374 from dolthub/steph/db-branch
Browse files Browse the repository at this point in the history
adds support for `db/branch` syntax with `--use-db` global arg
  • Loading branch information
stephkyou authored Jul 25, 2023
2 parents 8504f02 + 8b054bb commit d2d9d6d
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 87 deletions.
10 changes: 7 additions & 3 deletions go/cmd/dolt/commands/sqlserver/sqlclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/dolthub/dolt/go/cmd/dolt/commands"
"github.com/dolthub/dolt/go/cmd/dolt/commands/engine"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
Expand Down Expand Up @@ -475,17 +476,20 @@ func (c ConnectionQueryist) Query(ctx *sql.Context, query string) (sql.Schema, s

// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this
// module isn't ideal, but it's the only way to get the server config into the queryist.
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, database string) (cli.LateBindQueryist, error) {
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) {
clientConfig, err := GetClientConfig(cwdFS, creds, apr)
if err != nil {
return nil, err
}

parsedMySQLConfig, err := mysqlDriver.ParseDSN(ConnectionString(clientConfig, database))
// ParseDSN currently doesn't support `/` in the db name
dbName, _ := dsess.SplitRevisionDbName(dbRev)
parsedMySQLConfig, err := mysqlDriver.ParseDSN(ConnectionString(clientConfig, dbName))
if err != nil {
return nil, err
}

parsedMySQLConfig.DBName = dbRev
parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port)

if useTLS {
Expand All @@ -503,7 +507,7 @@ func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, c

var lateBind cli.LateBindQueryist = func(ctx context.Context) (cli.Queryist, *sql.Context, func(), error) {
sqlCtx := sql.NewContext(ctx)
sqlCtx.SetCurrentDatabase(database)
sqlCtx.SetCurrentDatabase(dbRev)
return queryist, sqlCtx, func() { conn.Conn(ctx) }, nil
}

Expand Down
5 changes: 3 additions & 2 deletions go/cmd/dolt/dolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,10 @@ func buildLateBinder(ctx context.Context, cwdFS filesys.Filesys, mrEnv *env.Mult
}

if hasUseDb {
targetEnv = mrEnv.GetEnv(useDb)
dbName, _ := dsess.SplitRevisionDbName(useDb)
targetEnv = mrEnv.GetEnv(dbName)
if targetEnv == nil {
return nil, fmt.Errorf("The provided --use-db %s does not exist or is not a directory.", useDb)
return nil, fmt.Errorf("The provided --use-db %s does not exist.", dbName)
}
} else {
useDb = mrEnv.GetFirstDatabase()
Expand Down
2 changes: 1 addition & 1 deletion go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ require (
github.com/dustin/go-humanize v1.0.0
github.com/fatih/color v1.13.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
github.com/go-sql-driver/mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.2-0.20230713085235-0b18dac46f7f
github.com/gocraft/dbr/v2 v2.7.2
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4
Expand Down
3 changes: 2 additions & 1 deletion go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhO
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.2-0.20230713085235-0b18dac46f7f h1:4+t8Qb99xUG/Ea00cQAiQl+gsjpK8ZYtAO8E76gRzQI=
github.com/go-sql-driver/mysql v1.7.2-0.20230713085235-0b18dac46f7f/go.mod h1:6gYm/zDt3ahdnMVTPeT/LfoBFsws1qZm5yI6FmVjB14=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func TestBinlogReplicationForAllTypes(t *testing.T) {
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
assertValues(t, 0, row)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
assertValues(t, 1, row)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "3", row["pk"])
assertNullValues(t, row)
require.False(t, rows.Next())
Expand All @@ -70,13 +70,13 @@ func TestBinlogReplicationForAllTypes(t *testing.T) {
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.alltypes order by pk asc;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
assertNullValues(t, row)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
assertValues(t, 0, row)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "3", row["pk"])
assertValues(t, 1, row)
require.False(t, rows.Next())
Expand Down Expand Up @@ -516,7 +516,7 @@ func assertValues(t *testing.T, assertionIndex int, row map[string]interface{})

actualValue := ""
if row[typeDesc.ColumnName()] != nil {
actualValue = row[typeDesc.ColumnName()].(string)
actualValue = fmt.Sprintf("%v", row[typeDesc.ColumnName()])
}
if typeDesc.TypeDefinition == "json" {
// LD_1 and DOLT storage formats return JSON strings slightly differently; DOLT removes spaces
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestBinlogReplicationFilters_ignoreTablesOnly(t *testing.T) {
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
require.Equal(t, "0", row["min"])
require.Equal(t, "9", row["max"])
Expand All @@ -65,7 +65,7 @@ func TestBinlogReplicationFilters_ignoreTablesOnly(t *testing.T) {
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
require.Equal(t, nil, row["min"])
require.Equal(t, nil, row["max"])
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestBinlogReplicationFilters_doTablesOnly(t *testing.T) {
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
require.Equal(t, "0", row["min"])
require.Equal(t, "9", row["max"])
Expand All @@ -116,7 +116,7 @@ func TestBinlogReplicationFilters_doTablesOnly(t *testing.T) {
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
require.Equal(t, nil, row["min"])
require.Equal(t, nil, row["max"])
Expand Down Expand Up @@ -159,7 +159,7 @@ func TestBinlogReplicationFilters_doTablesAndIgnoreTables(t *testing.T) {
// Verify that all changes from t1 were applied on the replica
rows, err := replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t1;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "10", row["count"])
require.Equal(t, "0", row["min"])
require.Equal(t, "9", row["max"])
Expand All @@ -168,7 +168,7 @@ func TestBinlogReplicationFilters_doTablesAndIgnoreTables(t *testing.T) {
// Verify that no changes from t2 were applied on the replica
rows, err = replicaDatabase.Queryx("SELECT COUNT(pk) as count, MIN(pk) as min, MAX(pk) as max from db01.t2;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["count"])
require.Equal(t, nil, row["min"])
require.Equal(t, nil, row["max"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "3", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "5", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "7", row["pk"])
require.False(t, rows.Next())
require.NoError(t, rows.Close())
Expand All @@ -64,19 +64,19 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.dolt_diff;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "1", row["data_change"])
require.EqualValues(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "1", row["data_change"])
require.EqualValues(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "1", row["data_change"])
require.EqualValues(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "0", row["data_change"])
require.EqualValues(t, "1", row["schema_change"])
Expand All @@ -88,33 +88,33 @@ func TestBinlogReplicationMultiDb(t *testing.T) {
replicaDatabase.MustExec("use db02;")
rows, err = replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "4", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "6", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "8", row["pk"])
require.False(t, rows.Next())
require.NoError(t, rows.Close())

// Verify db02.dolt_diff
rows, err = replicaDatabase.Queryx("select * from db02.dolt_diff;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "1", row["data_change"])
require.Equal(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "1", row["data_change"])
require.Equal(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "1", row["data_change"])
require.Equal(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "0", row["data_change"])
require.Equal(t, "1", row["schema_change"])
Expand Down Expand Up @@ -148,13 +148,13 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
waitForReplicaToCatchUp(t)
rows, err := replicaDatabase.Queryx("select * from db01.t01 order by pk asc;")
require.NoError(t, err)
row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "1", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "3", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "5", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "7", row["pk"])
require.False(t, rows.Next())
require.NoError(t, rows.Close())
Expand All @@ -163,11 +163,11 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
replicaDatabase.MustExec("use db01;")
rows, err = replicaDatabase.Queryx("select * from db01.dolt_diff;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "1", row["data_change"])
require.EqualValues(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t01", row["table_name"])
require.EqualValues(t, "0", row["data_change"])
require.EqualValues(t, "1", row["schema_change"])
Expand All @@ -179,25 +179,25 @@ func TestBinlogReplicationMultiDbTransactions(t *testing.T) {
replicaDatabase.MustExec("use db02;")
rows, err = replicaDatabase.Queryx("select * from db02.t02 order by pk asc;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "2", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "4", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "6", row["pk"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "8", row["pk"])
require.False(t, rows.Next())
require.NoError(t, rows.Close())

// Verify db02.dolt_diff
rows, err = replicaDatabase.Queryx("select * from db02.dolt_diff;")
require.NoError(t, err)
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "1", row["data_change"])
require.Equal(t, "0", row["schema_change"])
row = convertByteArraysToStrings(readNextRow(t, rows))
row = convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "t02", row["table_name"])
require.Equal(t, "0", row["data_change"])
require.Equal(t, "1", row["schema_change"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestBinlogReplicationAutoReconnect(t *testing.T) {
rows, err := replicaDatabase.Queryx("select min(pk) as min, max(pk) as max, count(pk) as count from db01.reconnect_test;")
require.NoError(t, err)

row := convertByteArraysToStrings(readNextRow(t, rows))
row := convertMapScanResultToStrings(readNextRow(t, rows))
require.Equal(t, "0", row["min"])
require.Equal(t, "999", row["max"])
require.Equal(t, "1000", row["count"])
Expand Down Expand Up @@ -145,7 +145,7 @@ func showReplicaStatus(t *testing.T) map[string]interface{} {
rows, err := replicaDatabase.Queryx("show replica status;")
require.NoError(t, err)
defer rows.Close()
return convertByteArraysToStrings(readNextRow(t, rows))
return convertMapScanResultToStrings(readNextRow(t, rows))
}

func configureToxiProxy(t *testing.T) {
Expand Down Expand Up @@ -184,15 +184,19 @@ func turnOnLimitDataToxic(t *testing.T) {
t.Logf("Toxiproxy proxy with limit_data toxic (1KB) started on port %d", proxyPort)
}

// convertByteArraysToStrings converts each []byte value in the specified map |m| into a string.
// convertMapScanResultToStrings converts each value in the specified map |m| into a string.
// This is necessary because MapScan doesn't honor (or know about) the correct underlying SQL types – it
// gets all results back as strings, typed as []byte.
// gets results back as strings, typed as []byte. Results also get returned as int64, which are converted to strings
// for ease of testing.
// More info at the end of this issue: https://github.com/jmoiron/sqlx/issues/225
func convertByteArraysToStrings(m map[string]interface{}) map[string]interface{} {
func convertMapScanResultToStrings(m map[string]interface{}) map[string]interface{} {
for key, value := range m {
if bytes, ok := value.([]byte); ok {
if bytes, ok := value.([]uint8); ok {
m[key] = string(bytes)
}
if i, ok := value.(int64); ok {
m[key] = strconv.FormatInt(i, 10)
}
}

return m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func TestBinlogReplicationServerRestart(t *testing.T) {
require.NoError(t, err)
replicaRows, err := replicaDatabase.Queryx(countMaxQuery)
require.NoError(t, err)
primaryRow := convertByteArraysToStrings(readNextRow(t, primaryRows))
replicaRow := convertByteArraysToStrings(readNextRow(t, replicaRows))
primaryRow := convertMapScanResultToStrings(readNextRow(t, primaryRows))
replicaRow := convertMapScanResultToStrings(readNextRow(t, replicaRows))
require.Equal(t, primaryRow["count"], replicaRow["count"])
require.Equal(t, primaryRow["max"], replicaRow["max"])
require.NoError(t, replicaRows.Close())
Expand Down
Loading

0 comments on commit d2d9d6d

Please sign in to comment.