From 784e916dee1470b9e9faffd375eb747898eeaa21 Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Thu, 18 Jan 2024 11:24:10 +0200 Subject: [PATCH] Cherry-pick 78006991dd22f7bdce38d9271fa2e78c9ce09dd3 with conflicts --- .../backup/vtbackup/backup_only_test.go | 8 +- go/test/endtoend/clustertest/vtctld_test.go | 48 ++++++++++- go/test/endtoend/mysqlserver/main_test.go | 4 +- .../reparent/emergencyreparent/ers_test.go | 5 +- .../reparent/plannedreparent/reparent_test.go | 29 ++++--- go/test/endtoend/reparent/utils/utils.go | 9 +++ go/test/endtoend/vtorc/general/vtorc_test.go | 21 +++-- go/vt/schemadiff/schema.go | 9 ++- go/vt/sqlparser/ast_test.go | 75 +++++++++++++++++ go/vt/sqlparser/parser.go | 17 ++++ go/vt/sqlparser/utils.go | 20 +++++ go/vt/sqlparser/utils_test.go | 69 ++++++++++++++++ go/vt/vttablet/tabletmanager/rpc_query.go | 73 ++++++++++++++++- .../vttablet/tabletmanager/rpc_query_test.go | 80 +++++++++++++++++++ 14 files changed, 434 insertions(+), 33 deletions(-) diff --git a/go/test/endtoend/backup/vtbackup/backup_only_test.go b/go/test/endtoend/backup/vtbackup/backup_only_test.go index 5e80d5d3cc3..87eca996418 100644 --- a/go/test/endtoend/backup/vtbackup/backup_only_test.go +++ b/go/test/endtoend/backup/vtbackup/backup_only_test.go @@ -326,12 +326,12 @@ func tearDown(t *testing.T, initMysql bool) { } caughtUp := waitForReplicationToCatchup([]cluster.Vttablet{*replica1, *replica2}) require.True(t, caughtUp, "Timed out waiting for all replicas to catch up") - promoteCommands := "STOP SLAVE; RESET SLAVE ALL; RESET MASTER;" - disableSemiSyncCommands := "SET GLOBAL rpl_semi_sync_master_enabled = false; SET GLOBAL rpl_semi_sync_slave_enabled = false" + promoteCommands := []string{"STOP SLAVE", "RESET SLAVE ALL", "RESET MASTER"} + disableSemiSyncCommands := []string{"SET GLOBAL rpl_semi_sync_master_enabled = false", " SET GLOBAL rpl_semi_sync_slave_enabled = false"} for _, tablet := range []cluster.Vttablet{*primary, *replica1, *replica2} { - _, err := tablet.VttabletProcess.QueryTablet(promoteCommands, keyspaceName, true) + err := tablet.VttabletProcess.QueryTabletMultiple(promoteCommands, keyspaceName, true) require.Nil(t, err) - _, err = tablet.VttabletProcess.QueryTablet(disableSemiSyncCommands, keyspaceName, true) + err = tablet.VttabletProcess.QueryTabletMultiple(disableSemiSyncCommands, keyspaceName, true) require.Nil(t, err) } diff --git a/go/test/endtoend/clustertest/vtctld_test.go b/go/test/endtoend/clustertest/vtctld_test.go index 45643d869b1..c1b341ccd73 100644 --- a/go/test/endtoend/clustertest/vtctld_test.go +++ b/go/test/endtoend/clustertest/vtctld_test.go @@ -128,9 +128,51 @@ func testTabletStatus(t *testing.T) { } func testExecuteAsDba(t *testing.T) { - result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, `SELECT 1 AS a`) - require.NoError(t, err) - assert.Equal(t, result, oneTableOutput) + tcases := []struct { + query string + result string + expectErr bool + }{ + { + query: "", + expectErr: true, + }, + { + query: "SELECT 1 AS a", + result: oneTableOutput, + }, + { + query: "SELECT 1 AS a; SELECT 1 AS a", + expectErr: true, + }, + { + query: "create table t(id int)", + result: "", + }, + { + query: "create table if not exists t(id int)", + result: "", + }, + { + query: "create table if not exists t(id int); create table if not exists t(id int);", + result: "", + }, + { + query: "create table if not exists t(id int); create table if not exists t(id int); SELECT 1 AS a", + expectErr: true, + }, + } + for _, tcase := range tcases { + t.Run(tcase.query, func(t *testing.T) { + result, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ExecuteFetchAsDba", clusterInstance.Keyspaces[0].Shards[0].Vttablets[0].Alias, tcase.query) + if tcase.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tcase.result, result) + } + }) + } } func testExecuteAsApp(t *testing.T) { diff --git a/go/test/endtoend/mysqlserver/main_test.go b/go/test/endtoend/mysqlserver/main_test.go index 42b4e6ea235..18b169e33d7 100644 --- a/go/test/endtoend/mysqlserver/main_test.go +++ b/go/test/endtoend/mysqlserver/main_test.go @@ -51,7 +51,7 @@ var ( PARTITION BY HASH( TO_DAYS(created) ) PARTITIONS 10; ` - createProcSQL = `use vt_test_keyspace; + createProcSQL = ` CREATE PROCEDURE testing() BEGIN delete from vt_insert_test; @@ -144,7 +144,7 @@ func TestMain(m *testing.M) { } primaryTabletProcess := clusterInstance.Keyspaces[0].Shards[0].PrimaryTablet().VttabletProcess - if _, err := primaryTabletProcess.QueryTablet(createProcSQL, keyspaceName, false); err != nil { + if _, err := primaryTabletProcess.QueryTablet(createProcSQL, keyspaceName, true); err != nil { return 1, err } diff --git a/go/test/endtoend/reparent/emergencyreparent/ers_test.go b/go/test/endtoend/reparent/emergencyreparent/ers_test.go index 8f6638ecb7e..fbd4770e15e 100644 --- a/go/test/endtoend/reparent/emergencyreparent/ers_test.go +++ b/go/test/endtoend/reparent/emergencyreparent/ers_test.go @@ -349,8 +349,11 @@ func TestNoReplicationStatusAndIOThreadStopped(t *testing.T) { tablets := clusterInstance.Keyspaces[0].Shards[0].Vttablets utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[1], tablets[2], tablets[3]}) - err := clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `STOP SLAVE; RESET SLAVE ALL`) + err := clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `STOP SLAVE`) require.NoError(t, err) + err = clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[1].Alias, `RESET SLAVE ALL`) + require.NoError(t, err) + // err = clusterInstance.VtctlclientProcess.ExecuteCommand("ExecuteFetchAsDba", tablets[3].Alias, `STOP SLAVE IO_THREAD;`) require.NoError(t, err) // Run an additional command in the current primary which will only be acked by tablets[2] and be in its relay log. diff --git a/go/test/endtoend/reparent/plannedreparent/reparent_test.go b/go/test/endtoend/reparent/plannedreparent/reparent_test.go index 45cbeb565c7..014570d8439 100644 --- a/go/test/endtoend/reparent/plannedreparent/reparent_test.go +++ b/go/test/endtoend/reparent/plannedreparent/reparent_test.go @@ -98,7 +98,7 @@ func TestPRSWithDrainedLaggingTablet(t *testing.T) { utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[1], tablets[2], tablets[3]}) // make tablets[1 lag from the other tablets by setting the delay to a large number - utils.RunSQL(context.Background(), t, `stop slave;CHANGE MASTER TO MASTER_DELAY = 1999;start slave;`, tablets[1]) + utils.RunSQLs(context.Background(), t, []string{`stop slave`, `CHANGE MASTER TO MASTER_DELAY = 1999`, `start slave;`}, tablets[1]) // insert another row in tablets[1 utils.ConfirmReplication(t, tablets[0], []*cluster.Vttablet{tablets[2], tablets[3]}) @@ -226,26 +226,33 @@ func reparentFromOutside(t *testing.T, clusterInstance *cluster.LocalProcessClus } // commands to convert a replica to be writable - promoteReplicaCommands := "STOP SLAVE; RESET SLAVE ALL; SET GLOBAL read_only = OFF;" - utils.RunSQL(ctx, t, promoteReplicaCommands, tablets[1]) + promoteReplicaCommands := []string{"STOP SLAVE", "RESET SLAVE ALL", "SET GLOBAL read_only = OFF"} + utils.RunSQLs(ctx, t, promoteReplicaCommands, tablets[1]) // Get primary position _, gtID := cluster.GetPrimaryPosition(t, *tablets[1], utils.Hostname) // tablets[0] will now be a replica of tablets[1 - changeReplicationSourceCommands := fmt.Sprintf("RESET MASTER; RESET SLAVE; SET GLOBAL gtid_purged = '%s';"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", gtID, utils.Hostname, tablets[1].MySQLPort) - utils.RunSQL(ctx, t, changeReplicationSourceCommands, tablets[0]) + changeReplicationSourceCommands := []string{ + "RESET MASTER", + "RESET SLAVE", + fmt.Sprintf("SET GLOBAL gtid_purged = '%s'", gtID), + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, tablets[1].MySQLPort), + } + utils.RunSQLs(ctx, t, changeReplicationSourceCommands, tablets[0]) // Capture time when we made tablets[1 writable baseTime := time.Now().UnixNano() / 1000000000 // tablets[2 will be a replica of tablets[1 - changeReplicationSourceCommands = fmt.Sprintf("STOP SLAVE; RESET MASTER; SET GLOBAL gtid_purged = '%s';"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", gtID, utils.Hostname, tablets[1].MySQLPort) - utils.RunSQL(ctx, t, changeReplicationSourceCommands, tablets[2]) + changeReplicationSourceCommands = []string{ + "STOP SLAVE", + "RESET MASTER", + fmt.Sprintf("SET GLOBAL gtid_purged = '%s'", gtID), + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, tablets[1].MySQLPort), + "START SLAVE", + } + utils.RunSQLs(ctx, t, changeReplicationSourceCommands, tablets[2]) // To test the downPrimary, we kill the old primary first and delete its tablet record if downPrimary { diff --git a/go/test/endtoend/reparent/utils/utils.go b/go/test/endtoend/reparent/utils/utils.go index 4ef13819e85..0f48f2b3fa8 100644 --- a/go/test/endtoend/reparent/utils/utils.go +++ b/go/test/endtoend/reparent/utils/utils.go @@ -258,6 +258,15 @@ func getMysqlConnParam(tablet *cluster.Vttablet) mysql.ConnParams { return connParams } +// RunSQLs is used to run SQL commands directly on the MySQL instance of a vttablet +func RunSQLs(ctx context.Context, t *testing.T, sqls []string, tablet *cluster.Vttablet) (results []*sqltypes.Result) { + for _, sql := range sqls { + result := RunSQL(ctx, t, sql, tablet) + results = append(results, result) + } + return results +} + // RunSQL is used to run a SQL command directly on the MySQL instance of a vttablet func RunSQL(ctx context.Context, t *testing.T, sql string, tablet *cluster.Vttablet) *sqltypes.Result { tabletParams := getMysqlConnParam(tablet) diff --git a/go/test/endtoend/vtorc/general/vtorc_test.go b/go/test/endtoend/vtorc/general/vtorc_test.go index 244cd364e7c..f7deffe20b3 100644 --- a/go/test/endtoend/vtorc/general/vtorc_test.go +++ b/go/test/endtoend/vtorc/general/vtorc_test.go @@ -189,9 +189,13 @@ func TestVTOrcRepairs(t *testing.T) { t.Run("ReplicationFromOtherReplica", func(t *testing.T) { // point replica at otherReplica - changeReplicationSourceCommand := fmt.Sprintf("STOP SLAVE; RESET SLAVE ALL;"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1; START SLAVE", utils.Hostname, otherReplica.MySQLPort) - _, err := utils.RunSQL(t, changeReplicationSourceCommand, replica, "") + changeReplicationSourceCommands := []string{ + "STOP SLAVE", + "RESET SLAVE ALL", + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", utils.Hostname, otherReplica.MySQLPort), + "START SLAVE", + } + err := utils.RunSQLs(t, changeReplicationSourceCommands, replica, "") require.NoError(t, err) // wait until the source port is set back correctly by vtorc @@ -204,10 +208,13 @@ func TestVTOrcRepairs(t *testing.T) { t.Run("CircularReplication", func(t *testing.T) { // change the replication source on the primary - changeReplicationSourceCommands := fmt.Sprintf("STOP SLAVE; RESET SLAVE ALL;"+ - "CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1;"+ - "START SLAVE;", replica.VttabletProcess.TabletHostname, replica.MySQLPort) - _, err := utils.RunSQL(t, changeReplicationSourceCommands, curPrimary, "") + changeReplicationSourceCommands := []string{ + "STOP SLAVE", + "RESET SLAVE ALL", + fmt.Sprintf("CHANGE MASTER TO MASTER_HOST='%s', MASTER_PORT=%d, MASTER_USER='vt_repl', MASTER_AUTO_POSITION = 1", replica.VttabletProcess.TabletHostname, replica.MySQLPort), + "START SLAVE", + } + err := utils.RunSQLs(t, changeReplicationSourceCommands, curPrimary, "") require.NoError(t, err) // wait for curPrimary to reach stable state diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index a9ef60fbb27..2e630726512 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -19,8 +19,6 @@ package schemadiff import ( "bytes" "errors" - "fmt" - "io" "sort" "strings" @@ -113,6 +111,7 @@ func NewSchemaFromQueries(queries []string) (*Schema, error) { // NewSchemaFromSQL creates a valid and normalized schema based on a SQL blob that contains // CREATE statements for various objects (tables, views) +<<<<<<< HEAD func NewSchemaFromSQL(sql string) (*Schema, error) { var statements []sqlparser.Statement tokenizer := sqlparser.NewStringTokenizer(sql) @@ -125,6 +124,12 @@ func NewSchemaFromSQL(sql string) (*Schema, error) { return nil, fmt.Errorf("could not parse statement in SQL: %v: %w", sql, err) } statements = append(statements, stmt) +======= +func NewSchemaFromSQL(env *Environment, sql string) (*Schema, error) { + statements, err := env.Parser.SplitStatements(sql) + if err != nil { + return nil, err +>>>>>>> 78006991dd (Protect `ExecuteFetchAsDBA` against multi-statements, excluding a sequence of `CREATE TABLE|VIEW`. (#14954)) } return NewSchemaFromStatements(statements) } diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index 97b93a80379..46b46957e9f 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -687,6 +687,77 @@ func TestColumns_FindColumn(t *testing.T) { } } +func TestSplitStatements(t *testing.T) { + testcases := []struct { + input string + stmts int + wantErr bool + }{ + { + input: "select * from table1; \t; \n; \n\t\t ;select * from table1;", + stmts: 2, + }, { + input: "select * from table1", + stmts: 1, + }, { + input: "select * from table1;", + stmts: 1, + }, { + input: "select * from table1; ", + stmts: 1, + }, { + input: "select * from table1; select * from table2;", + stmts: 2, + }, { + input: "create /*vt+ directive=true */ table t1 (id int); create table t2 (id int); create table t3 (id int)", + stmts: 3, + }, { + input: "create /*vt+ directive=true */ table t1 (id int); create table t2 (id int); create table t3 (id int);", + stmts: 3, + }, { + input: "select * from /* comment ; */ table1;", + stmts: 1, + }, { + input: "select * from table1 where semi = ';';", + stmts: 1, + }, { + input: "CREATE TABLE `total_data` (`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'id', " + + "`region` varchar(32) NOT NULL COMMENT 'region name, like zh; th; kepler'," + + "`data_size` bigint NOT NULL DEFAULT '0' COMMENT 'data size;'," + + "`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," + + "`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," + + "PRIMARY KEY (`id`))", + stmts: 1, + }, { + input: "create table t1 (id int primary key); create table t2 (id int primary key);", + stmts: 2, + }, { + input: ";;; create table t1 (id int primary key);;; ;create table t2 (id int primary key);", + stmts: 2, + }, { + input: ";create table t1 ;create table t2 (id;", + wantErr: true, + }, { + // Ignore quoted semicolon + input: ";create table t1 ';';;;create table t2 (id;", + wantErr: true, + }, + } + + parser := NewTestParser() + for _, tcase := range testcases { + t.Run(tcase.input, func(t *testing.T) { + statements, err := parser.SplitStatements(tcase.input) + if tcase.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tcase.stmts, len(statements)) + } + }) + } +} + func TestSplitStatementToPieces(t *testing.T) { testcases := []struct { input string @@ -735,6 +806,10 @@ func TestSplitStatementToPieces(t *testing.T) { // Ignore quoted semicolon input: ";create table t1 ';';;;create table t2 (id;", output: "create table t1 ';';create table t2 (id", + }, { + // Ignore quoted semicolon + input: "stop replica; start replica", + output: "stop replica; start replica", }, } diff --git a/go/vt/sqlparser/parser.go b/go/vt/sqlparser/parser.go index ae630ce3dea..31389da172c 100644 --- a/go/vt/sqlparser/parser.go +++ b/go/vt/sqlparser/parser.go @@ -17,6 +17,7 @@ limitations under the License. package sqlparser import ( + "errors" "fmt" "io" "strconv" @@ -263,6 +264,22 @@ func SplitStatement(blob string) (string, string, error) { return blob, "", nil } +// SplitStatements splits a given blob into multiple SQL statements. +func (p *Parser) SplitStatements(blob string) (statements []Statement, err error) { + tokenizer := p.NewStringTokenizer(blob) + for { + stmt, err := ParseNext(tokenizer) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, err + } + statements = append(statements, stmt) + } + return statements, nil +} + // SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces // returns the sql pieces blob contains; or error if sql cannot be parsed func SplitStatementToPieces(blob string) (pieces []string, err error) { diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index 0f3c66f2ea3..31ea4410df9 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -19,6 +19,7 @@ package sqlparser import ( "fmt" "sort" + "strings" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -160,3 +161,22 @@ func ReplaceTableQualifiers(query, olddb, newdb string) (string, error) { } return query, nil } + +// ReplaceTableQualifiersMultiQuery accepts a multi-query string and modifies it +// via ReplaceTableQualifiers, one query at a time. +func (p *Parser) ReplaceTableQualifiersMultiQuery(multiQuery, olddb, newdb string) (string, error) { + queries, err := p.SplitStatementToPieces(multiQuery) + if err != nil { + return multiQuery, err + } + var modifiedQueries []string + for _, query := range queries { + // Replace any provided sidecar database qualifiers with the correct one. + query, err := p.ReplaceTableQualifiers(query, olddb, newdb) + if err != nil { + return query, err + } + modifiedQueries = append(modifiedQueries, query) + } + return strings.Join(modifiedQueries, ";"), nil +} diff --git a/go/vt/sqlparser/utils_test.go b/go/vt/sqlparser/utils_test.go index 63c9b10ba43..56e1fdedde9 100644 --- a/go/vt/sqlparser/utils_test.go +++ b/go/vt/sqlparser/utils_test.go @@ -275,3 +275,72 @@ func TestReplaceTableQualifiers(t *testing.T) { }) } } + +func TestReplaceTableQualifiersMultiQuery(t *testing.T) { + origDB := "_vt" + tests := []struct { + name string + in string + newdb string + out string + wantErr bool + }{ + { + name: "invalid select", + in: "select frog bar person", + out: "", + wantErr: true, + }, + { + name: "simple select", + in: "select * from _vt.foo", + out: "select * from foo", + }, + { + name: "simple select with new db", + in: "select * from _vt.foo", + newdb: "_vt_test", + out: "select * from _vt_test.foo", + }, + { + name: "simple select with new db same", + in: "select * from _vt.foo where id=1", // should be unchanged + newdb: "_vt", + out: "select * from _vt.foo where id=1", + }, + { + name: "simple select with new db needing escaping", + in: "select * from _vt.foo", + newdb: "1_vt-test", + out: "select * from `1_vt-test`.foo", + }, + { + name: "multi query", + in: "select * from _vt.foo ; select * from _vt.bar", + out: "select * from foo;select * from bar", + }, + { + name: "multi query with new db", + in: "select * from _vt.foo ; select * from _vt.bar", + newdb: "_vt_test", + out: "select * from _vt_test.foo;select * from _vt_test.bar", + }, + { + name: "multi query with error", + in: "select * from _vt.foo ; select * from _vt.bar ; sel ect fr om wh at", + wantErr: true, + }, + } + parser := NewTestParser() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parser.ReplaceTableQualifiersMultiQuery(tt.in, origDB, tt.newdb) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.out, got, "RemoveTableQualifiers(); in: %s, out: %s", tt.in, got) + }) + } +} diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 0d21cee7677..c9be8804862 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -24,11 +24,50 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + "vitess.io/vitess/go/vt/proto/vtrpc" ) +// analyzeExecuteFetchAsDbaMultiQuery reutrns 'true' when at least one of the queries +// in the given SQL has a `/*vt+ allowZeroInDate=true */` directive. +func analyzeExecuteFetchAsDbaMultiQuery(sql string, parser *sqlparser.Parser) (queries []string, parseable bool, countCreate int, allowZeroInDate bool, err error) { + queries, err = parser.SplitStatementToPieces(sql) + if err != nil { + return nil, false, 0, false, err + } + if len(queries) == 0 { + return nil, false, 0, false, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "no statements found in query: %s", sql) + } + parseable = true + for _, query := range queries { + // Some of the queries we receive here are legitimately non-parseable by our + // current parser, such as `CHANGE REPLICATION SOURCE TO...`. We must allow + // them and so we skip parsing errors. + stmt, err := parser.Parse(query) + if err != nil { + parseable = false + continue + } + switch stmt.(type) { + case *sqlparser.CreateTable, *sqlparser.CreateView: + countCreate++ + default: + } + + if cmnt, ok := stmt.(sqlparser.Commented); ok { + directives := cmnt.GetParsedComments().Directives() + if directives.IsSet("allowZeroInDate") { + allowZeroInDate = true + } + } + + } + return queries, parseable, countCreate, allowZeroInDate, nil +} + // ExecuteFetchAsDba will execute the given query, possibly disabling binlogs and reload schema. func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanagerdatapb.ExecuteFetchAsDbaRequest) (*querypb.QueryResult, error) { // get a connection @@ -52,25 +91,53 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanag _, _ = conn.ExecuteFetch("USE "+sqlescape.EscapeID(req.DbName), 1, false) } +<<<<<<< HEAD // Handle special possible directives var directives *sqlparser.CommentDirectives if stmt, err := sqlparser.Parse(string(req.Query)); err == nil { if cmnt, ok := stmt.(sqlparser.Commented); ok { directives = cmnt.GetParsedComments().Directives() +======= + statements, _, countCreate, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(string(req.Query), tm.SQLParser) + if err != nil { + return nil, err + } + if len(statements) > 1 { + // Up to v19, we allow multi-statement SQL in ExecuteFetchAsDba, but only for the specific case + // where all statements are CREATE TABLE or CREATE VIEW. This is to support `ApplySchema --batch-size`. + // In v20, we will not support multi statements whatsoever. + // v20 will throw an error by virtua of using ExecuteFetch instead of ExecuteFetchMulti. + if countCreate != len(statements) { + return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "multi statement queries are not supported in ExecuteFetchAsDba unless all are CREATE TABLE or CREATE VIEW") +>>>>>>> 78006991dd (Protect `ExecuteFetchAsDBA` against multi-statements, excluding a sequence of `CREATE TABLE|VIEW`. (#14954)) } } - if directives.IsSet("allowZeroInDate") { + if allowZeroInDate { if _, err := conn.ExecuteFetch("set @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 1, false); err != nil { return nil, err } } - // Replace any provided sidecar database qualifiers with the correct one. +<<<<<<< HEAD uq, err := sqlparser.ReplaceTableQualifiers(string(req.Query), sidecar.DefaultName, sidecar.GetName()) +======= + // TODO(shlomi): we use ReplaceTableQualifiersMultiQuery for backwards compatibility. In v20 we will not accept + // multi statement queries in ExecuteFetchAsDBA. This will be rewritten as ReplaceTableQualifiers() + uq, err := tm.SQLParser.ReplaceTableQualifiersMultiQuery(string(req.Query), sidecar.DefaultName, sidecar.GetName()) +>>>>>>> 78006991dd (Protect `ExecuteFetchAsDBA` against multi-statements, excluding a sequence of `CREATE TABLE|VIEW`. (#14954)) if err != nil { return nil, err } - result, err := conn.ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) + // TODO(shlomi): we use ExecuteFetchMulti for backwards compatibility. In v20 we will not accept + // multi statement queries in ExecuteFetchAsDBA. This will be rewritten as: + // (in v20): result, err := ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) + result, more, err := conn.ExecuteFetchMulti(uq, int(req.MaxRows), true /*wantFields*/) + for more { + _, more, _, err = conn.ReadQueryResult(0, false) + if err != nil { + return nil, err + } + } // re-enable binlogs if necessary if req.DisableBinlogs && !conn.IsClosed() { diff --git a/go/vt/vttablet/tabletmanager/rpc_query_test.go b/go/vt/vttablet/tabletmanager/rpc_query_test.go index 87a64b2d8b7..5bc25ddfeb0 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query_test.go +++ b/go/vt/vttablet/tabletmanager/rpc_query_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -28,11 +29,90 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/mysqlctl" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttablet/tabletservermock" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" ) +func TestAnalyzeExecuteFetchAsDbaMultiQuery(t *testing.T) { + tcases := []struct { + query string + count int + parseable bool + allowZeroInDate bool + allCreate bool + expectErr bool + }{ + { + query: "", + expectErr: true, + }, + { + query: "select * from t1 ; select * from t2", + count: 2, + parseable: true, + }, + { + query: "create table t(id int)", + count: 1, + allCreate: true, + parseable: true, + }, + { + query: "create table t(id int); create view v as select 1 from dual", + count: 2, + allCreate: true, + parseable: true, + }, + { + query: "create table t(id int); create view v as select 1 from dual; drop table t3", + count: 3, + allCreate: false, + parseable: true, + }, + { + query: "create /*vt+ allowZeroInDate=true */ table t (id int)", + count: 1, + allCreate: true, + allowZeroInDate: true, + parseable: true, + }, + { + query: "create table a (id int) ; create /*vt+ allowZeroInDate=true */ table b (id int)", + count: 2, + allCreate: true, + allowZeroInDate: true, + parseable: true, + }, + { + query: "stop replica; start replica", + count: 2, + parseable: false, + }, + { + query: "create table a (id int) ; --comment ; what", + count: 3, + parseable: false, + }, + } + for _, tcase := range tcases { + t.Run(tcase.query, func(t *testing.T) { + parser := sqlparser.NewTestParser() + queries, parseable, countCreate, allowZeroInDate, err := analyzeExecuteFetchAsDbaMultiQuery(tcase.query, parser) + if tcase.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tcase.count, len(queries)) + assert.Equal(t, tcase.parseable, parseable) + assert.Equal(t, tcase.allCreate, (countCreate == len(queries))) + assert.Equal(t, tcase.allowZeroInDate, allowZeroInDate) + } + }) + } +} + func TestTabletManager_ExecuteFetchAsDba(t *testing.T) { ctx := context.Background() cp := mysql.ConnParams{}