From 99fd17e6b627b6974d5321126aaabafa29f05a03 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 24 Oct 2023 10:10:34 +0200 Subject: [PATCH] test: add scoping WITH tests Signed-off-by: Andres Taylor --- go/vt/vtgate/semantics/analyzer_test.go | 110 ++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 8 deletions(-) diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index ffd7fd31f23..61dd1ee4a9d 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -180,14 +180,6 @@ func TestBindingMultiTablePositive(t *testing.T) { query: "select case t.col when s.col then r.col else u.col end from t, s, r, w, u", deps: MergeTableSets(TS0, TS1, TS2, TS4), numberOfTables: 4, - // }, { - // TODO: move to subquery - // make sure that we don't let sub-query dependencies leak out by mistake - // query: "select t.col + (select 42 from s) from t", - // deps: TS0, - // }, { - // query: "select (select 42 from s where r.id = s.id) from r", - // deps: TS0 | TS1, }, { query: "select u1.a + u2.a from u1, u2", deps: MergeTableSets(TS0, TS1), @@ -990,6 +982,108 @@ func TestScopingWDerivedTables(t *testing.T) { } } +func TestScopingWithWITH(t *testing.T) { + queries := []struct { + query string + errorMessage string + recursive, direct TableSet + }{ + { + query: "with t as (select x as id from user) select id from t", + recursive: TS0, + direct: TS1, + }, { + query: "with t as (select foo as id from user) select id from t", + recursive: TS0, + direct: TS1, + }, { + query: "with c as (select x as foo from user), t as (select foo as id from c) select id from t", + recursive: TS0, + direct: TS2, + }, { + query: "with t as (select foo as id from user) select t.id from t", + recursive: TS0, + direct: TS1, + }, { + query: "select t.id2 from (select foo as id from user) as t", + errorMessage: "column 't.id2' not found", + }, { + query: "with t as (select 42 as id) select id from t", + recursive: T0, + direct: TS1, + }, { + query: "with t as (select 42 as id) select t.id from t", + recursive: T0, + direct: TS1, + }, { + query: "with t as (select 42 as id) select ks.t.id from t", + errorMessage: "column 'ks.t.id' not found", + }, { + query: "with t as (select id, id from user) select * from t", + errorMessage: "Duplicate column name 'id'", + }, { + query: "with t as (select id as baz from user) select t.baz = 1 from t", + direct: TS1, + recursive: TS0, + }, { + query: "with t as (select * from user, music) select t.id from t", + direct: TS2, + recursive: MergeTableSets(TS0, TS1), + }, { + query: "with t as (select * from user, music) select t.id from t order by t.id", + direct: TS2, + recursive: MergeTableSets(TS0, TS1), + }, { + query: "with t as (select * from user) select t.id from t join user as u on t.id = u.id", + direct: TS1, + recursive: TS0, + }, { + query: "with t as (select t1.id, t1.col1 from t1 join t2) select t.col1 from t3 ua join t", + direct: TS3, + recursive: TS1, + }, { + query: "with uu as (select id from t1) select uu.test from uu", + errorMessage: "column 'uu.test' not found", + }, { + query: "with uu as (select id as col from t1) select uu.id from uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id as col from t1) uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", + direct: TS1, + recursive: TS0, + }, { + query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + direct: T0, + recursive: T0, + }} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse, "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewIdentifierCS("t")}, + }, + }) + + switch { + case query.errorMessage != "" && err != nil: + require.EqualError(t, err, query.errorMessage) + case query.errorMessage != "": + require.EqualError(t, st.NotUnshardedErr, query.errorMessage) + default: + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursive, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") + assert.Equal(t, query.direct, st.DirectDeps(extract(sel, 0)), "DirectDeps") + } + }) + } +} + func TestJoinPredicateDependencies(t *testing.T) { // create table t() // create table t1(id bigint)