From c18b2307efe534ef055a690b802a0e70678f657a Mon Sep 17 00:00:00 2001 From: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> Date: Mon, 15 Jan 2024 14:56:43 +0200 Subject: [PATCH] ExecuteFetchAsDBA(): handle 'allowZeroInDate' and batched queries Signed-off-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/vt/sqlparser/utils.go | 20 ++++++ go/vt/sqlparser/utils_test.go | 69 +++++++++++++++++++ go/vt/vttablet/tabletmanager/rpc_query.go | 61 +++++++++++++--- .../vttablet/tabletmanager/rpc_query_test.go | 29 ++++++++ 4 files changed, 171 insertions(+), 8 deletions(-) diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index 16c3e4ce976..b785128917f 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -19,6 +19,7 @@ package sqlparser import ( "fmt" "sort" + "strings" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -160,3 +161,22 @@ func (p *Parser) ReplaceTableQualifiers(query, olddb, newdb string) (string, err } return query, nil } + +// ReplaceTableQualifiersMultiQuery accepts a multi-query string and modifies it +// via ReplaceTableQualifiers, one query at a time. +func (p *Parser) ReplaceTableQualifiersMultiQuery(multiQuery, olddb, newdb string) (string, error) { + queries, err := p.SplitStatementToPieces(multiQuery) + if err != nil { + return multiQuery, err + } + var modifiedQueries []string + for _, query := range queries { + // Replace any provided sidecar database qualifiers with the correct one. + query, err := p.ReplaceTableQualifiers(query, olddb, newdb) + if err != nil { + return query, err + } + modifiedQueries = append(modifiedQueries, query) + } + return strings.Join(modifiedQueries, ";"), nil +} diff --git a/go/vt/sqlparser/utils_test.go b/go/vt/sqlparser/utils_test.go index b2833a8187c..64339211917 100644 --- a/go/vt/sqlparser/utils_test.go +++ b/go/vt/sqlparser/utils_test.go @@ -278,3 +278,72 @@ func TestReplaceTableQualifiers(t *testing.T) { }) } } + +func TestReplaceTableQualifiersMultiQuery(t *testing.T) { + origDB := "_vt" + tests := []struct { + name string + in string + newdb string + out string + wantErr bool + }{ + { + name: "invalid select", + in: "select frog bar person", + out: "", + wantErr: true, + }, + { + name: "simple select", + in: "select * from _vt.foo", + out: "select * from foo", + }, + { + name: "simple select with new db", + in: "select * from _vt.foo", + newdb: "_vt_test", + out: "select * from _vt_test.foo", + }, + { + name: "simple select with new db same", + in: "select * from _vt.foo where id=1", // should be unchanged + newdb: "_vt", + out: "select * from _vt.foo where id=1", + }, + { + name: "simple select with new db needing escaping", + in: "select * from _vt.foo", + newdb: "1_vt-test", + out: "select * from `1_vt-test`.foo", + }, + { + name: "multi query", + in: "select * from _vt.foo ; select * from _vt.bar", + out: "select * from foo;select * from bar", + }, + { + name: "multi query with new db", + in: "select * from _vt.foo ; select * from _vt.bar", + newdb: "_vt_test", + out: "select * from _vt_test.foo;select * from _vt_test.bar", + }, + { + name: "multi query with error", + in: "select * from _vt.foo ; select * from _vt.bar ; sel ect fr om wh at", + wantErr: true, + }, + } + parser := NewTestParser() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parser.ReplaceTableQualifiersMultiQuery(tt.in, origDB, tt.newdb) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.out, got, "RemoveTableQualifiers(); in: %s, out: %s", tt.in, got) + }) + } +} diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 229353e7f17..088ff53282f 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -18,17 +18,43 @@ package tabletmanager import ( "context" + "errors" + "io" "vitess.io/vitess/go/constants/sidecar" "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" + "vitess.io/vitess/go/vt/proto/vtrpc" ) +// queriesHaveAllowZeroInDateDirective reutrns 'true' when at least one of the queries +// in the given SQL has a `/*vt+ allowZeroInDate=true */` directive. +func queriesHaveAllowZeroInDateDirective(sql string, parser *sqlparser.Parser) bool { + tokenizer := parser.NewStringTokenizer(sql) + for { + stmt, err := sqlparser.ParseNext(tokenizer) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return false + } + if cmnt, ok := stmt.(sqlparser.Commented); ok { + directives := cmnt.GetParsedComments().Directives() + if directives.IsSet("allowZeroInDate") { + return true + } + } + } + return false +} + // ExecuteFetchAsDba will execute the given query, possibly disabling binlogs and reload schema. func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanagerdatapb.ExecuteFetchAsDbaRequest) (*querypb.QueryResult, error) { if err := tm.waitForGrantsToHaveApplied(ctx); err != nil { @@ -55,25 +81,44 @@ func (tm *TabletManager) ExecuteFetchAsDba(ctx context.Context, req *tabletmanag _, _ = conn.ExecuteFetch("USE "+sqlescape.EscapeID(req.DbName), 1, false) } - // Handle special possible directives - var directives *sqlparser.CommentDirectives - if stmt, err := tm.SQLParser.Parse(string(req.Query)); err == nil { + allowZeroInDate := false + tokenizer := tm.SQLParser.NewStringTokenizer(string(req.Query)) + for { + stmt, err := sqlparser.ParseNext(tokenizer) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "could not parse statement in ExecuteFetchAsDba: %v: %v", string(req.Query), err) + } if cmnt, ok := stmt.(sqlparser.Commented); ok { - directives = cmnt.GetParsedComments().Directives() + directives := cmnt.GetParsedComments().Directives() + if directives.IsSet("allowZeroInDate") { + // --allow-zero-in-date Applies to DDLs. As a backport solution to + // https://github.com/vitessio/vitess/issues/14952, it is enough that + // one of the DDLs has the `allowZeroInDate` directive, that we allow + // zero in date for all queries. + allowZeroInDate = true + } } } - if directives.IsSet("allowZeroInDate") { + if allowZeroInDate { if _, err := conn.ExecuteFetch("set @@session.sql_mode=REPLACE(REPLACE(@@session.sql_mode, 'NO_ZERO_DATE', ''), 'NO_ZERO_IN_DATE', '')", 1, false); err != nil { return nil, err } } - // Replace any provided sidecar database qualifiers with the correct one. - uq, err := tm.SQLParser.ReplaceTableQualifiers(string(req.Query), sidecar.DefaultName, sidecar.GetName()) + uq, err := tm.SQLParser.ReplaceTableQualifiersMultiQuery(string(req.Query), sidecar.DefaultName, sidecar.GetName()) if err != nil { return nil, err } - result, err := conn.ExecuteFetch(uq, int(req.MaxRows), true /*wantFields*/) + result, more, err := conn.ExecuteFetchMulti(uq, int(req.MaxRows), true /*wantFields*/) + for more { + _, more, _, err = conn.ReadQueryResult(0, false) + if err != nil { + return nil, err + } + } // re-enable binlogs if necessary if req.DisableBinlogs && !conn.IsClosed() { diff --git a/go/vt/vttablet/tabletmanager/rpc_query_test.go b/go/vt/vttablet/tabletmanager/rpc_query_test.go index af7791b5374..8aa12bf0d82 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query_test.go +++ b/go/vt/vttablet/tabletmanager/rpc_query_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" @@ -28,11 +29,39 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/mysqlctl" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttablet/tabletservermock" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" ) +func TestQueriesHaveAllowZeroInDateDirective(t *testing.T) { + tcases := []struct { + query string + expected bool + }{ + { + query: "create table t(id int)", + expected: false, + }, + { + query: "create /*vt+ allowZeroInDate=true */ table t (id int)", + expected: true, + }, + { + query: "create table a (id int) ; create /*vt+ allowZeroInDate=true */ table b (id int)", + expected: true, + }, + } + for _, tcase := range tcases { + t.Run(tcase.query, func(t *testing.T) { + parser := sqlparser.NewTestParser() + got := queriesHaveAllowZeroInDateDirective(tcase.query, parser) + assert.Equal(t, tcase.expected, got) + }) + } +} + func TestTabletManager_ExecuteFetchAsDba(t *testing.T) { ctx := context.Background() cp := mysql.ConnParams{}