Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for common table expressions #14321

Merged
merged 20 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions go/test/endtoend/vtgate/queries/derived/cte_test.go
systay marked this conversation as resolved.
Show resolved Hide resolved
systay marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
Copyright 2023 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package misc

import (
"testing"
)

func TestCTEWithOrderByLimit(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("with d as (select id,name from user order by id limit 2) select music.id from music join d on music.user_id = d.id")
}

func TestCTEAggregationOnRHS(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("set sql_mode = ''")
mcmp.Exec("with d as (select id, count(*) as a from user) select d.a from music join d on music.user_id = d.id group by 1")
}

func TestCTERemoveInnerOrderBy(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("with toto as (select user.id as oui, music.id as non from user join music on user.id = music.user_id order by user.name) select count(*) from toto")
}

func TestCTEWithHaving(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("set sql_mode = ''")
// For the given query, we can get any id back, because we aren't grouping by it.
mcmp.AssertMatchesAnyNoCompare("with s as (select id from user having count(*) >= 1) select * from s",
"[[INT64(1)]]", "[[INT64(2)]]", "[[INT64(3)]]", "[[INT64(4)]]", "[[INT64(5)]]")
}

func TestCTEColumns(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.AssertMatches(`with t(id) as (SELECT id FROM user) SELECT t.id FROM t ORDER BY t.id DESC`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
}
22 changes: 7 additions & 15 deletions go/test/endtoend/vtgate/queries/derived/derived_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {

deleteAll()

mcmp.Exec("insert into music(id, user_id) values(1,1), (2,5), (3,1), (4,2), (5,3), (6,4), (7,5)")
mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")

return mcmp, func() {
deleteAll()
mcmp.Close()
Expand All @@ -49,19 +52,13 @@ func TestDerivedTableWithOrderByLimit(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into music(id, user_id) values(1,1), (2,5), (3,1), (4,2), (5,3), (6,4), (7,5)")
mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")

mcmp.Exec("select /*vt+ PLANNER=Gen4 */ music.id from music join (select id,name from user order by id limit 2) as d on music.user_id = d.id")
}

func TestDerivedAggregationOnRHS(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into music(id, user_id) values(1,1), (2,5), (3,1), (4,2), (5,3), (6,4), (7,5)")
mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")

mcmp.Exec("set sql_mode = ''")
mcmp.Exec("select /*vt+ PLANNER=Gen4 */ d.a from music join (select id, count(*) as a from user) as d on music.user_id = d.id group by 1")
}
Expand All @@ -70,28 +67,23 @@ func TestDerivedRemoveInnerOrderBy(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into music(id, user_id) values(1,1), (2,5), (3,1), (4,2), (5,3), (6,4), (7,5)")
mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")

mcmp.Exec("select /*vt+ PLANNER=Gen4 */ count(*) from (select user.id as oui, music.id as non from user join music on user.id = music.user_id order by user.name) as toto")
}

func TestDerivedTableWithHaving(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into music(id, user_id) values(1,1), (2,5), (3,1), (4,2), (5,3), (6,4), (7,5)")
mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")

mcmp.Exec("set sql_mode = ''")
// For the given query, we can get any id back, because we aren't grouping by it.
mcmp.AssertMatchesAnyNoCompare("select /*vt+ PLANNER=Gen4 */ * from (select id from user having count(*) >= 1) s", "[[INT64(1)]]", "[[INT64(2)]]", "[[INT64(3)]]", "[[INT64(4)]]", "[[INT64(5)]]")
mcmp.AssertMatchesAnyNoCompare("select /*vt+ PLANNER=Gen4 */ * from (select id from user having count(*) >= 1) s",
"[[INT64(1)]]", "[[INT64(2)]]", "[[INT64(3)]]", "[[INT64(4)]]", "[[INT64(5)]]")
}

func TestDerivedTableColumns(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into user(id, name) values(1,'toto'), (2,'tata'), (3,'titi'), (4,'tete'), (5,'foo')")
mcmp.AssertMatches(`SELECT /*vt+ PLANNER=gen4 */ t.id FROM (SELECT id FROM user) AS t(id) ORDER BY t.id DESC`, `[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
mcmp.AssertMatches(`SELECT /*vt+ PLANNER=gen4 */ t.id FROM (SELECT id FROM user) AS t(id) ORDER BY t.id DESC`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
}
8 changes: 4 additions & 4 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type (

// With contains the lists of common table expression and specifies if it is recursive or not
With struct {
ctes []*CommonTableExpr
CTEs []*CommonTableExpr
Recursive bool
}

Expand Down Expand Up @@ -259,12 +259,12 @@ type (
Distinct bool
StraightJoinHint bool
SQLCalcFoundRows bool
// The From field must be the first AST element of this struct so the rewriter sees it first
// The With field needs to come before the FROM clause, so any CTEs have been handled before we analyze it
With *With
From []TableExpr
Comments *ParsedComments
SelectExprs SelectExprs
Where *Where
With *With
GroupBy GroupBy
Having *Where
Windows NamedWindows
Expand Down Expand Up @@ -293,11 +293,11 @@ type (

// Union represents a UNION statement.
Union struct {
With *With
Left SelectStatement
Right SelectStatement
Distinct bool
OrderBy OrderBy
With *With
Limit *Limit
Lock Lock
Into *SelectInto
Expand Down
6 changes: 3 additions & 3 deletions go/vt/sqlparser/ast_clone.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 13 additions & 13 deletions go/vt/sqlparser/ast_copy_on_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,19 @@ func (node *Insert) Format(buf *TrackedBuffer) {

// Format formats the node.
func (node *With) Format(buf *TrackedBuffer) {
if len(node.CTEs) == 0 {
return
}
buf.astPrintf(node, "with ")

if node.Recursive {
buf.astPrintf(node, "recursive ")
}
ctesLength := len(node.ctes)
ctesLength := len(node.CTEs)
for i := 0; i < ctesLength-1; i++ {
buf.astPrintf(node, "%v, ", node.ctes[i])
buf.astPrintf(node, "%v, ", node.CTEs[i])
}
buf.astPrintf(node, "%v", node.ctes[ctesLength-1])
buf.astPrintf(node, "%v", node.CTEs[ctesLength-1])
}

// Format formats the node.
Expand Down
9 changes: 6 additions & 3 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading