diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 6cc921e51be..03a1b90397e 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -293,11 +293,11 @@ type ( // Union represents a UNION statement. Union struct { + With *With Left SelectStatement Right SelectStatement Distinct bool OrderBy OrderBy - With *With Limit *Limit Lock Lock Into *SelectInto diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 39860a678dd..b29b4c90047 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -3117,10 +3117,10 @@ func CloneRefOfUnion(n *Union) *Union { return nil } out := *n + out.With = CloneRefOfWith(n.With) out.Left = CloneSelectStatement(n.Left) out.Right = CloneSelectStatement(n.Right) out.OrderBy = CloneOrderBy(n.OrderBy) - out.With = CloneRefOfWith(n.With) out.Limit = CloneRefOfLimit(n.Limit) out.Into = CloneRefOfSelectInto(n.Into) return &out diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 8b731f1c0e2..86dda29ebcf 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -5977,18 +5977,18 @@ func (c *cow) copyOnRewriteRefOfUnion(n *Union, parent SQLNode) (out SQLNode, ch } out = n if c.pre == nil || c.pre(n, parent) { + _With, changedWith := c.copyOnRewriteRefOfWith(n.With, n) _Left, changedLeft := c.copyOnRewriteSelectStatement(n.Left, n) _Right, changedRight := c.copyOnRewriteSelectStatement(n.Right, n) _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) - _With, changedWith := c.copyOnRewriteRefOfWith(n.With, n) _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) _Into, changedInto := c.copyOnRewriteRefOfSelectInto(n.Into, n) - if changedLeft || changedRight || changedOrderBy || changedWith || changedLimit || changedInto { + if changedWith || changedLeft || changedRight || changedOrderBy || changedLimit || changedInto { res := *n + res.With, _ = _With.(*With) res.Left, _ = _Left.(SelectStatement) res.Right, _ = _Right.(SelectStatement) res.OrderBy, _ = _OrderBy.(OrderBy) - res.With, _ = _With.(*With) res.Limit, _ = _Limit.(*Limit) res.Into, _ = _Into.(*SelectInto) out = &res diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index bb263c65e47..9beed3a8242 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -4590,10 +4590,10 @@ func (cmp *Comparator) RefOfUnion(a, b *Union) bool { return false } return a.Distinct == b.Distinct && + cmp.RefOfWith(a.With, b.With) && cmp.SelectStatement(a.Left, b.Left) && cmp.SelectStatement(a.Right, b.Right) && cmp.OrderBy(a.OrderBy, b.OrderBy) && - cmp.RefOfWith(a.With, b.With) && cmp.RefOfLimit(a.Limit, b.Limit) && a.Lock == b.Lock && cmp.RefOfSelectInto(a.Into, b.Into) diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index e1c5cb60e59..0121695fe8c 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -8594,6 +8594,11 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re return true } } + if !a.rewriteRefOfWith(node, node.With, func(newNode, parent SQLNode) { + parent.(*Union).With = newNode.(*With) + }) { + return false + } if !a.rewriteSelectStatement(node, node.Left, func(newNode, parent SQLNode) { parent.(*Union).Left = newNode.(SelectStatement) }) { @@ -8609,11 +8614,6 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re }) { return false } - if !a.rewriteRefOfWith(node, node.With, func(newNode, parent SQLNode) { - parent.(*Union).With = newNode.(*With) - }) { - return false - } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Union).Limit = newNode.(*Limit) }) { diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index 7eb418acf46..a88d689f102 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -3972,6 +3972,9 @@ func VisitRefOfUnion(in *Union, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } + if err := VisitRefOfWith(in.With, f); err != nil { + return err + } if err := VisitSelectStatement(in.Left, f); err != nil { return err } @@ -3981,9 +3984,6 @@ func VisitRefOfUnion(in *Union, f Visit) error { if err := VisitOrderBy(in.OrderBy, f); err != nil { return err } - if err := VisitRefOfWith(in.With, f); err != nil { - return err - } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index b6a73f4c318..714efff23c8 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4050,5 +4050,59 @@ "comment": "select with a target destination", "query": "select * from `user[-]`.user_metadata", "plan": "VT09017: SELECT with a target destination is not allowed" + }, + { + "comment": "simple WITH query", + "query": "with x as (select * from user) select * from x", + "plan": { + "QueryType": "SELECT", + "Original": "with x as (select * from user) select * from x", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select * from (select * from `user` where 1 != 1) as x where 1 != 1", + "Query": "select * from (select * from `user`) as x", + "Table": "`user`" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "UNION with WITH clause", + "query": "with x as (select id, foo from user) select * from x union select * from x", + "plan": { + "QueryType": "SELECT", + "Original": "with x as (select id, foo from user) select * from x union select * from x", + "Instructions": { + "OperatorType": "Distinct", + "Collations": [ + "(0:2)", + "(1:3)" + ], + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, foo, weight_string(id), weight_string(foo) from (select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1 union select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1) as dt where 1 != 1", + "Query": "select id, foo, weight_string(id), weight_string(foo) from (select id, foo from (select id, foo from `user`) as x union select id, foo from (select id, foo from `user`) as x) as dt", + "Table": "`user`" + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index ea4383db911..fe51b0a1678 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -254,16 +254,6 @@ "query": "with x as (select * from user) update x set name = 'f'", "plan": "VT12001: unsupported: WITH expression in UPDATE statement" }, - { - "comment": "unsupported with clause in select statement", - "query": "with x as (select * from user) select * from x", - "plan": "VT12001: unsupported: WITH expression in SELECT statement" - }, - { - "comment": "unsupported with clause in union statement", - "query": "with x as (select * from user) select * from x union select * from x", - "plan": "VT12001: unsupported: WITH expression in UNION statement" - }, { "comment": "insert having subquery in row values", "query": "insert into user(id, name) values ((select 1 from user where id = 1), 'A')", diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 9464276a699..c7452f3b5ba 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -60,24 +60,29 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case *sqlparser.With: return r.handleWith(node) case *sqlparser.AliasedTableExpr: - tbl, ok := node.Expr.(sqlparser.TableName) - if !ok || !tbl.Qualifier.IsEmpty() { - return nil - } - scope := r.scoper.currentScope() - cte := scope.findCTE(tbl.Name.String()) - if cte == nil { - return nil - } - if node.As.IsEmpty() { - node.As = tbl.Name - } - node.Expr = &sqlparser.DerivedTable{ - Select: cte.Subquery.Select, - } - if len(cte.Columns) > 0 { - node.Columns = cte.Columns - } + return r.handleAliasedTable(node) + } + return nil +} + +func (r *earlyRewriter) handleAliasedTable(node *sqlparser.AliasedTableExpr) error { + tbl, ok := node.Expr.(sqlparser.TableName) + if !ok || !tbl.Qualifier.IsEmpty() { + return nil + } + scope := r.scoper.currentScope() + cte := scope.findCTE(tbl.Name.String()) + if cte == nil { + return nil + } + if node.As.IsEmpty() { + node.As = tbl.Name + } + node.Expr = &sqlparser.DerivedTable{ + Select: cte.Subquery.Select, + } + if len(cte.Columns) > 0 { + node.Columns = cte.Columns } return nil } diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 2827f171b85..c3685913376 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -66,6 +66,8 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { s.pushDMLScope(node) case *sqlparser.Select: s.pushSelectScope(node) + case *sqlparser.Union: + s.pushUnionScope(node) case sqlparser.TableExpr: s.enterJoinScope(cursor) case sqlparser.SelectExprs: @@ -75,14 +77,21 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { case sqlparser.GroupBy: return s.addColumnInfoForGroupBy(cursor, node) case *sqlparser.Where: - if node.Type != sqlparser.HavingClause { - break + if node.Type == sqlparser.HavingClause { + return s.createSpecialScopePostProjection(cursor.Parent()) } - return s.createSpecialScopePostProjection(cursor.Parent()) } return nil } +func (s *scoper) pushUnionScope(union *sqlparser.Union) { + currentScope := s.currentScope() + currScope := newScope(currentScope) + currScope.stmtScope = true + currScope.stmt = union + s.push(currScope) +} + func (s *scoper) addColumnInfoForGroupBy(cursor *sqlparser.Cursor, node sqlparser.GroupBy) error { err := s.createSpecialScopePostProjection(cursor.Parent()) if err != nil {