From 6cf1275c5ddfe9186ac62055917a0568b303bcfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Thu, 22 Feb 2024 15:29:27 +0100 Subject: [PATCH] [release-18.0] Column alias expanding on ORDER BY (#15302) (#15331) Signed-off-by: Andres Taylor --- go/mysql/sqlerror/sql_error.go | 1 + go/test/endtoend/utils/cmp.go | 6 + .../queries/aggregation/aggregation_test.go | 18 +- .../vtgate/queries/dml/insert_test.go | 6 +- .../queries/lookup_queries/main_test.go | 2 +- .../endtoend/vtgate/queries/misc/misc_test.go | 2 +- .../vtgate/queries/orderby/orderby_test.go | 70 ++++ .../vtgate/queries/orderby/schema.sql | 9 + .../vtgate/queries/orderby/vschema.json | 15 +- .../vtgate/queries/union/union_test.go | 9 +- go/vt/vterrors/state.go | 1 + .../testdata/postprocess_cases.json | 134 +++++++ .../testdata/unsupported_cases.json | 5 - go/vt/vtgate/semantics/analyzer.go | 15 +- go/vt/vtgate/semantics/binder.go | 10 +- go/vt/vtgate/semantics/early_rewriter.go | 355 +++++++++++++++--- go/vt/vtgate/semantics/early_rewriter_test.go | 207 ++++++++-- go/vt/vtgate/semantics/errors.go | 14 + 18 files changed, 748 insertions(+), 131 deletions(-) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index 9b1f65c82e3..6e7b8eb95e9 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -243,6 +243,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.CharacterSetMismatch: {num: ERCharacterSetMismatch, state: SSUnknownSQLState}, vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, + vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/test/endtoend/utils/cmp.go b/go/test/endtoend/utils/cmp.go index 38726d6c3aa..8d0b56ac6b3 100644 --- a/go/test/endtoend/utils/cmp.go +++ b/go/test/endtoend/utils/cmp.go @@ -70,6 +70,12 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) { } } +// SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t, +// This is because we might be inside a Run block that has a different `t` variable +func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) { + SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary) +} + // AssertMatchesAny ensures the given query produces any one of the expected results. func (mcmp *MySQLCompare) AssertMatchesAny(query string, expected ...string) { mcmp.t.Helper() diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index 5d43ffc4a33..835f78144ae 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -94,7 +94,7 @@ func TestEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having 1 = 1", `[[INT64(5)]]`) @@ -179,7 +179,7 @@ func TestNotEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a != 5", `[]`) @@ -203,7 +203,7 @@ func TestLessFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a < 10", `[[INT64(5)]]`) mcmp.AssertMatches("select count(*) as a from aggr_test having 1 < a", `[[INT64(5)]]`) @@ -226,7 +226,7 @@ func TestLessEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a <= 10", `[[INT64(5)]]`) @@ -250,7 +250,7 @@ func TestGreaterFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a > 1", `[[INT64(5)]]`) @@ -274,7 +274,7 @@ func TestGreaterEqualFilterOnScatter(t *testing.T) { workloads := []string{"oltp", "olap"} for _, workload := range workloads { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches("select count(*) as a from aggr_test having a >= 1", `[[INT64(5)]]`) @@ -309,7 +309,7 @@ func TestAggOnTopOfLimit(t *testing.T) { mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',6), (2,'a',1), (3,'b',1), (4,'c',3), (5,'c',4), (6,'b',null), (7,null,2), (8,null,null)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = '%s'", workload)) mcmp.AssertMatches(" select count(*) from (select id, val1 from aggr_test where val2 < 4 limit 2) as x", "[[INT64(2)]]") mcmp.AssertMatches(" select count(val1) from (select id, val1 from aggr_test where val2 < 4 order by val1 desc limit 2) as x", "[[INT64(2)]]") @@ -342,7 +342,7 @@ func TestEmptyTableAggr(t *testing.T) { defer closer() for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") @@ -354,7 +354,7 @@ func TestEmptyTableAggr(t *testing.T) { mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'b1','bar',200)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) mcmp.AssertMatches(" select count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") mcmp.AssertMatches(" select count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]") diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index 80d0602b898..1d09d3aab51 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -38,7 +38,7 @@ func TestSimpleInsertSelect(t *testing.T) { mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)") for i, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) qr := mcmp.Exec(fmt.Sprintf("insert into s_tbl(id, num) select id*%d, num*%d from s_tbl where id < 10", 10+i, 20+i)) @@ -65,7 +65,7 @@ func TestFailureInsertSelect(t *testing.T) { mcmp.Exec("insert into u_tbl(id, num) values (1,2),(3,4)") for _, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) // primary key same @@ -386,7 +386,7 @@ func TestInsertSelectUnshardedUsingSharded(t *testing.T) { mcmp.Exec("insert into s_tbl(id, num) values (1,2),(3,4)") for _, mode := range []string{"oltp", "olap"} { - t.Run(mode, func(t *testing.T) { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) qr := mcmp.Exec("insert into u_tbl(id, num) select id, num from s_tbl where s_tbl.id in (1,3)") assert.EqualValues(t, 2, qr.RowsAffected) diff --git a/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go b/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go index c385941502a..25bf78437da 100644 --- a/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go +++ b/go/test/endtoend/vtgate/queries/lookup_queries/main_test.go @@ -134,7 +134,7 @@ func TestLookupQueries(t *testing.T) { (3, 'monkey', 'monkey')`) for _, workload := range []string{"olap", "oltp"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) mcmp.AssertMatches("select id from user where lookup = 'apa'", "[[INT64(1)] [INT64(2)]]") diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 4eb005d272d..309da1c5941 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -268,7 +268,7 @@ func TestAnalyze(t *testing.T) { defer closer() for _, workload := range []string{"olap", "oltp"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload)) utils.Exec(t, mcmp.VtConn, "analyze table t1") utils.Exec(t, mcmp.VtConn, "analyze table uks.unsharded") diff --git a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go index 445a4d5a32f..09607bbd767 100644 --- a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go +++ b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go @@ -83,3 +83,73 @@ func TestOrderBy(t *testing.T) { mcmp.AssertMatches("select /*vt+ PLANNER=Gen4 */ id1, id2 from t4 order by reverse(id2) desc", `[[INT64(5) VARCHAR("test")] [INT64(8) VARCHAR("F")] [INT64(7) VARCHAR("e")] [INT64(6) VARCHAR("d")] [INT64(2) VARCHAR("Abc")] [INT64(4) VARCHAR("c")] [INT64(3) VARCHAR("b")] [INT64(1) VARCHAR("a")]]`) } } + +func TestOrderByComplex(t *testing.T) { + // tests written to try to trick the ORDER BY engine and planner + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into user(id, col, email) values(1,1,'a'), (2,2,'Abc'), (3,3,'b'), (4,4,'c'), (5,2,'test'), (6,1,'test'), (7,2,'a'), (8,3,'b'), (9,4,'c3'), (10,2,'d')") + + queries := []string{ + "select email, max(col) from user group by email order by col", + "select email, max(col) from user group by email order by col + 1", + "select email, max(col) from user group by email order by max(col)", + "select email, max(col) from user group by email order by max(col) + 1", + "select email, max(col) from user group by email order by min(col)", + "select email, max(col) as col from user group by email order by col", + "select email, max(col) as col from user group by email order by max(col)", + "select email, max(col) as col from user group by email order by col + 1", + "select email, max(col) as col from user group by email order by email + col", + "select email, max(col) as col from user group by email order by email + max(col)", + "select email, max(col) as col from user group by email order by email, col", + "select email, max(col) as xyz from user group by email order by email, xyz", + "select email, max(col) as xyz from user group by email order by email, max(xyz)", + "select email, max(col) as xyz from user group by email order by email, abs(xyz)", + "select email, max(col) as xyz from user group by email order by email, max(col)", + "select email, max(col) as xyz from user group by email order by email, abs(col)", + "select email, max(col) as xyz from user group by email order by xyz + email", + "select email, max(col) as xyz from user group by email order by abs(xyz) + email", + "select email, max(col) as xyz from user group by email order by abs(xyz)", + "select email, max(col) as xyz from user group by email order by abs(col)", + "select email, max(col) as max_col from user group by email order by max_col desc, length(email)", + "select email, max(col) as max_col, min(col) as min_col from user group by email order by max_col - min_col", + "select email, max(col) as col1, count(*) as col2 from user group by email order by col2 * col1", + "select email, sum(col) as sum_col from user group by email having sum_col > 10 order by sum_col / count(email)", + "select email, max(col) as max_col, char_length(email) as len_email from user group by email order by len_email, max_col desc", + "select email, max(col) as col_alias from user group by email order by case when col_alias > 100 then 0 else 1 end, col_alias", + "select email, count(*) as cnt, max(col) as max_col from user group by email order by cnt desc, max_col + cnt", + "select email, max(col) as max_col from user group by email order by if(max_col > 50, max_col, -max_col) desc", + "select email, max(col) as col, sum(col) as sum_col from user group by email order by col * sum_col desc", + "select email, max(col) as col, (select min(col) from user as u2 where u2.email = user.email) as min_col from user group by email order by col - min_col", + "select email, max(col) as max_col, (max(col) % 10) as mod_col from user group by email order by mod_col, max_col", + "select email, max(col) as 'value', count(email) as 'number' from user group by email order by 'number', 'value'", + "select email, max(col) as col, concat('email: ', email, ' col: ', max(col)) as complex_alias from user group by email order by complex_alias desc", + "select email, max(col) as max_col from user group by email union select email, min(col) as min_col from user group by email order by email", + "select email, max(col) as col from user where col > 50 group by email order by col desc", + "select email, max(col) as col from user group by email order by length(email), col", + "select email, max(col) as max_col, substring(email, 1, 3) as sub_email from user group by email order by sub_email, max_col desc", + "select email, max(col) as max_col from user group by email order by reverse(email), max_col", + "select email, max(col) as max_col from user group by email having max_col > avg(max_col) order by max_col desc", + "select email, count(*) as count, max(col) as max_col from user group by email order by count * max_col desc", + "select email, max(col) as col_alias from user group by email order by col_alias limit 10", + "select email, max(col) as col from user group by email order by col desc, email", + "select concat(email, ' ', max(col)) as combined from user group by email order by combined desc", + "select email, max(col) as max_col from user group by email order by ascii(email), max_col", + "select email, char_length(email) as email_length, max(col) as max_col from user group by email order by email_length desc, max_col", + "select email, max(col) as col from user group by email having col between 10 and 100 order by col", + "select email, max(col) as max_col, min(col) as min_col from user group by email order by max_col + min_col desc", + "select email, max(col) as 'max', count(*) as 'count' from user group by email order by 'max' desc, 'count'", + "select email, max(col) as max_col from (select email, col from user where col > 20) as filtered group by email order by max_col", + "select a.email, a.max_col from (select email, max(col) as max_col from user group by email) as a order by a.max_col desc", + "select email, max(col) as max_col from user where email like 'a%' group by email order by max_col, email", + } + + for _, query := range queries { + mcmp.Run(query, func(mcmp *utils.MySQLCompare) { + _, _ = mcmp.ExecAllowAndCompareError(query) + }) + } +} diff --git a/go/test/endtoend/vtgate/queries/orderby/schema.sql b/go/test/endtoend/vtgate/queries/orderby/schema.sql index 8f0131db357..0980d845c3d 100644 --- a/go/test/endtoend/vtgate/queries/orderby/schema.sql +++ b/go/test/endtoend/vtgate/queries/orderby/schema.sql @@ -27,3 +27,12 @@ create table t4_id2_idx ) Engine = InnoDB DEFAULT charset = utf8mb4 COLLATE = utf8mb4_general_ci; + +create table user +( + id bigint primary key, + col bigint, + email varchar(20) +) Engine = InnoDB + DEFAULT charset = utf8mb4 + COLLATE = utf8mb4_general_ci; diff --git a/go/test/endtoend/vtgate/queries/orderby/vschema.json b/go/test/endtoend/vtgate/queries/orderby/vschema.json index 14418850a35..59eb44fadcb 100644 --- a/go/test/endtoend/vtgate/queries/orderby/vschema.json +++ b/go/test/endtoend/vtgate/queries/orderby/vschema.json @@ -4,7 +4,7 @@ "hash": { "type": "hash" }, - "unicode_loose_md5" : { + "unicode_loose_md5": { "type": "unicode_loose_md5" }, "t1_id2_vdx": { @@ -54,7 +54,10 @@ "name": "hash" }, { - "columns": ["id2", "id1"], + "columns": [ + "id2", + "id1" + ], "name": "t4_id2_vdx" } ] @@ -66,6 +69,14 @@ "name": "unicode_loose_md5" } ] + }, + "user": { + "column_vindexes": [ + { + "column": "id", + "name": "hash" + } + ] } } } \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/union/union_test.go b/go/test/endtoend/vtgate/queries/union/union_test.go index 262d7fdb168..2f287fd7b9c 100644 --- a/go/test/endtoend/vtgate/queries/union/union_test.go +++ b/go/test/endtoend/vtgate/queries/union/union_test.go @@ -57,7 +57,7 @@ func TestUnionDistinct(t *testing.T) { mcmp.Exec("insert into t2(id3, id4) values (2, 3), (3, 4), (4,4), (5,5)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) mcmp.AssertMatches("select 1 union select null", "[[INT64(1)] [NULL]]") mcmp.AssertMatches("select null union select null", "[[NULL]]") @@ -69,10 +69,7 @@ func TestUnionDistinct(t *testing.T) { mcmp.AssertMatchesNoOrder("select id1 from t1 where id1 = 1 union select 452 union select id1 from t1 where id1 = 4", "[[INT64(1)] [INT64(452)] [INT64(4)]]") mcmp.AssertMatchesNoOrder("select id1, id2 from t1 union select 827, 452 union select id3,id4 from t2", "[[INT64(4) INT64(4)] [INT64(1) INT64(1)] [INT64(2) INT64(2)] [INT64(3) INT64(3)] [INT64(827) INT64(452)] [INT64(2) INT64(3)] [INT64(3) INT64(4)] [INT64(5) INT64(5)]]") - t.Run("skipped for now", func(t *testing.T) { - t.Skip() - mcmp.AssertMatches("select 1 from dual where 1 IN (select 1 as col union select 2)", "[[INT64(1)]]") - }) + mcmp.AssertMatches("select 1 from dual where 1 IN (select 1 as col union select 2)", "[[INT64(1)]]") mcmp.AssertMatches(`SELECT 1 from t1 UNION SELECT 2 from t1`, `[[INT64(1)] [INT64(2)]]`) mcmp.AssertMatches(`SELECT 5 from t1 UNION SELECT 6 from t1`, `[[INT64(5)] [INT64(6)]]`) mcmp.AssertMatchesNoOrder(`SELECT id1 from t1 UNION SELECT id2 from t1`, `[[INT64(1)] [INT64(2)] [INT64(3)] [INT64(4)]]`) @@ -95,7 +92,7 @@ func TestUnionAll(t *testing.T) { mcmp.Exec("insert into t2(id3, id4) values(3, 3), (4, 4)") for _, workload := range []string{"oltp", "olap"} { - t.Run(workload, func(t *testing.T) { + mcmp.Run(workload, func(mcmp *utils.MySQLCompare) { utils.Exec(t, mcmp.VtConn, "set workload = "+workload) // union all between two selectuniqueequal mcmp.AssertMatches("select id1 from t1 where id1 = 1 union all select id1 from t1 where id1 = 4", "[[INT64(1)]]") diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 5e3dcf22dfb..a1c6ebef3c9 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -47,6 +47,7 @@ const ( WrongValueCountOnRow WrongValue WrongArguments + InvalidGroupFuncUse // failed precondition NoDB diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 097877c14a5..ad1c11cc69c 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -2205,5 +2205,139 @@ "user.user" ] } + }, + { + "comment": "ORDER BY literal works fine even when the columns have the same name", + "query": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 1", + "plan": { + "QueryType": "SELECT", + "Original": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 1", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|2) ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,L:1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a.id, weight_string(a.id) from `user` as a where 1 != 1", + "Query": "select a.id, weight_string(a.id) from `user` as a", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select b.id from user_extra as b where 1 != 1", + "Query": "select b.id from user_extra as b", + "Table": "user_extra" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1, 2, weight_string(1) from dual where 1 != 1", + "Query": "select 1, 2, weight_string(1) from dual", + "Table": "dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "ORDER BY literal works fine even when the columns have the same name", + "query": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 2", + "plan": { + "QueryType": "SELECT", + "Original": "select a.id, b.id from user as a, user_extra as b union all select 1, 2 order by 2", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(1|2) ASC", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Concatenate", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a.id from `user` as a where 1 != 1", + "Query": "select a.id from `user` as a", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select b.id, weight_string(b.id) from user_extra as b where 1 != 1", + "Query": "select b.id, weight_string(b.id) from user_extra as b", + "Table": "user_extra" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1, 2, weight_string(2) from dual where 1 != 1", + "Query": "select 1, 2, weight_string(2) from dual", + "Table": "dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user", + "user.user_extra" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index a4b6576cbff..da4f1aaf0cf 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -354,11 +354,6 @@ "query": "select id2 from user uu where id in (select id from user where id = uu.id and user.col in (select col from (select col, id, user_id from user_extra where user_id = 5) uu where uu.user_id = uu.id))", "plan": "VT12001: unsupported: correlated subquery is only supported for EXISTS" }, - { - "comment": "rewrite of 'order by 2' that becomes 'order by id', leading to ambiguous binding.", - "query": "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 2", - "plan": "Column 'id' in field list is ambiguous" - }, { "comment": "unsupported with clause in delete statement", "query": "with x as (select * from user) delete from x", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 2f6f66b5d3d..b71c76f773b 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -69,6 +69,8 @@ func (a *analyzer) lateInit() { scoper: a.scoper, binder: a.binder, expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, + aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, + reAnalyze: a.lateAnalyze, } } @@ -211,10 +213,6 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } - if err := a.scoper.up(cursor); err != nil { - a.setError(err) - return false - } if err := a.tables.up(cursor); err != nil { a.setError(err) return false @@ -235,6 +233,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return true } + if err := a.scoper.up(cursor); err != nil { + a.setError(err) + return false + } + a.leaveProjection(cursor) return a.shouldContinue() } @@ -325,6 +328,10 @@ func (a *analyzer) analyze(statement sqlparser.Statement) error { a.lateInit() + return a.lateAnalyze(statement) +} + +func (a *analyzer) lateAnalyze(statement sqlparser.SQLNode) error { _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) return a.err } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 764a8e41501..1e4b1d00b47 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -61,7 +61,7 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { currScope := b.scoper.currentScope() for _, ident := range node.Using { name := sqlparser.NewColName(ident.String()) - deps, err := b.resolveColumn(name, currScope, true) + deps, err := b.resolveColumn(name, currScope, true, true) if err != nil { return err } @@ -69,7 +69,7 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } case *sqlparser.ColName: currentScope := b.scoper.currentScope() - deps, err := b.resolveColumn(node, currentScope, false) + deps, err := b.resolveColumn(node, currentScope, false, true) if err != nil { if deps.direct.IsEmpty() || !strings.HasSuffix(err.Error(), "is ambiguous") || @@ -155,7 +155,7 @@ func (b *binder) rewriteJoinUsingColName(deps dependency, node *sqlparser.ColNam Name: sqlparser.NewIdentifierCS(alias.String()), } } - deps, err = b.resolveColumn(node, currentScope, false) + deps, err = b.resolveColumn(node, currentScope, false, true) if err != nil { return dependency{}, err } @@ -196,7 +196,7 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *sc b.direct[subq] = subqDirectDeps.KeepOnly(tablesToKeep) } -func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti bool) (dependency, error) { +func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti, singleTableFallBack bool) (dependency, error) { var thisDeps dependencies first := true var tableName *sqlparser.TableName @@ -218,7 +218,7 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow } else if err != nil { return dependency{}, err } - if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() { + if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() && singleTableFallBack { // if this is the top scope, and we still haven't been able to find a match, we know we are about to fail // we can check this last scope and see if there is a single table. if there is just one table in the scope // we assume that the column is meant to come from this table. diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 68f663b29db..8bed6076d0f 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -33,6 +33,12 @@ type earlyRewriter struct { clause string warning string expandedColumns map[sqlparser.TableName][]*sqlparser.ColName + aliasMapCache map[*sqlparser.Select]map[string]exprContainer + + // reAnalyze is used when we are running in the late stage, after the other parts of semantic analysis + // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been + // typed, scoped and bound correctly + reAnalyze func(n sqlparser.SQLNode) error } func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { @@ -42,15 +48,7 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) case *sqlparser.JoinTableExpr: - r.handleJoinTableExpr(node) - case sqlparser.OrderBy: - r.clause = "order clause" - iter := &orderByIterator{ - node: node, - idx: -1, - } - - return r.handleOrderByAndGroupBy(cursor.Parent(), iter) + r.handleJoinTableExprDown(node) case *sqlparser.OrExpr: rewriteOrExpr(cursor, node) case *sqlparser.NotExpr: @@ -61,13 +59,44 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { node: node, idx: -1, } - return r.handleOrderByAndGroupBy(cursor.Parent(), iter) + return r.handleGroupBy(cursor.Parent(), iter) case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) } return nil } +func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { + // this rewriting is done in the `up` phase, because we need the scope to have been + // filled in with the available tables + switch node := cursor.Node().(type) { + case *sqlparser.JoinTableExpr: + return r.handleJoinTableExprUp(node) + case sqlparser.OrderBy: + r.clause = "order clause" + iter := &orderByIterator{ + node: node, + idx: -1, + r: r, + } + return r.handleOrderBy(cursor.Parent(), iter) + } + + return nil +} + +func (r *earlyRewriter) handleJoinTableExprUp(join *sqlparser.JoinTableExpr) error { + if len(join.Condition.Using) == 0 { + return nil + } + + err := rewriteJoinUsing(r.binder, join) + if err != nil { + return err + } + return r.reAnalyze(join.Condition.On) +} + func rewriteNotExpr(cursor *sqlparser.Cursor, node *sqlparser.NotExpr) { cmp, ok := node.Expr.(*sqlparser.ComparisonExpr) if !ok { @@ -83,33 +112,6 @@ func rewriteNotExpr(cursor *sqlparser.Cursor, node *sqlparser.NotExpr) { cursor.Replace(cmp) } -func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { - // this rewriting is done in the `up` phase, because we need the scope to have been - // filled in with the available tables - node, ok := cursor.Node().(*sqlparser.JoinTableExpr) - if !ok || len(node.Condition.Using) == 0 { - return nil - } - - err := rewriteJoinUsing(r.binder, node) - if err != nil { - return err - } - - // since the binder has already been over the join, we need to invoke it again, so it - // can bind columns to the right tables - sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool { - innerErr := r.binder.up(cursor) - if innerErr == nil { - return true - } - - err = innerErr - return false - }) - return err -} - // handleWhereClause processes WHERE clauses, specifically the HAVING clause. func (r *earlyRewriter) handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) error { sel, ok := parent.(*sqlparser.Select) @@ -119,7 +121,7 @@ func (r *earlyRewriter) handleWhereClause(node *sqlparser.Where, parent sqlparse if node.Type != sqlparser.HavingClause { return nil } - expr, err := r.rewriteAliasesInOrderByHavingAndGroupBy(node.Expr, sel) + expr, err := r.rewriteAliasesInHavingAndGroupBy(node.Expr, sel) if err != nil { return err } @@ -137,8 +139,8 @@ func (r *earlyRewriter) handleSelectExprs(cursor *sqlparser.Cursor, node sqlpars return r.expandStar(cursor, node) } -// handleJoinTableExpr processes JOIN table expressions and handles the Straight Join type. -func (r *earlyRewriter) handleJoinTableExpr(node *sqlparser.JoinTableExpr) { +// handleJoinTableExprDown processes JOIN table expressions and handles the Straight Join type. +func (r *earlyRewriter) handleJoinTableExprDown(node *sqlparser.JoinTableExpr) { if node.Join != sqlparser.StraightJoinType { return } @@ -149,6 +151,7 @@ func (r *earlyRewriter) handleJoinTableExpr(node *sqlparser.JoinTableExpr) { type orderByIterator struct { node sqlparser.OrderBy idx int + r *earlyRewriter } func (it *orderByIterator) next() sqlparser.Expr { @@ -197,13 +200,57 @@ type iterator interface { replace(e sqlparser.Expr) error } -func (r *earlyRewriter) replaceLiteralsInOrderByGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { +func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator) (bool, error) { + lit := getIntLiteral(e) + if lit == nil { + return false, nil + } + + newExpr, recheck, err := r.rewriteOrderByExpr(lit) + if err != nil { + return false, err + } + + if getIntLiteral(newExpr) == nil { + coll, ok := e.(*sqlparser.CollateExpr) + if ok { + coll.Expr = newExpr + newExpr = coll + } + } else { + // the expression is still a literal int. that means that we don't really need to sort by it. + // we'll just replace the number with a string instead, just like mysql would do in this situation + // mysql> explain select 1 as foo from user group by 1; + // + // mysql> show warnings; + // +-------+------+-----------------------------------------------------------------+ + // | Level | Code | Message | + // +-------+------+-----------------------------------------------------------------+ + // | Note | 1003 | /* select#1 */ select 1 AS `foo` from `test`.`user` group by '' | + // +-------+------+-----------------------------------------------------------------+ + newExpr = sqlparser.NewStrLiteral("") + } + + err = iter.replace(newExpr) + if err != nil { + return false, err + } + if recheck { + err = r.reAnalyze(newExpr) + } + if err != nil { + return false, err + } + return true, nil +} + +func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { lit := getIntLiteral(e) if lit == nil { return false, nil } - newExpr, err := r.rewriteOrderByExpr(lit) + newExpr, err := r.rewriteGroupByExpr(lit) if err != nil { return false, err } @@ -253,7 +300,41 @@ func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { } // handleOrderBy processes the ORDER BY clause. -func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter iterator) error { +func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) error { + stmt, ok := parent.(sqlparser.SelectStatement) + if !ok { + return nil + } + + sel := sqlparser.GetFirstSelect(stmt) + for e := iter.next(); e != nil; e = iter.next() { + lit, err := r.replaceLiteralsInOrderBy(e, iter) + if err != nil { + return err + } + if lit { + continue + } + + expr, err := r.rewriteAliasesInOrderBy(e, sel) + if err != nil { + return err + } + + if err = iter.replace(expr); err != nil { + return err + } + + if err = r.reAnalyze(expr); err != nil { + return err + } + } + + return nil +} + +// handleGroupBy processes the GROUP BY clause. +func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) error { stmt, ok := parent.(sqlparser.SelectStatement) if !ok { return nil @@ -261,14 +342,14 @@ func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter i sel := sqlparser.GetFirstSelect(stmt) for e := iter.next(); e != nil; e = iter.next() { - lit, err := r.replaceLiteralsInOrderByGroupBy(e, iter) + lit, err := r.replaceLiteralsInGroupBy(e, iter) if err != nil { return err } if lit { continue } - expr, err := r.rewriteAliasesInOrderByHavingAndGroupBy(e, sel) + expr, err := r.rewriteAliasesInHavingAndGroupBy(e, sel) if err != nil { return err } @@ -287,7 +368,7 @@ func (r *earlyRewriter) handleOrderByAndGroupBy(parent sqlparser.SQLNode, iter i // in SELECT points to that expression, not any table column. // - However, if the aliased expression is an aggregation and the column identifier in // the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. -func (r *earlyRewriter) rewriteAliasesInOrderByHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { +func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { type ExprContainer struct { expr sqlparser.Expr ambiguous bool @@ -379,7 +460,178 @@ func (r *earlyRewriter) rewriteAliasesInOrderByHavingAndGroupBy(node sqlparser.E return } -func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { +// rewriteAliasesInOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases +// from the SELECT expressions when applicable, following MySQL scoping rules: +// - A column identifier without a table qualifier that matches an alias introduced +// in SELECT points to that expression, not any table column. +// - However, if the aliased expression is an aggregation and the column identifier in +// the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. +func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { + currentScope := r.scoper.currentScope() + if currentScope.isUnion { + // It is not safe to rewrite order by clauses in unions. + return node, nil + } + + aliases := r.getAliasMap(sel) + insideAggr := false + dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + insideAggr = true + } + + _, isSubq := node.(*sqlparser.Subquery) + return !isSubq + } + output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + var col *sqlparser.ColName + + switch node := cursor.Node().(type) { + case sqlparser.AggrFunc: + insideAggr = false + return + case *sqlparser.ColName: + col = node + default: + return + } + + if !col.Qualifier.IsEmpty() { + // we are only interested in columns not qualified by table names + return + } + + item, found := aliases[col.Name.Lowered()] + if !found { + // if there is no matching alias, there is no rewriting needed + return + } + + topLevel := col == node + if !topLevel && r.isColumnOnTable(col, currentScope) { + // we only want to replace columns that are not coming from the table + return + } + + if item.ambiguous { + err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUseOfGroupFunction{} + } + if err != nil { + cursor.StopTreeWalk() + return + } + + cursor.Replace(sqlparser.CloneExpr(item.expr)) + }, nil) + + expr = output.(sqlparser.Expr) + return +} + +func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) bool { + if !currentScope.stmtScope && currentScope.parent != nil { + currentScope = currentScope.parent + } + _, err := r.binder.resolveColumn(col, currentScope, false, false) + return err == nil +} + +func (r *earlyRewriter) getAliasMap(sel *sqlparser.Select) (aliases map[string]exprContainer) { + var found bool + aliases, found = r.aliasMapCache[sel] + if found { + return + } + aliases = map[string]exprContainer{} + for _, e := range sel.SelectExprs { + ae, ok := e.(*sqlparser.AliasedExpr) + if !ok { + continue + } + + var alias string + + item := exprContainer{expr: ae.Expr} + if !ae.As.IsEmpty() { + alias = ae.As.Lowered() + } else if col, ok := ae.Expr.(*sqlparser.ColName); ok { + alias = col.Name.Lowered() + } + + if old, alreadyExists := aliases[alias]; alreadyExists && !sqlparser.Equals.Expr(old.expr, item.expr) { + item.ambiguous = true + } + + aliases[alias] = item + } + return aliases +} + +type exprContainer struct { + expr sqlparser.Expr + ambiguous bool +} + +func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { + scope, found := r.scoper.specialExprScopes[node] + if !found { + return node, false, nil + } + num, err := strconv.Atoi(node.Val) + if err != nil { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "error parsing column number: %s", node.Val) + } + + stmt, isSel := scope.stmt.(*sqlparser.Select) + if !isSel { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "error invalid statement type, expect Select, got: %T", scope.stmt) + } + + if num < 1 || num > len(stmt.SelectExprs) { + return nil, false, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, r.clause) + } + + // We loop like this instead of directly accessing the offset, to make sure there are no unexpanded `*` before + for i := 0; i < num; i++ { + if _, ok := stmt.SelectExprs[i].(*sqlparser.AliasedExpr); !ok { + return nil, false, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "cannot use column offsets in %s when using `%s`", r.clause, sqlparser.String(stmt.SelectExprs[i])) + } + } + + colOffset := num - 1 + aliasedExpr, ok := stmt.SelectExprs[colOffset].(*sqlparser.AliasedExpr) + if !ok { + return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "don't know how to handle %s", sqlparser.String(node)) + } + + if scope.isUnion { + colName := sqlparser.NewColName(aliasedExpr.ColumnName()) + vtabl, ok := scope.tables[0].(*vTableInfo) + if !ok { + panic("BUG: not expected") + } + + // since column names can be ambiguous here, we want to do the binding by offset and not by column name + allColExprs := vtabl.cols[colOffset] + direct, recursive, typ := r.binder.org.depsForExpr(allColExprs) + r.binder.direct[colName] = direct + r.binder.recursive[colName] = recursive + if typ != nil { + r.binder.typer.exprTypes[colName] = *typ + } + + return colName, false, nil + } + + return realCloneOfColNames(aliasedExpr.Expr, false), true, nil +} + +func (r *earlyRewriter) rewriteGroupByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) { scope, found := r.scoper.specialExprScopes[node] if !found { return node, nil @@ -411,13 +663,8 @@ func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.E } if scope.isUnion { - col, isCol := aliasedExpr.Expr.(*sqlparser.ColName) - - if aliasedExpr.As.IsEmpty() && isCol { - return sqlparser.NewColName(col.Name.String()), nil - } - - return sqlparser.NewColName(aliasedExpr.ColumnName()), nil + colName := sqlparser.NewColName(aliasedExpr.ColumnName()) + return colName, nil } return realCloneOfColNames(aliasedExpr.Expr, false), nil diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index ffb8e441348..53ddbca3486 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -300,42 +300,84 @@ func TestRewriteJoinUsingColumns(t *testing.T) { } -func TestOrderByGroupByLiteral(t *testing.T) { +func TestGroupByLiteral(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, } cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expDeps TableSet + expErr string }{{ - sql: "select 1 as id from t1 order by 1", - expSQL: "select 1 as id from t1 order by '' asc", + sql: "select t1.col from t1 group by 1", + expSQL: "select t1.col from t1 group by t1.col", + expDeps: TS0, }, { - sql: "select t1.col from t1 order by 1", - expSQL: "select t1.col from t1 order by t1.col asc", + sql: "select t1.col as xyz from t1 group by 1", + expSQL: "select t1.col as xyz from t1 group by t1.col", + expDeps: TS0, }, { - sql: "select t1.col from t1 order by 1.0", - expSQL: "select t1.col from t1 order by 1.0 asc", + sql: "select id from t1 group by 2", + expErr: "Unknown column '2' in 'group statement'", }, { - sql: "select t1.col from t1 order by 'fubick'", - expSQL: "select t1.col from t1 order by 'fubick' asc", + sql: "select *, id from t1 group by 2", + expErr: "cannot use column offsets in group statement when using `*`", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(*sqlparser.Select) + st, err := Analyze(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + gb := selectStatement.GroupBy + deps := st.RecursiveDeps(gb[0]) + assert.Equal(t, tcase.expDeps, deps) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestOrderByLiteral(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{}, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expDeps TableSet + expErr string + }{{ + sql: "select 1 as id from t1 order by 1", + expSQL: "select 1 as id from t1 order by '' asc", + expDeps: NoTables, }, { - sql: "select t1.col as foo from t1 order by 1", - expSQL: "select t1.col as foo from t1 order by t1.col asc", + sql: "select t1.col from t1 order by 1", + expSQL: "select t1.col from t1 order by t1.col asc", + expDeps: TS0, }, { - sql: "select t1.col from t1 group by 1", - expSQL: "select t1.col from t1 group by t1.col", + sql: "select t1.col from t1 order by 1.0", + expSQL: "select t1.col from t1 order by 1.0 asc", + expDeps: NoTables, }, { - sql: "select t1.col as xyz from t1 group by 1", - expSQL: "select t1.col as xyz from t1 group by t1.col", + sql: "select t1.col from t1 order by 'fubick'", + expSQL: "select t1.col from t1 order by 'fubick' asc", + expDeps: NoTables, }, { - sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", - expSQL: "select t1.col as xyz, count(*) from t1 group by t1.col order by count(*) asc", + sql: "select t1.col as foo from t1 order by 1", + expSQL: "select t1.col as foo from t1 order by t1.col asc", + expDeps: TS0, }, { - sql: "select id from t1 group by 2", - expErr: "Unknown column '2' in 'group statement'", + sql: "select t1.col as xyz, count(*) from t1 group by 1 order by 2", + expSQL: "select t1.col as xyz, count(*) from t1 group by t1.col order by count(*) asc", + expDeps: TS0, }, { sql: "select id from t1 order by 2", expErr: "Unknown column '2' in 'order clause'", @@ -343,27 +385,41 @@ func TestOrderByGroupByLiteral(t *testing.T) { sql: "select *, id from t1 order by 2", expErr: "cannot use column offsets in order clause when using `*`", }, { - sql: "select *, id from t1 group by 2", - expErr: "cannot use column offsets in group statement when using `*`", + sql: "select id from t1 order by 1 collate utf8_general_ci", + expSQL: "select id from t1 order by id collate utf8_general_ci asc", + expDeps: TS0, + }, { + sql: "select id from `user` union select 1 from dual order by 1", + expSQL: "select id from `user` union select 1 from dual order by id asc", + expDeps: TS0, }, { - sql: "select id from t1 order by 1 collate utf8_general_ci", - expSQL: "select id from t1 order by id collate utf8_general_ci asc", + sql: "select id from t1 order by 2", + expErr: "Unknown column '2' in 'order clause'", }, { - sql: "select a.id from `user` union select 1 from dual order by 1", - expSQL: "select a.id from `user` union select 1 from dual order by id asc", + sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 1", + expSQL: "select a.id, b.id from `user` as a, user_extra as b union select 1, 2 from dual order by id asc", + expDeps: TS0, }, { - sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 1", - expErr: "Column 'id' in field list is ambiguous", + sql: "select a.id, b.id from user as a, user_extra as b union select 1, 2 order by 2", + expSQL: "select a.id, b.id from `user` as a, user_extra as b union select 1, 2 from dual order by id asc", + expDeps: TS1, + }, { + sql: "select user.id as foo from user union select col from user_extra order by 1", + expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", + expDeps: MergeTableSets(TS0, TS1), }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) require.NoError(t, err) selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + st, err := Analyze(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + ordering := selectStatement.GetOrderBy() + deps := st.RecursiveDeps(ordering[0].Expr) + assert.Equal(t, tcase.expDeps, deps) } else { require.EqualError(t, err, tcase.expErr) } @@ -371,7 +427,7 @@ func TestOrderByGroupByLiteral(t *testing.T) { } } -func TestHavingAndOrderByColumnName(t *testing.T) { +func TestHavingColumnName(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, } @@ -384,28 +440,97 @@ func TestHavingAndOrderByColumnName(t *testing.T) { sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1", }, { + sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + }, { + sql: "select foo + 2 as foo from t1 having foo = 42", + expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(sqlparser.SelectStatement) + _, err = Analyze(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestOrderByColumnName(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Keyspace: &vindexes.Keyspace{Name: "ks", Sharded: true}, + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("foo"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("bar"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expErr string + }{{ sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo", expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc", }, { - sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", - expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo + 1", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) + 1 asc", + }, { + sql: "select id, sum(foo) as sumOfFoo from t1 order by abs(sumOfFoo)", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(foo)) asc", + }, { + sql: "select id, sum(foo) as sumOfFoo from t1 order by max(sumOfFoo)", + expErr: "Invalid use of group function", + }, { + sql: "select id, sum(foo) as foo from t1 order by foo + 1", + expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + }, { + sql: "select id, sum(foo) as foo from t1 order by foo", + expSQL: "select id, sum(foo) as foo from t1 order by sum(foo) asc", }, { sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", }, { - // invalid according to group by rules, but still accepted by mysql - sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", - expSQL: "select id, t1.bar as foo from t1 group by id order by min(t1.bar) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by foo", + expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(foo)) asc", }, { - sql: "select foo + 2 as foo from t1 having foo = 42", - expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", }, { - sql: "select id, b as id, count(*) from t1 order by id", + sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", + expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + }, { + sql: "select id, bar as id, count(*) from t1 order by id", expErr: "Column 'id' in field list is ambiguous", }, { sql: "select id, id, count(*) from t1 order by id", expSQL: "select id, id, count(*) from t1 order by id asc", - }} + }, { + sql: "select id, count(distinct foo) k from t1 group by id order by k", + expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct foo) asc", + }, { + sql: "select user.id as foo from user union select col from user_extra order by foo", + expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", + }, + } for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.Parse(tcase.sql) diff --git a/go/vt/vtgate/semantics/errors.go b/go/vt/vtgate/semantics/errors.go index 8d0b23d7f82..6a5a7e29082 100644 --- a/go/vt/vtgate/semantics/errors.go +++ b/go/vt/vtgate/semantics/errors.go @@ -51,6 +51,7 @@ type ( AmbiguousColumnError struct{ Column string } SubqueryColumnCountError struct{ Expected int } ColumnsMissingInSchemaError struct{} + InvalidUseOfGroupFunction struct{} UnsupportedMultiTablesInUpdateError struct { ExprCount int @@ -261,3 +262,16 @@ func (e *ColumnsMissingInSchemaError) Error() string { func (e *ColumnsMissingInSchemaError) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } + +// InvalidUserOfGroupFunction +func (InvalidUseOfGroupFunction) Error() string { + return "Invalid use of group function" +} + +func (InvalidUseOfGroupFunction) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (InvalidUseOfGroupFunction) ErrorState() vterrors.State { + return vterrors.InvalidGroupFuncUse +}