From c26bd04f9ce2e1e88f15577ac86dc7c133341595 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 18 Jul 2024 09:29:25 +0200 Subject: [PATCH 01/28] Semantic analysis for recursive CTEs Signed-off-by: Manan Gupta Signed-off-by: Andres Taylor --- .../vtgate/planbuilder/operators/ast_to_op.go | 22 +-- go/vt/vtgate/semantics/analyzer.go | 17 +-- go/vt/vtgate/semantics/analyzer_test.go | 28 +++- go/vt/vtgate/semantics/check_invalid.go | 4 - go/vt/vtgate/semantics/cte_table.go | 116 ++++++++++++++++ go/vt/vtgate/semantics/derived_table.go | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 4 +- go/vt/vtgate/semantics/foreign_keys_test.go | 121 +++++++---------- go/vt/vtgate/semantics/scoper.go | 3 + .../{semantic_state.go => semantic_table.go} | 4 - ...c_state_test.go => semantic_table_test.go} | 0 go/vt/vtgate/semantics/table_collector.go | 128 +++++++++++------- 12 files changed, 297 insertions(+), 152 deletions(-) create mode 100644 go/vt/vtgate/semantics/cte_table.go rename go/vt/vtgate/semantics/{semantic_state.go => semantic_table.go} (99%) rename go/vt/vtgate/semantics/{semantic_state_test.go => semantic_table_test.go} (100%) diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index f017f77d6a3..6a35a0ad921 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -254,24 +254,30 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr panic(err) } - if vt, isVindex := tableInfo.(*semantics.VindexTable); isVindex { + switch tableInfo := tableInfo.(type) { + case *semantics.VindexTable: solves := tableID return &Vindex{ Table: VindexTable{ TableID: tableID, Alias: tableExpr, Table: tbl, - VTable: vt.Table.GetVindexTable(), + VTable: tableInfo.Table.GetVindexTable(), }, - Vindex: vt.Vindex, + Vindex: tableInfo.Vindex, Solved: solves, } + case *semantics.CTETable: + panic(vterrors.VT12001("recursive common table expression")) + case *semantics.RealTable: + qg := newQueryGraph() + isInfSchema := tableInfo.IsInfSchema() + qt := &QueryTable{Alias: tableExpr, Table: tbl, ID: tableID, IsInfSchema: isInfSchema} + qg.Tables = append(qg.Tables, qt) + return qg + default: + panic(vterrors.VT13001(fmt.Sprintf("unknown table type %T", tableInfo))) } - qg := newQueryGraph() - isInfSchema := tableInfo.IsInfSchema() - qt := &QueryTable{Alias: tableExpr, Table: tbl, ID: tableID, IsInfSchema: isInfSchema} - qg.Tables = append(qg.Tables, qt) - return qg case *sqlparser.DerivedTable: if onlyTable && tbl.Select.GetLimit() == nil { tbl.Select.SetOrderBy(nil) diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 8bb7cc393fc..5be67c63436 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -357,7 +357,7 @@ func (a *analyzer) collationEnv() *collations.Environment { } func (a *analyzer) analyze(statement sqlparser.Statement) error { - _ = sqlparser.Rewrite(statement, nil, a.earlyUp) + _ = sqlparser.Rewrite(statement, a.earlyTables.down, a.earlyTables.up) if a.err != nil { return a.err } @@ -387,13 +387,13 @@ func (a *analyzer) reAnalyze(statement sqlparser.SQLNode) error { // canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys // if so, we can stop the analyzer early func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) { - ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables) - a.singleUnshardedKeyspace = ks != nil - if !a.singleUnshardedKeyspace { + if a.fullAnalysis { return false } - if a.fullAnalysis { + ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables) + a.singleUnshardedKeyspace = ks != nil + if !a.singleUnshardedKeyspace { return false } @@ -424,13 +424,6 @@ func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) return true } -// earlyUp collects tables in the query, so we can check -// if this a single unsharded query we are dealing with -func (a *analyzer) earlyUp(cursor *sqlparser.Cursor) bool { - a.earlyTables.up(cursor) - return true -} - func (a *analyzer) shouldContinue() bool { return a.err == nil } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 0fbf0911f3a..01c34639763 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -195,6 +195,31 @@ func TestBindingMultiTablePositive(t *testing.T) { } } +func TestBindingRecursiveCTEs(t *testing.T) { + type testCase struct { + query string + rdeps TableSet + ddeps TableSet + } + queries := []testCase{{ + query: "with recursive x as (select id from user union select x.id + 1 from x where x.id < 15) select t.id from x join x t;", + rdeps: MergeTableSets(TS0, TS1), // This is the user and `x` in the CTE + ddeps: TS3, // this is the t id + }, { + query: "WITH RECURSIVE user_cte AS (SELECT id, name FROM user WHERE id = 42 UNION ALL SELECT u.id, u.name FROM user u JOIN user_cte cte ON u.id = cte.id + 1 WHERE u.id = 42) SELECT id FROM user_cte", + rdeps: MergeTableSets(TS0, TS1, TS2), // This is the two uses of the user and `user_cte` in the CTE + ddeps: TS3, + }} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + stmt, semTable := parseAndAnalyzeStrict(t, query.query, "user") + sel := stmt.(*sqlparser.Select) + assert.Equal(t, query.rdeps, semTable.RecursiveDeps(extract(sel, 0)), "recursive") + assert.Equal(t, query.ddeps, semTable.DirectDeps(extract(sel, 0)), "direct") + }) + } +} + func TestBindingMultiAliasedTablePositive(t *testing.T) { type testCase struct { query string @@ -887,9 +912,6 @@ func TestInvalidQueries(t *testing.T) { }, { sql: "select 1 from t1 where (id, id) in (select 1, 2, 3)", serr: "Operand should contain 2 column(s)", - }, { - sql: "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - serr: "VT12001: unsupported: recursive common table expression", }, { sql: "with x as (select 1), x as (select 1) select * from x", serr: "VT03013: not unique table/alias: 'x'", diff --git a/go/vt/vtgate/semantics/check_invalid.go b/go/vt/vtgate/semantics/check_invalid.go index a739e857c00..6509f5f5ee8 100644 --- a/go/vt/vtgate/semantics/check_invalid.go +++ b/go/vt/vtgate/semantics/check_invalid.go @@ -48,10 +48,6 @@ func (a *analyzer) checkForInvalidConstructs(cursor *sqlparser.Cursor) error { } case *sqlparser.Subquery: return a.checkSubqueryColumns(cursor.Parent(), node) - case *sqlparser.With: - if node.Recursive { - return vterrors.VT12001("recursive common table expression") - } case *sqlparser.Insert: if !a.singleUnshardedKeyspace && node.Action == sqlparser.ReplaceAct { return ShardedError{Inner: &UnsupportedConstruct{errString: "REPLACE INTO with sharded keyspace"}} diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go new file mode 100644 index 00000000000..ea7bd514cca --- /dev/null +++ b/go/vt/vtgate/semantics/cte_table.go @@ -0,0 +1,116 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + "strings" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +// CTETable contains the information about the CTE table. +type CTETable struct { + tableName string + ASTNode *sqlparser.AliasedTableExpr + CTEDef +} + +var _ TableInfo = (*CTETable)(nil) + +func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef CTEDef) *CTETable { + var name string + if node.As.IsEmpty() { + name = t.Name.String() + } else { + name = node.As.String() + } + return &CTETable{ + tableName: name, + ASTNode: node, + CTEDef: cteDef, + } +} + +func (cte *CTETable) Name() (sqlparser.TableName, error) { + return sqlparser.NewTableName(cte.tableName), nil +} + +func (cte *CTETable) GetVindexTable() *vindexes.Table { + return nil +} + +func (cte *CTETable) IsInfSchema() bool { + return false +} + +func (cte *CTETable) matches(name sqlparser.TableName) bool { + return cte.tableName == name.Name.String() && name.Qualifier.IsEmpty() +} + +func (cte *CTETable) authoritative() bool { + return cte.isAuthoritative +} + +func (cte *CTETable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { + return cte.ASTNode +} + +func (cte *CTETable) canShortCut() shortCut { + return canShortCut +} + +func (cte *CTETable) getColumns(bool) []ColumnInfo { + selExprs := cte.definition.GetColumns() + cols := make([]ColumnInfo, 0, len(selExprs)) + for _, selExpr := range selExprs { + ae, isAe := selExpr.(*sqlparser.AliasedExpr) + if !isAe { + panic(vterrors.VT12001("should not be called")) + } + cols = append(cols, ColumnInfo{ + Name: ae.ColumnName(), + }) + } + return cols +} + +func (cte *CTETable) dependencies(colName string, org originable) (dependencies, error) { + directDeps := org.tableSetFor(cte.ASTNode) + for _, columnInfo := range cte.getColumns(false) { + if strings.EqualFold(columnInfo.Name, colName) { + return createCertain(directDeps, cte.recursive(org), evalengine.NewUnknownType()), nil + } + } + + if cte.authoritative() { + return ¬hing{}, nil + } + + return createUncertain(directDeps, cte.recursive(org)), nil +} + +func (cte *CTETable) getExprFor(s string) (sqlparser.Expr, error) { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Unknown column '%s' in 'field list'", s) +} + +func (cte *CTETable) getTableSet(org originable) TableSet { + return org.tableSetFor(cte.ASTNode) +} diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index aabbe9f0b22..684966f8ac8 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -146,7 +146,7 @@ func (dt *DerivedTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { } func (dt *DerivedTable) canShortCut() shortCut { - panic(vterrors.VT12001("should not be called")) + return canShortCut } // GetVindexTable implements the TableInfo interface diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 611c91e512c..038a4405f91 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -57,7 +57,9 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) case *sqlparser.With: - return r.handleWith(node) + if !node.Recursive { + return r.handleWith(node) + } case *sqlparser.AliasedTableExpr: return r.handleAliasedTable(node) case *sqlparser.Delete: diff --git a/go/vt/vtgate/semantics/foreign_keys_test.go b/go/vt/vtgate/semantics/foreign_keys_test.go index e1c26ecf569..a46c67c9710 100644 --- a/go/vt/vtgate/semantics/foreign_keys_test.go +++ b/go/vt/vtgate/semantics/foreign_keys_test.go @@ -141,13 +141,10 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "Collect all foreign key constraints", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - &DerivedTable{}, - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + &DerivedTable{}), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -171,12 +168,10 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "keyspace not found in schema information", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -188,12 +183,9 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "Cyclic fk constraints error", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], tbl["t1"], - &DerivedTable{}, - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], tbl["t1"], + &DerivedTable{}), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -236,17 +228,11 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { }, }, getError: func() error { return fmt.Errorf("ambiguous test error") }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t4"], - tbl["t5"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, + tables: makeTableCollector(&FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }}, tbl["t4"], + tbl["t5"]), } updateExprs := sqlparser.UpdateExprs{ &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, @@ -350,12 +336,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Delete Query", stmt: &sqlparser.Delete{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -389,12 +373,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { cold: SingleTableSet(1), }, }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t4"], - tbl["t5"], - }, - }, + tables: makeTableCollector(nil, + tbl["t4"], + tbl["t5"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -433,12 +415,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { Action: sqlparser.ReplaceAct, }, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -465,12 +445,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { Action: sqlparser.InsertAct, }, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -502,12 +479,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { colb: SingleTableSet(0), }, }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t6"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t6"], + tbl["t1"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -536,12 +510,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Insert error", stmt: &sqlparser.Insert{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -554,12 +525,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Update error", stmt: &sqlparser.Update{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -600,3 +568,12 @@ func pkInfo(parentTable *vindexes.Table, pCols []string, cCols []string) vindexe ChildColumns: sqlparser.MakeColumns(cCols...), } } + +func makeTableCollector(si SchemaInformation, tables ...TableInfo) *tableCollector { + return &tableCollector{ + earlyTableCollector: earlyTableCollector{ + Tables: tables, + si: si, + }, + } +} diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index ae3e5b7e88d..e775f4a52eb 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -240,6 +240,9 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { case sqlparser.AggrFunc: s.currentScope().inHavingAggr = false case sqlparser.TableExpr: + // inside joins and derived tables, we can only see the tables in the table/join. + // we also want the tables available in the outer query, for SELECT expressions and the WHERE clause, + // so we copy the tables from the current scope to the parent scope if isParentSelect(cursor) { curScope := s.currentScope() s.popScope() diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_table.go similarity index 99% rename from go/vt/vtgate/semantics/semantic_state.go rename to go/vt/vtgate/semantics/semantic_table.go index ac2fd9c1604..2eda2c5c29f 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -773,10 +773,6 @@ func singleUnshardedKeyspace(tableInfos []TableInfo) (ks *vindexes.Keyspace, tab } for _, table := range tableInfos { - if _, isDT := table.(*DerivedTable); isDT { - continue - } - sc := table.canShortCut() var vtbl *vindexes.Table diff --git a/go/vt/vtgate/semantics/semantic_state_test.go b/go/vt/vtgate/semantics/semantic_table_test.go similarity index 100% rename from go/vt/vtgate/semantics/semantic_state_test.go rename to go/vt/vtgate/semantics/semantic_table_test.go diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 948edb37d47..5d94ba5e1e4 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -28,45 +28,76 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -// tableCollector is responsible for gathering information about the tables listed in the FROM clause, -// and adding them to the current scope, plus keeping the global list of tables used in the query -type tableCollector struct { - Tables []TableInfo - scoper *scoper - si SchemaInformation - currentDb string - org originable - unionInfo map[*sqlparser.Union]unionInfo - done map[*sqlparser.AliasedTableExpr]TableInfo -} +type ( + // tableCollector is responsible for gathering information about the tables listed in the FROM clause, + // and adding them to the current scope, plus keeping the global list of tables used in the query + tableCollector struct { + earlyTableCollector + scoper *scoper + org originable + unionInfo map[*sqlparser.Union]unionInfo + } + + earlyTableCollector struct { + si SchemaInformation + currentDb string + Tables []TableInfo + done map[*sqlparser.AliasedTableExpr]TableInfo + cte map[string]CTEDef + } -type earlyTableCollector struct { - si SchemaInformation - currentDb string - Tables []TableInfo - done map[*sqlparser.AliasedTableExpr]TableInfo - withTables map[sqlparser.IdentifierCS]any + CTEDef struct { + definition sqlparser.SelectStatement + isAuthoritative bool + recursiveDeps *TableSet + } +) + +func (cte *CTEDef) recursive(org originable) (id TableSet) { + if cte.recursiveDeps != nil { + return *cte.recursiveDeps + } + + // We need to find the recursive dependencies of the CTE + // We'll do this by walking the inner query and finding all the tables + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + ate, ok := node.(*sqlparser.AliasedTableExpr) + if !ok { + return true, nil + } + id = id.Merge(org.tableSetFor(ate)) + return true, nil + }, cte.definition) + return } func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { return &earlyTableCollector{ - si: si, - currentDb: currentDb, - done: map[*sqlparser.AliasedTableExpr]TableInfo{}, - withTables: map[sqlparser.IdentifierCS]any{}, + si: si, + currentDb: currentDb, + done: map[*sqlparser.AliasedTableExpr]TableInfo{}, + cte: map[string]CTEDef{}, } } -func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { - switch node := cursor.Node().(type) { - case *sqlparser.AliasedTableExpr: - etc.visitAliasedTableExpr(node) - case *sqlparser.With: - for _, cte := range node.CTEs { - etc.withTables[cte.ID] = nil - } +func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { + with, ok := cursor.Node().(*sqlparser.With) + if !ok { + return true + } + for _, cte := range with.CTEs { + etc.cte[cte.ID.String()] = CTEDef{definition: cte.Subquery.Select} } + return true +} +func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) bool { + ate, ok := cursor.Node().(*sqlparser.AliasedTableExpr) + if !ok { + return true + } + etc.visitAliasedTableExpr(ate) + return true } func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { @@ -79,25 +110,22 @@ func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTabl func (etc *earlyTableCollector) newTableCollector(scoper *scoper, org originable) *tableCollector { return &tableCollector{ - Tables: etc.Tables, - scoper: scoper, - si: etc.si, - currentDb: etc.currentDb, - unionInfo: map[*sqlparser.Union]unionInfo{}, - done: etc.done, - org: org, + earlyTableCollector: *etc, + scoper: scoper, + unionInfo: map[*sqlparser.Union]unionInfo{}, + org: org, } } func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sqlparser.AliasedTableExpr) { if tbl.Qualifier.IsEmpty() { - _, isCTE := etc.withTables[tbl.Name] + _, isCTE := etc.cte[tbl.Name.String()] if isCTE { // no need to handle these tables here, we wait for the late phase instead return } } - tableInfo, err := getTableInfo(aet, tbl, etc.si, etc.currentDb) + tableInfo, err := etc.getTableInfo(aet, tbl) if err != nil { // this could just be a CTE that we haven't processed, so we'll give it the benefit of the doubt for now return @@ -304,7 +332,7 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq tableInfo, found = tc.done[node] if !found { - tableInfo, err = getTableInfo(node, t, tc.si, tc.currentDb) + tableInfo, err = tc.earlyTableCollector.getTableInfo(node, t) if err != nil { return err } @@ -315,12 +343,20 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq return scope.addTable(tableInfo) } -func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si SchemaInformation, currentDb string) (TableInfo, error) { +func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (TableInfo, error) { var tbl *vindexes.Table var vindex vindexes.Vindex + if t.Qualifier.IsEmpty() { + // CTE handling will not be used in the early table collection + cteDef, isCte := etc.cte[t.Name.String()] + if isCte { + return newCTETable(node, t, cteDef), nil + } + } + isInfSchema := sqlparser.SystemSchema(t.Qualifier.String()) var err error - tbl, vindex, _, _, _, err = si.FindTableOrVindex(t) + tbl, vindex, _, _, _, err = etc.si.FindTableOrVindex(t) if err != nil && !isInfSchema { // if we are dealing with a system table, it might not be available in the vschema, but that is OK return nil, err @@ -329,7 +365,7 @@ func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si Sc tbl = newVindexTable(t.Name) } - tableInfo, err := createTable(t, node, tbl, isInfSchema, vindex, si, currentDb) + tableInfo, err := etc.createTable(t, node, tbl, isInfSchema, vindex) if err != nil { return nil, err } @@ -437,14 +473,12 @@ func (tc *tableCollector) tableInfoFor(id TableSet) (TableInfo, error) { return tc.Tables[offset], nil } -func createTable( +func (etc *earlyTableCollector) createTable( t sqlparser.TableName, alias *sqlparser.AliasedTableExpr, tbl *vindexes.Table, isInfSchema bool, vindex vindexes.Vindex, - si SchemaInformation, - currentDb string, ) (TableInfo, error) { hint := getVindexHint(alias.Hints) @@ -458,13 +492,13 @@ func createTable( Table: tbl, VindexHint: hint, isInfSchema: isInfSchema, - collationEnv: si.Environment().CollationEnv(), + collationEnv: etc.si.Environment().CollationEnv(), } if alias.As.IsEmpty() { dbName := t.Qualifier.String() if dbName == "" { - dbName = currentDb + dbName = etc.currentDb } table.dbName = dbName From 08b10c483ba7601796623691de980711d3cb924f Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 19 Jul 2024 07:51:58 +0200 Subject: [PATCH 02/28] feat: add recurse operator and start of CTE planning Signed-off-by: Andres Taylor --- .../planbuilder/operators/SQL_builder.go | 39 ++++++++ .../vtgate/planbuilder/operators/ast_to_op.go | 39 +++++++- .../planbuilder/operators/cte_merging.go | 80 +++++++++++++++ .../planbuilder/operators/query_planning.go | 3 + go/vt/vtgate/planbuilder/operators/recurse.go | 99 +++++++++++++++++++ .../plancontext/planning_context.go | 24 +++++ .../planbuilder/testdata/from_cases.json | 44 +++++++++ .../testdata/unsupported_cases.json | 5 - go/vt/vtgate/semantics/cte_table.go | 71 +++++++++++-- go/vt/vtgate/semantics/table_collector.go | 29 +----- 10 files changed, 390 insertions(+), 43 deletions(-) create mode 100644 go/vt/vtgate/planbuilder/operators/cte_merging.go create mode 100644 go/vt/vtgate/planbuilder/operators/recurse.go diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 08cf3c4801c..bcde782e12d 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -56,6 +56,17 @@ func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement } func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { + if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { + tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) + if err != nil { + panic(err.Error()) + } + cte, isCTE := tblInfo.(*semantics.CTETable) + if isCTE { + tableName = cte.TableName + db = "" + } + } tableExpr := sqlparser.TableName{ Name: sqlparser.NewIdentifierCS(tableName), Qualifier: sqlparser.NewIdentifierCS(db), @@ -207,6 +218,25 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { } } +func (qb *queryBuilder) cteWith(other *queryBuilder, name string) { + cteUnion := &sqlparser.Union{ + Left: qb.stmt.(sqlparser.SelectStatement), + Right: other.stmt.(sqlparser.SelectStatement), + } + + qb.stmt = &sqlparser.Select{ + With: &sqlparser.With{ + CTEs: []*sqlparser.CommonTableExpr{{ + ID: sqlparser.NewIdentifierCS(name), + Columns: nil, + Subquery: &sqlparser.Subquery{Select: cteUnion}, + }}, + }, + } + + qb.addTable("", name, "", "", nil) +} + type FromStatement interface { GetFrom() []sqlparser.TableExpr SetFrom([]sqlparser.TableExpr) @@ -401,6 +431,8 @@ func buildQuery(op Operator, qb *queryBuilder) { buildDelete(op, qb) case *Insert: buildDML(op, qb) + case *Recurse: + buildCTE(op, qb) default: panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) } @@ -636,6 +668,13 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { sqlparser.RemoveKeyspaceInCol(qb.stmt) } +func buildCTE(op *Recurse, qb *queryBuilder) { + buildQuery(op.Init, qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildQuery(op.Tail, qbR) + qb.cteWith(qbR, op.Name) +} + func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { switch { case h1 == nil && h2 == nil: diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 6a35a0ad921..a18855688c1 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -256,7 +256,6 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr switch tableInfo := tableInfo.(type) { case *semantics.VindexTable: - solves := tableID return &Vindex{ Table: VindexTable{ TableID: tableID, @@ -265,10 +264,14 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr VTable: tableInfo.Table.GetVindexTable(), }, Vindex: tableInfo.Vindex, - Solved: solves, + Solved: tableID, } case *semantics.CTETable: - panic(vterrors.VT12001("recursive common table expression")) + current := ctx.ActiveCTE() + if current != nil && current.CTEDef.Equals(tableInfo.CTEDef) { + return createDualCTETable(ctx, tableID, tableInfo) + } + return createRecursiveCTE(ctx, tableInfo) case *semantics.RealTable: qg := newQueryGraph() isInfSchema := tableInfo.IsInfSchema() @@ -298,6 +301,36 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr } } +func createDualCTETable(ctx *plancontext.PlanningContext, tableID semantics.TableSet, tableInfo *semantics.CTETable) Operator { + vschemaTable, _, _, _, _, err := ctx.VSchema.FindTableOrVindex(sqlparser.NewTableName("dual")) + if err != nil { + panic(err) + } + qtbl := &QueryTable{ + ID: tableID, + Alias: tableInfo.ASTNode, + Table: sqlparser.NewTableName("dual"), + } + return createRouteFromVSchemaTable(ctx, qtbl, vschemaTable, false, nil) +} + +func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTETable) Operator { + union, ok := def.CTEDef.Query.(*sqlparser.Union) + if !ok { + panic(vterrors.VT13001("expected UNION in recursive CTE")) + } + + init := translateQueryToOp(ctx, union.Left) + + // Push the CTE definition to the stack so that it can be used in the recursive part of the query + ctx.PushCTE(def) + tail := translateQueryToOp(ctx, union.Right) + if err := ctx.PopCTE(); err != nil { + panic(err) + } + return newRecurse(def.TableName, init, tail) +} + func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Operator { var output Operator for _, tableExpr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go new file mode 100644 index 00000000000..31fe4a045c2 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -0,0 +1,80 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func tryMergeRecurse(ctx *plancontext.PlanningContext, in *Recurse) (Operator, *ApplyResult) { + op := tryMergeCTE(ctx, in.Init, in.Tail, in) + if op == nil { + return in, NoRewrite + } + + return op, Rewrote("Merged CTE") +} + +func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *Recurse) *Route { + initRoute, tailRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(init, tail) + if initRoute == nil || !sameKeyspace { + return nil + } + + switch { + case a == dual: + return mergeCTE(ctx, initRoute, tailRoute, routingB, in) + case a == sharded && b == sharded: + return tryMergeCTESharded(ctx, initRoute, tailRoute, in) + default: + return nil + } +} + +func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in *Recurse) *Route { + tblA := init.Routing.(*ShardedRouting) + tblB := tail.Routing.(*ShardedRouting) + switch tblA.RouteOpCode { + case engine.EqualUnique: + // If the two routes fully match, they can be merged together. + if tblB.RouteOpCode == engine.EqualUnique { + aVdx := tblA.SelectedVindex() + bVdx := tblB.SelectedVindex() + aExpr := tblA.VindexExpressions() + bExpr := tblB.VindexExpressions() + if aVdx == bVdx && gen4ValuesEqual(ctx, aExpr, bExpr) { + return mergeCTE(ctx, init, tail, tblA, in) + } + } + } + + return nil +} + +func mergeCTE(ctx *plancontext.PlanningContext, init, tail *Route, r Routing, in *Recurse) *Route { + return &Route{ + Routing: r, + Source: &Recurse{ + Name: in.Name, + ColumnNames: in.ColumnNames, + Init: init.Source, + Tail: tail.Source, + }, + MergedWith: []*Route{tail}, + } +} diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index e88fb53edb3..2731d732156 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -102,6 +102,9 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return tryPushDelete(in) case *Update: return tryPushUpdate(in) + case *Recurse: + return tryMergeRecurse(ctx, in) + default: return in, NoRewrite } diff --git a/go/vt/vtgate/planbuilder/operators/recurse.go b/go/vt/vtgate/planbuilder/operators/recurse.go new file mode 100644 index 00000000000..67ddea3adc8 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/recurse.go @@ -0,0 +1,99 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "slices" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +// Recurse is used to represent a recursive CTE +type Recurse struct { + // Name is the name of the recursive CTE + Name string + + // ColumnNames is the list of column names that are sent between the two parts of the recursive CTE + ColumnNames []string + + // ColumnOffsets is the list of column offsets that are sent between the two parts of the recursive CTE + Offsets []int + + Init, Tail Operator +} + +var _ Operator = (*Recurse)(nil) + +func newRecurse(name string, init, tail Operator) *Recurse { + return &Recurse{ + Name: name, + Init: init, + Tail: tail, + } +} + +func (r *Recurse) Clone(inputs []Operator) Operator { + return &Recurse{ + Name: r.Name, + ColumnNames: slices.Clone(r.ColumnNames), + Offsets: slices.Clone(r.Offsets), + Init: inputs[0], + Tail: inputs[1], + } +} + +func (r *Recurse) Inputs() []Operator { + return []Operator{r.Init, r.Tail} +} + +func (r *Recurse) SetInputs(operators []Operator) { + r.Init = operators[0] + r.Tail = operators[1] +} + +func (r *Recurse) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { + r.Tail = newFilter(r, e) + return r +} + +func (r *Recurse) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + return r.Init.AddColumn(ctx, reuseExisting, addToGroupBy, expr) +} + +func (r *Recurse) AddWSColumn(*plancontext.PlanningContext, int, bool) int { + panic("implement me") +} + +func (r *Recurse) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + return r.Init.FindCol(ctx, expr, underRoute) +} + +func (r *Recurse) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + return r.Init.GetColumns(ctx) +} + +func (r *Recurse) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + return r.Init.GetSelectExprs(ctx) +} + +func (r *Recurse) ShortDescription() string { return "" } + +func (r *Recurse) GetOrdering(*plancontext.PlanningContext) []OrderBy { + // Recurse is a special case. It never guarantees any ordering. + return nil +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 58be17febab..b56793d0e6c 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -66,6 +67,10 @@ type PlanningContext struct { // OuterTables contains the tables that are outer to the current query // Used to set the nullable flag on the columns OuterTables semantics.TableSet + + // This is a stack of CTEs being built. It's used when we have CTEs inside CTEs, + // to remember which is the CTE currently being assembled + CurrentCTE []*semantics.CTETable } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -376,3 +381,22 @@ func (ctx *PlanningContext) ContainsAggr(e sqlparser.SQLNode) (hasAggr bool) { }, e) return } + +func (ctx *PlanningContext) PushCTE(def *semantics.CTETable) { + ctx.CurrentCTE = append(ctx.CurrentCTE, def) +} + +func (ctx *PlanningContext) PopCTE() error { + if len(ctx.CurrentCTE) == 0 { + return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no CTE to pop") + } + ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] + return nil +} + +func (ctx *PlanningContext) ActiveCTE() *semantics.CTETable { + if len(ctx.CurrentCTE) == 0 { + return nil + } + return ctx.CurrentCTE[len(ctx.CurrentCTE)-1] +} diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 2e0fe429c1f..a13e1b4c69a 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4731,6 +4731,50 @@ ] } }, + { + "comment": "Merge into a single dual route", + "query": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } + }, + { + "comment": "Recursive CTE with star projection", + "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } + }, { "comment": "Cross keyspace join", "query": "select 1 from user join t1 on user.id = t1.id", diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 0e230b3e44d..9241cec595c 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -329,11 +329,6 @@ "query": "with user as (select aa from user where user.id=1) select ref.col from ref join user", "plan": "VT12001: unsupported: do not support CTE that use the CTE alias inside the CTE query" }, - { - "comment": "Recursive WITH", - "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - "plan": "VT12001: unsupported: recursive common table expression" - }, { "comment": "Alias cannot clash with base tables", "query": "WITH user AS (SELECT col FROM user) SELECT * FROM user", diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index ea7bd514cca..c7b4f41f8d8 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + "slices" "strings" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -28,7 +29,7 @@ import ( // CTETable contains the information about the CTE table. type CTETable struct { - tableName string + TableName string ASTNode *sqlparser.AliasedTableExpr CTEDef } @@ -42,15 +43,26 @@ func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef } else { name = node.As.String() } + + authoritative := true + for _, expr := range cteDef.Query.GetColumns() { + _, isStar := expr.(*sqlparser.StarExpr) + if isStar { + authoritative = false + break + } + } + cteDef.isAuthoritative = authoritative + return &CTETable{ - tableName: name, + TableName: name, ASTNode: node, CTEDef: cteDef, } } func (cte *CTETable) Name() (sqlparser.TableName, error) { - return sqlparser.NewTableName(cte.tableName), nil + return sqlparser.NewTableName(cte.TableName), nil } func (cte *CTETable) GetVindexTable() *vindexes.Table { @@ -62,7 +74,7 @@ func (cte *CTETable) IsInfSchema() bool { } func (cte *CTETable) matches(name sqlparser.TableName) bool { - return cte.tableName == name.Name.String() && name.Qualifier.IsEmpty() + return cte.TableName == name.Name.String() && name.Qualifier.IsEmpty() } func (cte *CTETable) authoritative() bool { @@ -78,23 +90,28 @@ func (cte *CTETable) canShortCut() shortCut { } func (cte *CTETable) getColumns(bool) []ColumnInfo { - selExprs := cte.definition.GetColumns() + selExprs := cte.Query.GetColumns() cols := make([]ColumnInfo, 0, len(selExprs)) - for _, selExpr := range selExprs { + for i, selExpr := range selExprs { ae, isAe := selExpr.(*sqlparser.AliasedExpr) if !isAe { panic(vterrors.VT12001("should not be called")) } - cols = append(cols, ColumnInfo{ - Name: ae.ColumnName(), - }) + if len(cte.Columns) == 0 { + cols = append(cols, ColumnInfo{Name: ae.ColumnName()}) + continue + } + + // We have column aliases defined on the CTE + cols = append(cols, ColumnInfo{Name: cte.Columns[i].String()}) } return cols } func (cte *CTETable) dependencies(colName string, org originable) (dependencies, error) { directDeps := org.tableSetFor(cte.ASTNode) - for _, columnInfo := range cte.getColumns(false) { + columns := cte.getColumns(false) + for _, columnInfo := range columns { if strings.EqualFold(columnInfo.Name, colName) { return createCertain(directDeps, cte.recursive(org), evalengine.NewUnknownType()), nil } @@ -114,3 +131,37 @@ func (cte *CTETable) getExprFor(s string) (sqlparser.Expr, error) { func (cte *CTETable) getTableSet(org originable) TableSet { return org.tableSetFor(cte.ASTNode) } + +type CTEDef struct { + Query sqlparser.SelectStatement + isAuthoritative bool + recursiveDeps *TableSet + Columns sqlparser.Columns +} + +func (cte CTEDef) recursive(org originable) (id TableSet) { + if cte.recursiveDeps != nil { + return *cte.recursiveDeps + } + + // We need to find the recursive dependencies of the CTE + // We'll do this by walking the inner query and finding all the tables + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + ate, ok := node.(*sqlparser.AliasedTableExpr) + if !ok { + return true, nil + } + id = id.Merge(org.tableSetFor(ate)) + return true, nil + }, cte.Query) + return +} + +func (cte CTEDef) Equals(other CTEDef) bool { + if !sqlparser.Equals.SelectStatement(cte.Query, other.Query) { + return false + } + return slices.EqualFunc(cte.Columns, other.Columns, func(a, b sqlparser.IdentifierCI) bool { + return a.Equal(b) + }) +} diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 5d94ba5e1e4..84effee1e4a 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -45,32 +45,8 @@ type ( done map[*sqlparser.AliasedTableExpr]TableInfo cte map[string]CTEDef } - - CTEDef struct { - definition sqlparser.SelectStatement - isAuthoritative bool - recursiveDeps *TableSet - } ) -func (cte *CTEDef) recursive(org originable) (id TableSet) { - if cte.recursiveDeps != nil { - return *cte.recursiveDeps - } - - // We need to find the recursive dependencies of the CTE - // We'll do this by walking the inner query and finding all the tables - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - ate, ok := node.(*sqlparser.AliasedTableExpr) - if !ok { - return true, nil - } - id = id.Merge(org.tableSetFor(ate)) - return true, nil - }, cte.definition) - return -} - func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { return &earlyTableCollector{ si: si, @@ -86,7 +62,10 @@ func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { return true } for _, cte := range with.CTEs { - etc.cte[cte.ID.String()] = CTEDef{definition: cte.Subquery.Select} + etc.cte[cte.ID.String()] = CTEDef{ + Query: cte.Subquery.Select, + Columns: cte.Columns, + } } return true } From c32fb87b09378f6f4c67ec16b0ca8d34c826db1d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 22 Jul 2024 16:42:51 +0200 Subject: [PATCH 03/28] first steps in adding Recurse primitive Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/cached_size.go | 34 +++++ go/vt/vtgate/engine/fake_primitive_test.go | 2 +- go/vt/vtgate/engine/recurse_cte.go | 139 ++++++++++++++++++ go/vt/vtgate/engine/recurse_cte_test.go | 129 ++++++++++++++++ .../planbuilder/operators/SQL_builder.go | 4 +- .../planbuilder/operators/cte_merging.go | 10 +- .../planbuilder/operators/query_planning.go | 2 +- .../operators/{recurse.go => recurse_cte.go} | 36 ++--- 8 files changed, 329 insertions(+), 27 deletions(-) create mode 100644 go/vt/vtgate/engine/recurse_cte.go create mode 100644 go/vt/vtgate/engine/recurse_cte_test.go rename go/vt/vtgate/planbuilder/operators/{recurse.go => recurse_cte.go} (57%) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index c05e276caa9..8c1fcc2f425 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -857,6 +857,40 @@ func (cached *Projection) CachedSize(alloc bool) int64 { } return size } + +//go:nocheckptr +func (cached *RecurseCTE) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field Init vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Init.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Recurse vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Recurse.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Vars map[string]int + if cached.Vars != nil { + size += int64(48) + hmap := reflect.ValueOf(cached.Vars) + numBuckets := int(math.Pow(2, float64((*(*uint8)(unsafe.Pointer(hmap.Pointer() + uintptr(9))))))) + numOldBuckets := (*(*uint16)(unsafe.Pointer(hmap.Pointer() + uintptr(10)))) + size += hack.RuntimeAllocSize(int64(numOldBuckets * 208)) + if len(cached.Vars) > 0 || numBuckets > 1 { + size += hack.RuntimeAllocSize(int64(numBuckets * 208)) + } + for k := range cached.Vars { + size += hack.RuntimeAllocSize(int64(len(k))) + } + } + return size +} func (cached *RenameFields) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index e992c2a4623..6ab54fe9e7b 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -80,7 +80,7 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar if r == nil { return nil, f.sendErr } - return r, nil + return r.Copy(), nil } func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go new file mode 100644 index 00000000000..f8745a3036e --- /dev/null +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -0,0 +1,139 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// RecurseCTE is used to represent recursive CTEs +// Init is used to represent the initial query. +// It's result are then used to start the recursion on the RecurseCTE side +// The values being sent to the RecurseCTE side are stored in the Vars map - +// the key is the bindvar name and the value is the index of the column in the recursive result +type RecurseCTE struct { + Init, Recurse Primitive + + Vars map[string]int +} + +var _ Primitive = (*RecurseCTE)(nil) + +func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + res, err := vcursor.ExecutePrimitive(ctx, r.Init, bindVars, wantfields) + if err != nil { + return nil, err + } + + // recurseRows contains the rows used in the next recursion + recurseRows := res.Rows + joinVars := make(map[string]*querypb.BindVariable) + for len(recurseRows) > 0 { + // copy over the results from the previous recursion + theseRows := recurseRows + recurseRows = nil + for _, row := range theseRows { + for k, col := range r.Vars { + joinVars[k] = sqltypes.ValueBindVariable(row[col]) + } + rresult, err := vcursor.ExecutePrimitive(ctx, r.Recurse, combineVars(bindVars, joinVars), false) + if err != nil { + return nil, err + } + recurseRows = append(recurseRows, rresult.Rows...) + res.Rows = append(res.Rows, rresult.Rows...) + } + } + return res, nil +} + +func (r *RecurseCTE) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + if vcursor.Session().InTransaction() { + res, err := r.TryExecute(ctx, vcursor, bindVars, wantfields) + if err != nil { + return err + } + return callback(res) + } + return vcursor.StreamExecutePrimitive(ctx, r.Init, bindVars, wantfields, func(result *sqltypes.Result) error { + err := callback(result) + if err != nil { + return err + } + return r.recurse(ctx, vcursor, bindVars, result, callback) + }) +} + +func (r *RecurseCTE) recurse(ctx context.Context, vcursor VCursor, bindvars map[string]*querypb.BindVariable, result *sqltypes.Result, callback func(*sqltypes.Result) error) error { + if len(result.Rows) == 0 { + return nil + } + joinVars := make(map[string]*querypb.BindVariable) + for _, row := range result.Rows { + for k, col := range r.Vars { + joinVars[k] = sqltypes.ValueBindVariable(row[col]) + } + + err := vcursor.StreamExecutePrimitive(ctx, r.Recurse, combineVars(bindvars, joinVars), false, func(result *sqltypes.Result) error { + err := callback(result) + if err != nil { + return err + } + return r.recurse(ctx, vcursor, bindvars, result, callback) + }) + if err != nil { + return err + } + } + return nil +} + +func (r *RecurseCTE) RouteType() string { + return "RecurseCTE" +} + +func (r *RecurseCTE) GetKeyspaceName() string { + if r.Init.GetKeyspaceName() == r.Recurse.GetKeyspaceName() { + return r.Init.GetKeyspaceName() + } + return r.Init.GetKeyspaceName() + "_" + r.Recurse.GetKeyspaceName() +} + +func (r *RecurseCTE) GetTableName() string { + return r.Init.GetTableName() +} + +func (r *RecurseCTE) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return r.Init.GetFields(ctx, vcursor, bindVars) +} + +func (r *RecurseCTE) NeedsTransaction() bool { + return false +} + +func (r *RecurseCTE) Inputs() ([]Primitive, []map[string]any) { + return []Primitive{r.Init, r.Recurse}, nil +} + +func (r *RecurseCTE) description() PrimitiveDescription { + return PrimitiveDescription{ + OperatorType: "RecurseCTE", + } +} diff --git a/go/vt/vtgate/engine/recurse_cte_test.go b/go/vt/vtgate/engine/recurse_cte_test.go new file mode 100644 index 00000000000..674b4f3533d --- /dev/null +++ b/go/vt/vtgate/engine/recurse_cte_test.go @@ -0,0 +1,129 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func TestRecurseDualQuery(t *testing.T) { + // Test that the RecurseCTE primitive works as expected. + // The test is testing something like this: + // WITH RECURSIVE cte AS (SELECT 1 as col1 UNION SELECT col1+1 FROM cte WHERE col1 < 5) SELECT * FROM cte; + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1", + "int64", + ), + "1", + ), + }, + } + rightFields := sqltypes.MakeTestFields( + "col4", + "int64", + ) + + rightPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + rightFields, + "2", + ), + sqltypes.MakeTestResult( + rightFields, + "3", + ), + sqltypes.MakeTestResult( + rightFields, + "4", + ), sqltypes.MakeTestResult( + rightFields, + ), + }, + } + bv := map[string]*querypb.BindVariable{} + + cte := &RecurseCTE{ + Init: leftPrim, + Recurse: rightPrim, + Vars: map[string]int{"col1": 0}, + } + + r, err := cte.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `Execute col1: type:INT64 value:"1" false`, + `Execute col1: type:INT64 value:"2" false`, + `Execute col1: type:INT64 value:"3" false`, + `Execute col1: type:INT64 value:"4" false`, + }) + + wantRes := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1", + "int64", + ), + "1", + "2", + "3", + "4", + ) + expectResult(t, r, wantRes) + + // testing the streaming mode. + + leftPrim.rewind() + rightPrim.rewind() + + r, err = wrapStreamExecute(cte, &noopVCursor{}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `StreamExecute col1: type:INT64 value:"1" false`, + `StreamExecute col1: type:INT64 value:"2" false`, + `StreamExecute col1: type:INT64 value:"3" false`, + `StreamExecute col1: type:INT64 value:"4" false`, + }) + expectResult(t, r, wantRes) + + // testing the streaming mode with transaction + + leftPrim.rewind() + rightPrim.rewind() + + r, err = wrapStreamExecute(cte, &noopVCursor{inTx: true}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `Execute col1: type:INT64 value:"1" false`, + `Execute col1: type:INT64 value:"2" false`, + `Execute col1: type:INT64 value:"3" false`, + `Execute col1: type:INT64 value:"4" false`, + }) + expectResult(t, r, wantRes) + +} diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index bcde782e12d..33c747a5f85 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -431,7 +431,7 @@ func buildQuery(op Operator, qb *queryBuilder) { buildDelete(op, qb) case *Insert: buildDML(op, qb) - case *Recurse: + case *RecurseCTE: buildCTE(op, qb) default: panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) @@ -668,7 +668,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { sqlparser.RemoveKeyspaceInCol(qb.stmt) } -func buildCTE(op *Recurse, qb *queryBuilder) { +func buildCTE(op *RecurseCTE, qb *queryBuilder) { buildQuery(op.Init, qb) qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.Tail, qbR) diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index 31fe4a045c2..7c12a85c89b 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -21,7 +21,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) -func tryMergeRecurse(ctx *plancontext.PlanningContext, in *Recurse) (Operator, *ApplyResult) { +func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator, *ApplyResult) { op := tryMergeCTE(ctx, in.Init, in.Tail, in) if op == nil { return in, NoRewrite @@ -30,7 +30,7 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *Recurse) (Operator, * return op, Rewrote("Merged CTE") } -func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *Recurse) *Route { +func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *RecurseCTE) *Route { initRoute, tailRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(init, tail) if initRoute == nil || !sameKeyspace { return nil @@ -46,7 +46,7 @@ func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *Recu } } -func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in *Recurse) *Route { +func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in *RecurseCTE) *Route { tblA := init.Routing.(*ShardedRouting) tblB := tail.Routing.(*ShardedRouting) switch tblA.RouteOpCode { @@ -66,10 +66,10 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in return nil } -func mergeCTE(ctx *plancontext.PlanningContext, init, tail *Route, r Routing, in *Recurse) *Route { +func mergeCTE(ctx *plancontext.PlanningContext, init, tail *Route, r Routing, in *RecurseCTE) *Route { return &Route{ Routing: r, - Source: &Recurse{ + Source: &RecurseCTE{ Name: in.Name, ColumnNames: in.ColumnNames, Init: init.Source, diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 2731d732156..0f2445a22e7 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -102,7 +102,7 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return tryPushDelete(in) case *Update: return tryPushUpdate(in) - case *Recurse: + case *RecurseCTE: return tryMergeRecurse(ctx, in) default: diff --git a/go/vt/vtgate/planbuilder/operators/recurse.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go similarity index 57% rename from go/vt/vtgate/planbuilder/operators/recurse.go rename to go/vt/vtgate/planbuilder/operators/recurse_cte.go index 67ddea3adc8..18261e7becc 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -23,8 +23,8 @@ import ( "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) -// Recurse is used to represent a recursive CTE -type Recurse struct { +// RecurseCTE is used to represent a recursive CTE +type RecurseCTE struct { // Name is the name of the recursive CTE Name string @@ -37,18 +37,18 @@ type Recurse struct { Init, Tail Operator } -var _ Operator = (*Recurse)(nil) +var _ Operator = (*RecurseCTE)(nil) -func newRecurse(name string, init, tail Operator) *Recurse { - return &Recurse{ +func newRecurse(name string, init, tail Operator) *RecurseCTE { + return &RecurseCTE{ Name: name, Init: init, Tail: tail, } } -func (r *Recurse) Clone(inputs []Operator) Operator { - return &Recurse{ +func (r *RecurseCTE) Clone(inputs []Operator) Operator { + return &RecurseCTE{ Name: r.Name, ColumnNames: slices.Clone(r.ColumnNames), Offsets: slices.Clone(r.Offsets), @@ -57,43 +57,43 @@ func (r *Recurse) Clone(inputs []Operator) Operator { } } -func (r *Recurse) Inputs() []Operator { +func (r *RecurseCTE) Inputs() []Operator { return []Operator{r.Init, r.Tail} } -func (r *Recurse) SetInputs(operators []Operator) { +func (r *RecurseCTE) SetInputs(operators []Operator) { r.Init = operators[0] r.Tail = operators[1] } -func (r *Recurse) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { +func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { r.Tail = newFilter(r, e) return r } -func (r *Recurse) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { +func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { return r.Init.AddColumn(ctx, reuseExisting, addToGroupBy, expr) } -func (r *Recurse) AddWSColumn(*plancontext.PlanningContext, int, bool) int { +func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { panic("implement me") } -func (r *Recurse) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { +func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { return r.Init.FindCol(ctx, expr, underRoute) } -func (r *Recurse) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { +func (r *RecurseCTE) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { return r.Init.GetColumns(ctx) } -func (r *Recurse) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { +func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { return r.Init.GetSelectExprs(ctx) } -func (r *Recurse) ShortDescription() string { return "" } +func (r *RecurseCTE) ShortDescription() string { return "" } -func (r *Recurse) GetOrdering(*plancontext.PlanningContext) []OrderBy { - // Recurse is a special case. It never guarantees any ordering. +func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { + // RecurseCTE is a special case. It never guarantees any ordering. return nil } From 55dfef9215fa21edfc0a07030693953a28553cdd Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 25 Jul 2024 12:05:53 +0200 Subject: [PATCH 04/28] remove subquery from cte Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 2 +- go/vt/sqlparser/ast_copy_on_rewrite.go | 4 ++-- go/vt/sqlparser/ast_equals.go | 2 +- go/vt/sqlparser/ast_format.go | 2 +- go/vt/sqlparser/ast_format_fast.go | 4 ++-- go/vt/sqlparser/ast_rewrite.go | 4 ++-- go/vt/sqlparser/ast_visit.go | 2 +- go/vt/sqlparser/cached_size.go | 8 +++++--- go/vt/sqlparser/sql.go | 2 +- go/vt/sqlparser/sql.y | 2 +- go/vt/vtgate/planbuilder/operators/SQL_builder.go | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 2 +- go/vt/vtgate/semantics/scoper.go | 2 +- go/vt/vtgate/semantics/table_collector.go | 2 +- 15 files changed, 22 insertions(+), 20 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 8a2363331e9..938b9063011 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -161,7 +161,7 @@ type ( CommonTableExpr struct { ID IdentifierCS Columns Columns - Subquery *Subquery + Subquery SelectStatement } // ChangeColumn is used to change the column definition, can also rename the column in alter table command ChangeColumn struct { diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 7a59832b867..f22a1790232 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1022,7 +1022,7 @@ func CloneRefOfCommonTableExpr(n *CommonTableExpr) *CommonTableExpr { out := *n out.ID = CloneIdentifierCS(n.ID) out.Columns = CloneColumns(n.Columns) - out.Subquery = CloneRefOfSubquery(n.Subquery) + out.Subquery = CloneSelectStatement(n.Subquery) return &out } diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index caa00181f9e..0e329e24f31 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -1522,12 +1522,12 @@ func (c *cow) copyOnRewriteRefOfCommonTableExpr(n *CommonTableExpr, parent SQLNo if c.pre == nil || c.pre(n, parent) { _ID, changedID := c.copyOnRewriteIdentifierCS(n.ID, n) _Columns, changedColumns := c.copyOnRewriteColumns(n.Columns, n) - _Subquery, changedSubquery := c.copyOnRewriteRefOfSubquery(n.Subquery, n) + _Subquery, changedSubquery := c.copyOnRewriteSelectStatement(n.Subquery, n) if changedID || changedColumns || changedSubquery { res := *n res.ID, _ = _ID.(IdentifierCS) res.Columns, _ = _Columns.(Columns) - res.Subquery, _ = _Subquery.(*Subquery) + res.Subquery, _ = _Subquery.(SelectStatement) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 2b391db630b..cf076d706e7 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2193,7 +2193,7 @@ func (cmp *Comparator) RefOfCommonTableExpr(a, b *CommonTableExpr) bool { } return cmp.IdentifierCS(a.ID, b.ID) && cmp.Columns(a.Columns, b.Columns) && - cmp.RefOfSubquery(a.Subquery, b.Subquery) + cmp.SelectStatement(a.Subquery, b.Subquery) } // RefOfComparisonExpr does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index e89da3dc270..587b32d4afe 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -167,7 +167,7 @@ func (node *With) Format(buf *TrackedBuffer) { // Format formats the node. func (node *CommonTableExpr) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "%v%v as %v ", node.ID, node.Columns, node.Subquery) + buf.astPrintf(node, "%v%v as (%v) ", node.ID, node.Columns, node.Subquery) } // Format formats the node. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index f04928c7dfa..c2b02711398 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -243,9 +243,9 @@ func (node *With) FormatFast(buf *TrackedBuffer) { func (node *CommonTableExpr) FormatFast(buf *TrackedBuffer) { node.ID.FormatFast(buf) node.Columns.FormatFast(buf) - buf.WriteString(" as ") + buf.WriteString(" as (") node.Subquery.FormatFast(buf) - buf.WriteByte(' ') + buf.WriteString(") ") } // FormatFast formats the node. diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 0cad7237455..015c27a2cbd 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1964,8 +1964,8 @@ func (a *application) rewriteRefOfCommonTableExpr(parent SQLNode, node *CommonTa }) { return false } - if !a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { - parent.(*CommonTableExpr).Subquery = newNode.(*Subquery) + if !a.rewriteSelectStatement(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*CommonTableExpr).Subquery = newNode.(SelectStatement) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index d73ed076dbb..d33c2d1e055 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1172,7 +1172,7 @@ func VisitRefOfCommonTableExpr(in *CommonTableExpr, f Visit) error { if err := VisitColumns(in.Columns, f); err != nil { return err } - if err := VisitRefOfSubquery(in.Subquery, f); err != nil { + if err := VisitSelectStatement(in.Subquery, f); err != nil { return err } return nil diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 2110ea8be30..391e9a84ad3 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -834,7 +834,7 @@ func (cached *CommonTableExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } // field ID vitess.io/vitess/go/vt/sqlparser.IdentifierCS size += cached.ID.CachedSize(false) @@ -845,8 +845,10 @@ func (cached *CommonTableExpr) CachedSize(alloc bool) int64 { size += elem.CachedSize(false) } } - // field Subquery *vitess.io/vitess/go/vt/sqlparser.Subquery - size += cached.Subquery.CachedSize(true) + // field Subquery vitess.io/vitess/go/vt/sqlparser.SelectStatement + if cc, ok := cached.Subquery.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 196d020a36b..9912b19f323 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -10591,7 +10591,7 @@ yydefault: var yyLOCAL *CommonTableExpr //line sql.y:757 { - yyLOCAL = &CommonTableExpr{ID: yyDollar[1].identifierCS, Columns: yyDollar[2].columnsUnion(), Subquery: yyDollar[4].subqueryUnion()} + yyLOCAL = &CommonTableExpr{ID: yyDollar[1].identifierCS, Columns: yyDollar[2].columnsUnion(), Subquery: yyDollar[4].subqueryUnion().Select} } yyVAL.union = yyLOCAL case 54: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 5bef040c4f1..64ce957d2dd 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -755,7 +755,7 @@ with_list: common_table_expr: table_id column_list_opt AS subquery { - $$ = &CommonTableExpr{ID: $1, Columns: $2, Subquery: $4} + $$ = &CommonTableExpr{ID: $1, Columns: $2, Subquery: $4.Select} } query_expression_parens: diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 33c747a5f85..0ab42e17038 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -229,7 +229,7 @@ func (qb *queryBuilder) cteWith(other *queryBuilder, name string) { CTEs: []*sqlparser.CommonTableExpr{{ ID: sqlparser.NewIdentifierCS(name), Columns: nil, - Subquery: &sqlparser.Subquery{Select: cteUnion}, + Subquery: cteUnion, }}, }, } diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 038a4405f91..ee12765e984 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -146,7 +146,7 @@ func (r *earlyRewriter) handleAliasedTable(node *sqlparser.AliasedTableExpr) err node.As = tbl.Name } node.Expr = &sqlparser.DerivedTable{ - Select: cte.Subquery.Select, + Select: cte.Subquery, } if len(cte.Columns) > 0 { node.Columns = cte.Columns diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index e775f4a52eb..4b42f2622cb 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -370,7 +370,7 @@ func checkForInvalidAliasUse(cte *sqlparser.CommonTableExpr, name string) (err e } return err == nil } - _ = sqlparser.CopyOnRewrite(cte.Subquery.Select, down, nil, nil) + _ = sqlparser.CopyOnRewrite(cte.Subquery, down, nil, nil) return err } diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 84effee1e4a..edb9f184268 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -63,7 +63,7 @@ func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { } for _, cte := range with.CTEs { etc.cte[cte.ID.String()] = CTEDef{ - Query: cte.Subquery.Select, + Query: cte.Subquery, Columns: cte.Columns, } } From 09aa5d88b1e83a233da27689cfba5d141ff120c4 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 29 Jul 2024 14:36:53 +0200 Subject: [PATCH 05/28] check for invalid CTEs Signed-off-by: Andres Taylor --- go/mysql/sqlerror/constants.go | 46 +++--- go/mysql/sqlerror/sql_error.go | 152 +++++++++--------- go/vt/sqlparser/ast_funcs.go | 14 ++ go/vt/vterrors/code.go | 4 + go/vt/vterrors/state.go | 4 + .../vtgate/planbuilder/operators/ast_to_op.go | 2 +- go/vt/vtgate/semantics/analyzer_test.go | 36 ++++- go/vt/vtgate/semantics/cte_table.go | 25 ++- go/vt/vtgate/semantics/real_table.go | 42 +++++ go/vt/vtgate/semantics/scoper.go | 15 +- go/vt/vtgate/semantics/table_collector.go | 93 +++++++++-- 11 files changed, 302 insertions(+), 131 deletions(-) diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go index 15c590b92a8..bd4c188af14 100644 --- a/go/mysql/sqlerror/constants.go +++ b/go/mysql/sqlerror/constants.go @@ -255,27 +255,31 @@ const ( ERJSONValueTooBig = ErrorCode(3150) ERJSONDocumentTooDeep = ErrorCode(3157) - ERLockNowait = ErrorCode(3572) - ERRegexpStringNotTerminated = ErrorCode(3684) - ERRegexpBufferOverflow = ErrorCode(3684) - ERRegexpIllegalArgument = ErrorCode(3685) - ERRegexpIndexOutOfBounds = ErrorCode(3686) - ERRegexpInternal = ErrorCode(3687) - ERRegexpRuleSyntax = ErrorCode(3688) - ERRegexpBadEscapeSequence = ErrorCode(3689) - ERRegexpUnimplemented = ErrorCode(3690) - ERRegexpMismatchParen = ErrorCode(3691) - ERRegexpBadInterval = ErrorCode(3692) - ERRRegexpMaxLtMin = ErrorCode(3693) - ERRegexpInvalidBackRef = ErrorCode(3694) - ERRegexpLookBehindLimit = ErrorCode(3695) - ERRegexpMissingCloseBracket = ErrorCode(3696) - ERRegexpInvalidRange = ErrorCode(3697) - ERRegexpStackOverflow = ErrorCode(3698) - ERRegexpTimeOut = ErrorCode(3699) - ERRegexpPatternTooBig = ErrorCode(3700) - ERRegexpInvalidCaptureGroup = ErrorCode(3887) - ERRegexpInvalidFlag = ErrorCode(3900) + ERLockNowait = ErrorCode(3572) + ERCTERecursiveRequiresUnion = ErrorCode(3573) + ERCTERecursiveForbidsAggregation = ErrorCode(3575) + ERCTERecursiveForbiddenJoinOrder = ErrorCode(3576) + ERCTERecursiveRequiresSingleReference = ErrorCode(3577) + ERRegexpStringNotTerminated = ErrorCode(3684) + ERRegexpBufferOverflow = ErrorCode(3684) + ERRegexpIllegalArgument = ErrorCode(3685) + ERRegexpIndexOutOfBounds = ErrorCode(3686) + ERRegexpInternal = ErrorCode(3687) + ERRegexpRuleSyntax = ErrorCode(3688) + ERRegexpBadEscapeSequence = ErrorCode(3689) + ERRegexpUnimplemented = ErrorCode(3690) + ERRegexpMismatchParen = ErrorCode(3691) + ERRegexpBadInterval = ErrorCode(3692) + ERRRegexpMaxLtMin = ErrorCode(3693) + ERRegexpInvalidBackRef = ErrorCode(3694) + ERRegexpLookBehindLimit = ErrorCode(3695) + ERRegexpMissingCloseBracket = ErrorCode(3696) + ERRegexpInvalidRange = ErrorCode(3697) + ERRegexpStackOverflow = ErrorCode(3698) + ERRegexpTimeOut = ErrorCode(3699) + ERRegexpPatternTooBig = ErrorCode(3700) + ERRegexpInvalidCaptureGroup = ErrorCode(3887) + ERRegexpInvalidFlag = ErrorCode(3900) ERCharacterSetMismatch = ErrorCode(3995) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index eaa49c2c537..603456a7ae9 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -172,80 +172,84 @@ type mysqlCode struct { } var stateToMysqlCode = map[vterrors.State]mysqlCode{ - vterrors.Undefined: {num: ERUnknownError, state: SSUnknownSQLState}, - vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError}, - vterrors.BadDb: {num: ERBadDb, state: SSClientError}, - vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError}, - vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable}, - vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError}, - vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, - vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, - vterrors.DbDropExists: {num: ERDbDropExists, state: SSUnknownSQLState}, - vterrors.DupFieldName: {num: ERDupFieldName, state: SSDupFieldName}, - vterrors.EmptyQuery: {num: EREmptyQuery, state: SSClientError}, - vterrors.IncorrectGlobalLocalVar: {num: ERIncorrectGlobalLocalVar, state: SSUnknownSQLState}, - vterrors.InnodbReadOnly: {num: ERInnodbReadOnly, state: SSUnknownSQLState}, - vterrors.LockOrActiveTransaction: {num: ERLockOrActiveTransaction, state: SSUnknownSQLState}, - vterrors.NoDB: {num: ERNoDb, state: SSNoDB}, - vterrors.NoSuchTable: {num: ERNoSuchTable, state: SSUnknownTable}, - vterrors.NotSupportedYet: {num: ERNotSupportedYet, state: SSClientError}, - vterrors.ForbidSchemaChange: {num: ERForbidSchemaChange, state: SSUnknownSQLState}, - vterrors.MixOfGroupFuncAndFields: {num: ERMixOfGroupFuncAndFields, state: SSClientError}, - vterrors.NetPacketTooLarge: {num: ERNetPacketTooLarge, state: SSNetError}, - vterrors.NonUniqError: {num: ERNonUniq, state: SSConstraintViolation}, - vterrors.NonUniqTable: {num: ERNonUniqTable, state: SSClientError}, - vterrors.NonUpdateableTable: {num: ERNonUpdateableTable, state: SSUnknownSQLState}, - vterrors.QueryInterrupted: {num: ERQueryInterrupted, state: SSQueryInterrupted}, - vterrors.SPDoesNotExist: {num: ERSPDoesNotExist, state: SSClientError}, - vterrors.SyntaxError: {num: ERSyntaxError, state: SSClientError}, - vterrors.UnsupportedPS: {num: ERUnsupportedPS, state: SSUnknownSQLState}, - vterrors.UnknownSystemVariable: {num: ERUnknownSystemVariable, state: SSUnknownSQLState}, - vterrors.UnknownTable: {num: ERUnknownTable, state: SSUnknownTable}, - vterrors.WrongGroupField: {num: ERWrongGroupField, state: SSClientError}, - vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns}, - vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError}, - vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError}, - vterrors.WrongValue: {num: ERWrongValue, state: SSUnknownSQLState}, - vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError}, - vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError}, - vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction}, - vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError}, - vterrors.RowIsReferenced2: {num: ERRowIsReferenced2, state: SSConstraintViolation}, - vterrors.NoReferencedRow2: {num: ErNoReferencedRow2, state: SSConstraintViolation}, - vterrors.NoSuchSession: {num: ERUnknownComError, state: SSNetError}, - vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, - vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, - vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState}, - vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState}, - vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState}, - vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError}, - vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, - vterrors.RegexpStringNotTerminated: {num: ERRegexpStringNotTerminated, state: SSUnknownSQLState}, - vterrors.RegexpBufferOverflow: {num: ERRegexpBufferOverflow, state: SSUnknownSQLState}, - vterrors.RegexpIllegalArgument: {num: ERRegexpIllegalArgument, state: SSUnknownSQLState}, - vterrors.RegexpIndexOutOfBounds: {num: ERRegexpIndexOutOfBounds, state: SSUnknownSQLState}, - vterrors.RegexpInternal: {num: ERRegexpInternal, state: SSUnknownSQLState}, - vterrors.RegexpRuleSyntax: {num: ERRegexpRuleSyntax, state: SSUnknownSQLState}, - vterrors.RegexpBadEscapeSequence: {num: ERRegexpBadEscapeSequence, state: SSUnknownSQLState}, - vterrors.RegexpUnimplemented: {num: ERRegexpUnimplemented, state: SSUnknownSQLState}, - vterrors.RegexpMismatchParen: {num: ERRegexpMismatchParen, state: SSUnknownSQLState}, - vterrors.RegexpBadInterval: {num: ERRegexpBadInterval, state: SSUnknownSQLState}, - vterrors.RegexpMaxLtMin: {num: ERRRegexpMaxLtMin, state: SSUnknownSQLState}, - vterrors.RegexpInvalidBackRef: {num: ERRegexpInvalidBackRef, state: SSUnknownSQLState}, - vterrors.RegexpLookBehindLimit: {num: ERRegexpLookBehindLimit, state: SSUnknownSQLState}, - vterrors.RegexpMissingCloseBracket: {num: ERRegexpMissingCloseBracket, state: SSUnknownSQLState}, - vterrors.RegexpInvalidRange: {num: ERRegexpInvalidRange, state: SSUnknownSQLState}, - vterrors.RegexpStackOverflow: {num: ERRegexpStackOverflow, state: SSUnknownSQLState}, - vterrors.RegexpTimeOut: {num: ERRegexpTimeOut, state: SSUnknownSQLState}, - vterrors.RegexpPatternTooBig: {num: ERRegexpPatternTooBig, state: SSUnknownSQLState}, - vterrors.RegexpInvalidFlag: {num: ERRegexpInvalidFlag, state: SSUnknownSQLState}, - vterrors.RegexpInvalidCaptureGroup: {num: ERRegexpInvalidCaptureGroup, state: SSUnknownSQLState}, - vterrors.CharacterSetMismatch: {num: ERCharacterSetMismatch, state: SSUnknownSQLState}, - vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, - vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, - vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, - vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, - vterrors.VectorConversion: {num: ERVectorConversion, state: SSUnknownSQLState}, + vterrors.Undefined: {num: ERUnknownError, state: SSUnknownSQLState}, + vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError}, + vterrors.BadDb: {num: ERBadDb, state: SSClientError}, + vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError}, + vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable}, + vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError}, + vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, + vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, + vterrors.DbDropExists: {num: ERDbDropExists, state: SSUnknownSQLState}, + vterrors.DupFieldName: {num: ERDupFieldName, state: SSDupFieldName}, + vterrors.EmptyQuery: {num: EREmptyQuery, state: SSClientError}, + vterrors.IncorrectGlobalLocalVar: {num: ERIncorrectGlobalLocalVar, state: SSUnknownSQLState}, + vterrors.InnodbReadOnly: {num: ERInnodbReadOnly, state: SSUnknownSQLState}, + vterrors.LockOrActiveTransaction: {num: ERLockOrActiveTransaction, state: SSUnknownSQLState}, + vterrors.NoDB: {num: ERNoDb, state: SSNoDB}, + vterrors.NoSuchTable: {num: ERNoSuchTable, state: SSUnknownTable}, + vterrors.NotSupportedYet: {num: ERNotSupportedYet, state: SSClientError}, + vterrors.ForbidSchemaChange: {num: ERForbidSchemaChange, state: SSUnknownSQLState}, + vterrors.MixOfGroupFuncAndFields: {num: ERMixOfGroupFuncAndFields, state: SSClientError}, + vterrors.NetPacketTooLarge: {num: ERNetPacketTooLarge, state: SSNetError}, + vterrors.NonUniqError: {num: ERNonUniq, state: SSConstraintViolation}, + vterrors.NonUniqTable: {num: ERNonUniqTable, state: SSClientError}, + vterrors.NonUpdateableTable: {num: ERNonUpdateableTable, state: SSUnknownSQLState}, + vterrors.QueryInterrupted: {num: ERQueryInterrupted, state: SSQueryInterrupted}, + vterrors.SPDoesNotExist: {num: ERSPDoesNotExist, state: SSClientError}, + vterrors.SyntaxError: {num: ERSyntaxError, state: SSClientError}, + vterrors.UnsupportedPS: {num: ERUnsupportedPS, state: SSUnknownSQLState}, + vterrors.UnknownSystemVariable: {num: ERUnknownSystemVariable, state: SSUnknownSQLState}, + vterrors.UnknownTable: {num: ERUnknownTable, state: SSUnknownTable}, + vterrors.WrongGroupField: {num: ERWrongGroupField, state: SSClientError}, + vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns}, + vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError}, + vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError}, + vterrors.WrongValue: {num: ERWrongValue, state: SSUnknownSQLState}, + vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError}, + vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError}, + vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction}, + vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError}, + vterrors.RowIsReferenced2: {num: ERRowIsReferenced2, state: SSConstraintViolation}, + vterrors.NoReferencedRow2: {num: ErNoReferencedRow2, state: SSConstraintViolation}, + vterrors.NoSuchSession: {num: ERUnknownComError, state: SSNetError}, + vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, + vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, + vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState}, + vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState}, + vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState}, + vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError}, + vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, + vterrors.RegexpStringNotTerminated: {num: ERRegexpStringNotTerminated, state: SSUnknownSQLState}, + vterrors.RegexpBufferOverflow: {num: ERRegexpBufferOverflow, state: SSUnknownSQLState}, + vterrors.RegexpIllegalArgument: {num: ERRegexpIllegalArgument, state: SSUnknownSQLState}, + vterrors.RegexpIndexOutOfBounds: {num: ERRegexpIndexOutOfBounds, state: SSUnknownSQLState}, + vterrors.RegexpInternal: {num: ERRegexpInternal, state: SSUnknownSQLState}, + vterrors.RegexpRuleSyntax: {num: ERRegexpRuleSyntax, state: SSUnknownSQLState}, + vterrors.RegexpBadEscapeSequence: {num: ERRegexpBadEscapeSequence, state: SSUnknownSQLState}, + vterrors.RegexpUnimplemented: {num: ERRegexpUnimplemented, state: SSUnknownSQLState}, + vterrors.RegexpMismatchParen: {num: ERRegexpMismatchParen, state: SSUnknownSQLState}, + vterrors.RegexpBadInterval: {num: ERRegexpBadInterval, state: SSUnknownSQLState}, + vterrors.RegexpMaxLtMin: {num: ERRRegexpMaxLtMin, state: SSUnknownSQLState}, + vterrors.RegexpInvalidBackRef: {num: ERRegexpInvalidBackRef, state: SSUnknownSQLState}, + vterrors.RegexpLookBehindLimit: {num: ERRegexpLookBehindLimit, state: SSUnknownSQLState}, + vterrors.RegexpMissingCloseBracket: {num: ERRegexpMissingCloseBracket, state: SSUnknownSQLState}, + vterrors.RegexpInvalidRange: {num: ERRegexpInvalidRange, state: SSUnknownSQLState}, + vterrors.RegexpStackOverflow: {num: ERRegexpStackOverflow, state: SSUnknownSQLState}, + vterrors.RegexpTimeOut: {num: ERRegexpTimeOut, state: SSUnknownSQLState}, + vterrors.RegexpPatternTooBig: {num: ERRegexpPatternTooBig, state: SSUnknownSQLState}, + vterrors.RegexpInvalidFlag: {num: ERRegexpInvalidFlag, state: SSUnknownSQLState}, + vterrors.RegexpInvalidCaptureGroup: {num: ERRegexpInvalidCaptureGroup, state: SSUnknownSQLState}, + vterrors.CharacterSetMismatch: {num: ERCharacterSetMismatch, state: SSUnknownSQLState}, + vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, + vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, + vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, + vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, + vterrors.VectorConversion: {num: ERVectorConversion, state: SSUnknownSQLState}, + vterrors.CTERecursiveRequiresSingleReference: {num: ERCTERecursiveRequiresSingleReference, state: SSUnknownSQLState}, + vterrors.CTERecursiveRequiresUnion: {num: ERCTERecursiveRequiresUnion, state: SSUnknownSQLState}, + vterrors.CTERecursiveForbidsAggregation: {num: ERCTERecursiveForbidsAggregation, state: SSUnknownSQLState}, + vterrors.CTERecursiveForbiddenJoinOrder: {num: ERCTERecursiveForbiddenJoinOrder, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index cd4d5304047..5e152065622 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -428,6 +428,20 @@ func (node *AliasedTableExpr) TableName() (TableName, error) { return tableName, nil } +// TableNameString returns a TableNameString pointing to this table expr +func (node *AliasedTableExpr) TableNameString() string { + if node.As.NotEmpty() { + return node.As.String() + } + + tableName, ok := node.Expr.(TableName) + if !ok { + panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: Derived table should have an alias. This should not be possible")) + } + + return tableName.Name.String() +} + // IsEmpty returns true if TableName is nil or empty. func (node TableName) IsEmpty() bool { // If Name is empty, Qualifier is also empty. diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 83a87503265..0b2c298f17f 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -97,6 +97,10 @@ var ( VT09023 = errorWithoutState("VT09023", vtrpcpb.Code_FAILED_PRECONDITION, "could not map %v to a keyspace id", "Unable to determine the shard for the given row.") VT09024 = errorWithoutState("VT09024", vtrpcpb.Code_FAILED_PRECONDITION, "could not map %v to a unique keyspace id: %v", "Unable to determine the shard for the given row.") VT09025 = errorWithoutState("VT09025", vtrpcpb.Code_FAILED_PRECONDITION, "atomic transaction error: %v", "Error in atomic transactions") + VT09026 = errorWithState("VT09026", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveRequiresUnion, "Recursive Common Table Expression '%s' should contain a UNION", "") + VT09027 = errorWithState("VT09027", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbidsAggregation, "Recursive Common Table Expression '%s' can contain neither aggregation nor window functions in recursive query block", "") + VT09028 = errorWithState("VT09028", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbiddenJoinOrder, "In recursive query block of Recursive Common Table Expression '%s', the recursive table must neither be in the right argument of a LEFT JOIN, nor be forced to be non-first with join order hints", "") + VT09029 = errorWithState("VT09029", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveRequiresSingleReference, "In recursive query block of Recursive Common Table Expression %s, the recursive table must be referenced only once, and not in any subquery", "") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") VT10002 = errorWithoutState("VT10002", vtrpcpb.Code_ABORTED, "atomic distributed transaction not allowed: %s", "The distributed transaction cannot be committed. A rollback decision is taken.") diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 82434df382a..1f1c5922c37 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -62,6 +62,10 @@ const ( NoReferencedRow2 UnknownStmtHandler KeyDoesNotExist + CTERecursiveRequiresSingleReference + CTERecursiveRequiresUnion + CTERecursiveForbidsAggregation + CTERecursiveForbiddenJoinOrder // not found BadDb diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index a18855688c1..58402e48a68 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -268,7 +268,7 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr } case *semantics.CTETable: current := ctx.ActiveCTE() - if current != nil && current.CTEDef.Equals(tableInfo.CTEDef) { + if current != nil && current.CTEDef == tableInfo.CTEDef { return createDualCTETable(ctx, tableID, tableInfo) } return createRecursiveCTE(ctx, tableInfo) diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 01c34639763..0c42456b0ab 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -203,11 +203,11 @@ func TestBindingRecursiveCTEs(t *testing.T) { } queries := []testCase{{ query: "with recursive x as (select id from user union select x.id + 1 from x where x.id < 15) select t.id from x join x t;", - rdeps: MergeTableSets(TS0, TS1), // This is the user and `x` in the CTE - ddeps: TS3, // this is the t id + rdeps: TS3, + ddeps: TS3, }, { query: "WITH RECURSIVE user_cte AS (SELECT id, name FROM user WHERE id = 42 UNION ALL SELECT u.id, u.name FROM user u JOIN user_cte cte ON u.id = cte.id + 1 WHERE u.id = 42) SELECT id FROM user_cte", - rdeps: MergeTableSets(TS0, TS1, TS2), // This is the two uses of the user and `user_cte` in the CTE + rdeps: TS3, ddeps: TS3, }} for _, query := range queries { @@ -220,6 +220,34 @@ func TestBindingRecursiveCTEs(t *testing.T) { } } +func TestRecursiveCTEChecking(t *testing.T) { + type testCase struct { + name, query, err string + } + queries := []testCase{{ + name: "recursive CTE using aggregation", + query: "with recursive x as (select id from user union select count(*) from x) select t.id from x join x t", + err: "VT09027: Recursive Common Table Expression 'x' can contain neither aggregation nor window functions in recursive query block", + }, { + name: "recursive CTE using grouping", + query: "with recursive x as (select id from user union select id+1 from x where id < 10 group by 1) select t.id from x join x t", + err: "VT09027: Recursive Common Table Expression 'x' can contain neither aggregation nor window functions in recursive query block", + }, { + name: "use the same recursive cte twice in definition", + query: "with recursive x as (select 1 union select id+1 from x where id < 10 union select id+2 from x where id < 20) select t.id from x", + err: "VT09029: In recursive query block of Recursive Common Table Expression x, the recursive table must be referenced only once, and not in any subquery", + }} + for _, tc := range queries { + t.Run(tc.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(tc.query) + require.NoError(t, err) + + _, err = AnalyzeStrict(parse, "user", fakeSchemaInfo()) + require.EqualError(t, err, tc.err) + }) + } +} + func TestBindingMultiAliasedTablePositive(t *testing.T) { type testCase struct { query string @@ -978,7 +1006,7 @@ func TestScopingWithWITH(t *testing.T) { }, { query: "with c as (select x as foo from user), t as (select foo as id from c) select id from t", recursive: TS0, - direct: TS3, + direct: TS2, }, { query: "with t as (select foo as id from user) select t.id from t", recursive: TS0, diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index c7b4f41f8d8..5854a989dd8 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -17,7 +17,6 @@ limitations under the License. package semantics import ( - "slices" "strings" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -31,12 +30,12 @@ import ( type CTETable struct { TableName string ASTNode *sqlparser.AliasedTableExpr - CTEDef + *CTEDef } var _ TableInfo = (*CTETable)(nil) -func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef CTEDef) *CTETable { +func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef *CTEDef) *CTETable { var name string if node.As.IsEmpty() { name = t.Name.String() @@ -113,7 +112,7 @@ func (cte *CTETable) dependencies(colName string, org originable) (dependencies, columns := cte.getColumns(false) for _, columnInfo := range columns { if strings.EqualFold(columnInfo.Name, colName) { - return createCertain(directDeps, cte.recursive(org), evalengine.NewUnknownType()), nil + return createCertain(directDeps, directDeps, evalengine.NewUnknownType()), nil } } @@ -121,7 +120,7 @@ func (cte *CTETable) dependencies(colName string, org originable) (dependencies, return ¬hing{}, nil } - return createUncertain(directDeps, cte.recursive(org)), nil + return createUncertain(directDeps, directDeps), nil } func (cte *CTETable) getExprFor(s string) (sqlparser.Expr, error) { @@ -133,13 +132,18 @@ func (cte *CTETable) getTableSet(org originable) TableSet { } type CTEDef struct { + Name string Query sqlparser.SelectStatement isAuthoritative bool recursiveDeps *TableSet Columns sqlparser.Columns + IDForRecurse *TableSet + + // Was this CTE marked for being recursive? + Recursive bool } -func (cte CTEDef) recursive(org originable) (id TableSet) { +func (cte *CTEDef) recursive(org originable) (id TableSet) { if cte.recursiveDeps != nil { return *cte.recursiveDeps } @@ -156,12 +160,3 @@ func (cte CTEDef) recursive(org originable) (id TableSet) { }, cte.Query) return } - -func (cte CTEDef) Equals(other CTEDef) bool { - if !sqlparser.Equals.SelectStatement(cte.Query, other.Query) { - return false - } - return slices.EqualFunc(cte.Columns, other.Columns, func(a, b sqlparser.IdentifierCI) bool { - return a.Equal(b) - }) -} diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 4f1639d0897..b11b736e79d 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -20,9 +20,11 @@ import ( "strings" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/slice" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -31,6 +33,7 @@ type RealTable struct { dbName, tableName string ASTNode *sqlparser.AliasedTableExpr Table *vindexes.Table + CTE *CTEDef VindexHint *sqlparser.IndexHint isInfSchema bool collationEnv *collations.Environment @@ -71,8 +74,17 @@ func (r *RealTable) IsInfSchema() bool { // GetColumns implements the TableInfo interface func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { if r.Table == nil { + if r.CTE != nil { + selectExprs := r.CTE.Query.GetColumns() + ci := extractColumnsFromCTE(r.CTE.Columns, selectExprs) + if ci == nil { + return ci + } + return extractSelectExprsFromCTE(selectExprs) + } return nil } + nameMap := map[string]any{} cols := make([]ColumnInfo, 0, len(r.Table.Columns)) for _, col := range r.Table.Columns { @@ -105,6 +117,36 @@ func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { return cols } +func extractSelectExprsFromCTE(selectExprs sqlparser.SelectExprs) []ColumnInfo { + var ci []ColumnInfo + for _, expr := range selectExprs { + ae, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return nil + } + ci = append(ci, ColumnInfo{ + Name: ae.ColumnName(), + Type: evalengine.NewUnknownType(), // TODO: set the proper type + }) + } + return ci +} + +func extractColumnsFromCTE(columns sqlparser.Columns, selectExprs sqlparser.SelectExprs) []ColumnInfo { + if len(columns) == 0 { + return nil + } + if len(selectExprs) != len(columns) { + panic("mismatch of columns") + } + return slice.Map(columns, func(from sqlparser.IdentifierCI) ColumnInfo { + return ColumnInfo{ + Name: from.String(), + Type: evalengine.NewUnknownType(), + } + }) +} + // GetExpr implements the TableInfo interface func (r *RealTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { return r.ASTNode diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 4b42f2622cb..b51cedd2338 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + "fmt" "reflect" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -35,9 +36,10 @@ type ( binder *binder // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 - specialExprScopes map[*sqlparser.Literal]*scope - statementIDs map[sqlparser.Statement]TableSet - si SchemaInformation + specialExprScopes map[*sqlparser.Literal]*scope + statementIDs map[sqlparser.Statement]TableSet + commonTableExprScopes []*sqlparser.CommonTableExpr + si SchemaInformation } scope struct { @@ -105,6 +107,8 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { s.currentScope().inHaving = true return nil } + case *sqlparser.CommonTableExpr: + s.commonTableExprScopes = append(s.commonTableExprScopes, node) } return nil } @@ -184,6 +188,9 @@ func (s *scoper) enterJoinScope(cursor *sqlparser.Cursor) { func (s *scoper) pushSelectScope(node *sqlparser.Select) { currScope := newScope(s.currentScope()) + if len(s.scopes) > 0 && s.scopes[len(s.scopes)-1] != s.currentScope() { + fmt.Println("BUG: scope counts did not match") + } currScope.stmtScope = true s.push(currScope) @@ -261,6 +268,8 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { s.binder.usingJoinInfo[ts] = m } } + case *sqlparser.CommonTableExpr: + s.commonTableExprScopes = s.commonTableExprScopes[:len(s.commonTableExprScopes)-1] } return nil } diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index edb9f184268..7c9efd79513 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -43,7 +43,9 @@ type ( currentDb string Tables []TableInfo done map[*sqlparser.AliasedTableExpr]TableInfo - cte map[string]CTEDef + + // cte is a map of CTE definitions that are used in the query + cte map[string]*CTEDef } ) @@ -52,7 +54,7 @@ func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableC si: si, currentDb: currentDb, done: map[*sqlparser.AliasedTableExpr]TableInfo{}, - cte: map[string]CTEDef{}, + cte: map[string]*CTEDef{}, } } @@ -62,9 +64,11 @@ func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { return true } for _, cte := range with.CTEs { - etc.cte[cte.ID.String()] = CTEDef{ - Query: cte.Subquery, - Columns: cte.Columns, + etc.cte[cte.ID.String()] = &CTEDef{ + Name: cte.ID.String(), + Query: cte.Subquery, + Columns: cte.Columns, + Recursive: with.Recursive, } } return true @@ -104,7 +108,7 @@ func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sq return } } - tableInfo, err := etc.getTableInfo(aet, tbl) + tableInfo, err := etc.getTableInfo(aet, tbl, nil) if err != nil { // this could just be a CTE that we haven't processed, so we'll give it the benefit of the doubt for now return @@ -311,7 +315,7 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq tableInfo, found = tc.done[node] if !found { - tableInfo, err = tc.earlyTableCollector.getTableInfo(node, t) + tableInfo, err = tc.earlyTableCollector.getTableInfo(node, t, tc.scoper) if err != nil { return err } @@ -322,14 +326,26 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq return scope.addTable(tableInfo) } -func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (TableInfo, error) { +func (etc *earlyTableCollector) getCTE(t sqlparser.TableName) *CTEDef { + if t.Qualifier.NotEmpty() { + return nil + } + + return etc.cte[t.Name.String()] +} + +func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper) (TableInfo, error) { var tbl *vindexes.Table var vindex vindexes.Vindex - if t.Qualifier.IsEmpty() { - // CTE handling will not be used in the early table collection - cteDef, isCte := etc.cte[t.Name.String()] - if isCte { - return newCTETable(node, t, cteDef), nil + if cteDef := etc.getCTE(t); cteDef != nil { + cte, err := etc.buildRecursiveCTE(node, t, sc, cteDef) + if err != nil { + return nil, err + } + if cte != nil { + // if we didn't get a table, it means we can't build a recursive CTE, + // so we need to look for a regular table instead + return cte, nil } } @@ -351,6 +367,57 @@ func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t return tableInfo, nil } +func (etc *earlyTableCollector) buildRecursiveCTE(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper, cteDef *CTEDef) (TableInfo, error) { + // If sc is nil, then we are in the early table collector. + // In early table collector, we don't go over the CTE definitions, so we must be seeing a usage of the CTE. + if sc != nil && len(sc.commonTableExprScopes) > 0 { + cte := sc.commonTableExprScopes[len(sc.commonTableExprScopes)-1] + if cte.ID.String() == t.Name.String() { + + if err := checkValidRecursiveCTE(cteDef); err != nil { + return nil, err + } + + cteTable := newCTETable(node, t, cteDef) + cteTableSet := SingleTableSet(len(etc.Tables)) + cteDef.IDForRecurse = &cteTableSet + if !cteDef.Recursive { + return nil, nil + } + return cteTable, nil + } + } + return &RealTable{ + tableName: node.TableNameString(), + ASTNode: node, + CTE: cteDef, + collationEnv: etc.si.Environment().CollationEnv(), + }, nil +} + +func checkValidRecursiveCTE(cteDef *CTEDef) error { + if cteDef.IDForRecurse != nil { + return vterrors.VT09029(cteDef.Name) + } + + union, isUnion := cteDef.Query.(*sqlparser.Union) + if !isUnion { + return vterrors.VT09026(cteDef.Name) + } + + firstSelect := sqlparser.GetFirstSelect(union.Right) + if firstSelect.GroupBy != nil { + return vterrors.VT09027(cteDef.Name) + } + + for _, expr := range firstSelect.GetColumns() { + if sqlparser.ContainsAggregation(expr) { + return vterrors.VT09027(cteDef.Name) + } + } + return nil +} + func (tc *tableCollector) handleDerivedTable(node *sqlparser.AliasedTableExpr, t *sqlparser.DerivedTable) error { switch sel := t.Select.(type) { case *sqlparser.Select: From 9e8f5ba3cc2f9230eb13ef63bf208be6d05f9473 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 30 Jul 2024 09:02:24 +0200 Subject: [PATCH 06/28] feat: use the new CTE field on RealTable Signed-off-by: Andres Taylor --- .../vtgate/planbuilder/operators/ast_to_op.go | 27 +++++++++------ .../planbuilder/operators/recurse_cte.go | 2 +- .../plancontext/planning_context.go | 34 ++++++++++++------- go/vt/vtgate/semantics/cte_table.go | 3 ++ 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 58402e48a68..fcceb620318 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -267,12 +267,12 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr Solved: tableID, } case *semantics.CTETable: - current := ctx.ActiveCTE() - if current != nil && current.CTEDef == tableInfo.CTEDef { - return createDualCTETable(ctx, tableID, tableInfo) - } - return createRecursiveCTE(ctx, tableInfo) + return createDualCTETable(ctx, tableID, tableInfo) case *semantics.RealTable: + if tableInfo.CTE != nil { + return createRecursiveCTE(ctx, tableInfo.CTE) + } + qg := newQueryGraph() isInfSchema := tableInfo.IsInfSchema() qt := &QueryTable{Alias: tableExpr, Table: tbl, ID: tableID, IsInfSchema: isInfSchema} @@ -314,8 +314,8 @@ func createDualCTETable(ctx *plancontext.PlanningContext, tableID semantics.Tabl return createRouteFromVSchemaTable(ctx, qtbl, vschemaTable, false, nil) } -func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTETable) Operator { - union, ok := def.CTEDef.Query.(*sqlparser.Union) +func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTEDef) Operator { + union, ok := def.Query.(*sqlparser.Union) if !ok { panic(vterrors.VT13001("expected UNION in recursive CTE")) } @@ -323,14 +323,19 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTETabl init := translateQueryToOp(ctx, union.Left) // Push the CTE definition to the stack so that it can be used in the recursive part of the query - ctx.PushCTE(def) + ctx.PushCTE(&plancontext.ContextCTE{ + CTEDef: def, + }) tail := translateQueryToOp(ctx, union.Right) - if err := ctx.PopCTE(); err != nil { + activeCTE, err := ctx.PopCTE() + if err != nil { panic(err) } - return newRecurse(def.TableName, init, tail) + for _, expression := range activeCTE.Expressions { + tail = tail.AddPredicate(ctx, expression.RightExpr) + } + return newRecurse(def.Name, init, tail, activeCTE.Expressions) } - func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Operator { var output Operator for _, tableExpr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 18261e7becc..8b70b52dfcf 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -39,7 +39,7 @@ type RecurseCTE struct { var _ Operator = (*RecurseCTE)(nil) -func newRecurse(name string, init, tail Operator) *RecurseCTE { +func newRecurse(name string, init, tail Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { return &RecurseCTE{ Name: name, Init: init, diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index b56793d0e6c..7e2c321eb25 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -70,7 +70,7 @@ type PlanningContext struct { // This is a stack of CTEs being built. It's used when we have CTEs inside CTEs, // to remember which is the CTE currently being assembled - CurrentCTE []*semantics.CTETable + CurrentCTE []*ContextCTE } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -382,21 +382,31 @@ func (ctx *PlanningContext) ContainsAggr(e sqlparser.SQLNode) (hasAggr bool) { return } -func (ctx *PlanningContext) PushCTE(def *semantics.CTETable) { - ctx.CurrentCTE = append(ctx.CurrentCTE, def) +type ContextCTE struct { + *semantics.CTEDef + Expressions []*RecurseExpression } -func (ctx *PlanningContext) PopCTE() error { - if len(ctx.CurrentCTE) == 0 { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no CTE to pop") - } - ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] - return nil +type RecurseExpression struct { + Original sqlparser.Expr + RightExpr sqlparser.Expr + LeftExpr []BindVarExpr +} + +type BindVarExpr struct { + Name string + Expr sqlparser.Expr } -func (ctx *PlanningContext) ActiveCTE() *semantics.CTETable { +func (ctx *PlanningContext) PushCTE(def *ContextCTE) { + ctx.CurrentCTE = append(ctx.CurrentCTE, def) +} + +func (ctx *PlanningContext) PopCTE() (*ContextCTE, error) { if len(ctx.CurrentCTE) == 0 { - return nil + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no CTE to pop") } - return ctx.CurrentCTE[len(ctx.CurrentCTE)-1] + activeCTE := ctx.CurrentCTE[len(ctx.CurrentCTE)-1] + ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] + return activeCTE, nil } diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 5854a989dd8..85407712d98 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -27,6 +27,9 @@ import ( ) // CTETable contains the information about the CTE table. +// This is a special TableInfo that is used to represent the recursive table inside a CTE. For the query: +// WITH RECURSIVE cte AS (SELECT 1 UNION ALL SELECT * FROM cte as C1) SELECT * FROM cte as C2 +// The CTE table C1 is represented by a CTETable. type CTETable struct { TableName string ASTNode *sqlparser.AliasedTableExpr From c38ae573b83b75646cae7ec9ca5f52fb5841ec08 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 30 Jul 2024 09:22:59 +0200 Subject: [PATCH 07/28] feat: handle star expansion on CTEs Signed-off-by: Andres Taylor --- go/vt/vtgate/semantics/early_rewriter_test.go | 3 ++ go/vt/vtgate/semantics/real_table.go | 51 ++++++++++++++----- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 16b3756189f..fab8211f74e 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -184,6 +184,9 @@ func TestExpandStar(t *testing.T) { // if we are only star-expanding authoritative tables, we don't need to stop the expansion sql: "SELECT * FROM (SELECT t2.*, 12 AS foo FROM t3, t2) as results", expSQL: "select c1, c2, foo from (select t2.c1, t2.c2, 12 as foo from t3, t2) as results", + }, { + sql: "with recursive hierarchy as (select t1.a, t1.b from t1 where t1.a is null union select t1.a, t1.b from t1 join hierarchy on t1.a = hierarchy.b) select * from hierarchy", + expSQL: "with recursive hierarchy as (select t1.a, t1.b from t1 where t1.a is null union select t1.a, t1.b from t1 join hierarchy on t1.a = hierarchy.b) select a, b from hierarchy", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index b11b736e79d..13eaf3a6da1 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -73,18 +73,17 @@ func (r *RealTable) IsInfSchema() bool { // GetColumns implements the TableInfo interface func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { - if r.Table == nil { - if r.CTE != nil { - selectExprs := r.CTE.Query.GetColumns() - ci := extractColumnsFromCTE(r.CTE.Columns, selectExprs) - if ci == nil { - return ci - } - return extractSelectExprsFromCTE(selectExprs) - } + switch { + case r.CTE != nil: + return r.getCTEColumns() + case r.Table == nil: return nil + default: + return r.getVindexTableColumns(ignoreInvisbleCol) } +} +func (r *RealTable) getVindexTableColumns(ignoreInvisbleCol bool) []ColumnInfo { nameMap := map[string]any{} cols := make([]ColumnInfo, 0, len(r.Table.Columns)) for _, col := range r.Table.Columns { @@ -117,6 +116,35 @@ func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { return cols } +func (r *RealTable) getCTEColumns() []ColumnInfo { + selectExprs := r.CTE.Query.GetColumns() + ci := extractColumnsFromCTE(r.CTE.Columns, selectExprs) + if ci != nil { + return ci + } + return extractSelectExprsFromCTE(selectExprs) +} + +// Authoritative implements the TableInfo interface +func (r *RealTable) authoritative() bool { + if r.Table != nil { + return r.Table.ColumnListAuthoritative + } + if r.CTE != nil { + if len(r.CTE.Columns) > 0 { + return true + } + for _, se := range r.CTE.Query.GetColumns() { + _, isAe := se.(*sqlparser.AliasedExpr) + if !isAe { + return false + } + } + return true + } + return false +} + func extractSelectExprsFromCTE(selectExprs sqlparser.SelectExprs) []ColumnInfo { var ci []ColumnInfo for _, expr := range selectExprs { @@ -187,11 +215,6 @@ func (r *RealTable) Name() (sqlparser.TableName, error) { return r.ASTNode.TableName() } -// Authoritative implements the TableInfo interface -func (r *RealTable) authoritative() bool { - return r.Table != nil && r.Table.ColumnListAuthoritative -} - // Matches implements the TableInfo interface func (r *RealTable) matches(name sqlparser.TableName) bool { return (name.Qualifier.IsEmpty() || name.Qualifier.String() == r.dbName) && r.tableName == name.Name.String() From 705c0792660bac167a7bccd20882f9838188746c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 31 Jul 2024 10:49:42 +0200 Subject: [PATCH 08/28] first plan with recursion at the vtgate level passes g Signed-off-by: Andres Taylor --- .../planbuilder/operator_transformers.go | 18 +++++ .../planbuilder/operators/SQL_builder.go | 28 +++++++- .../vtgate/planbuilder/operators/ast_to_op.go | 13 ++-- .../planbuilder/operators/cte_merging.go | 9 +-- go/vt/vtgate/planbuilder/operators/join.go | 34 ++++++++++ .../planbuilder/operators/recurse_cte.go | 64 +++++++++++++++--- .../plancontext/planning_context.go | 21 ++++-- .../planbuilder/testdata/cte_cases.json | 67 +++++++++++++++++++ .../planbuilder/testdata/from_cases.json | 22 ------ go/vt/vtgate/semantics/cte_table.go | 13 ++-- go/vt/vtgate/semantics/real_table.go | 2 +- go/vt/vtgate/semantics/table_collector.go | 12 ++-- 12 files changed, 242 insertions(+), 61 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 546a9854f26..ec2ef0d0f87 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -77,6 +77,8 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato return transformSequential(ctx, op) case *operators.DMLWithInput: return transformDMLWithInput(ctx, op) + case *operators.RecurseCTE: + return transformRecurseCTE(ctx, op) } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToPrimitive)", op)) @@ -981,6 +983,22 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) return prim, nil } +func transformRecurseCTE(ctx *plancontext.PlanningContext, op *operators.RecurseCTE) (engine.Primitive, error) { + init, err := transformToPrimitive(ctx, op.Init) + if err != nil { + return nil, err + } + tail, err := transformToPrimitive(ctx, op.Tail) + if err != nil { + return nil, err + } + return &engine.RecurseCTE{ + Init: init, + Recurse: tail, + Vars: op.Vars, + }, nil +} + func generateQuery(statement sqlparser.Statement) string { buf := sqlparser.NewTrackedBuffer(dmlFormatter) statement.Format(buf) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 0ab42e17038..1da867a20c4 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -55,6 +55,22 @@ func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement return q.stmt, q.dmlOperator, nil } +func (qb *queryBuilder) includeTable(op *Table) bool { + if qb.ctx.SemTable == nil { + return true + } + tbl, err := qb.ctx.SemTable.TableInfoFor(op.QTable.ID) + if err != nil { + return true + } + cteTbl, isCTE := tbl.(*semantics.CTETable) + if !isCTE { + return true + } + + return cteTbl.Merged +} + func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) @@ -524,6 +540,11 @@ func buildLimit(op *Limit, qb *queryBuilder) { } func buildTable(op *Table, qb *queryBuilder) { + toto := qb.includeTable(op) + if !toto { + return + } + dbName := "" if op.QTable.IsInfSchema { @@ -583,6 +604,11 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.RHS, qbR) + // if we have a recursive cte on the rhs, we might not have a statement + if qbR.stmt == nil { + return + } + qb.joinWith(qbR, pred, op.JoinType) } @@ -672,7 +698,7 @@ func buildCTE(op *RecurseCTE, qb *queryBuilder) { buildQuery(op.Init, qb) qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.Tail, qbR) - qb.cteWith(qbR, op.Name) + qb.cteWith(qbR, op.Def.Name) } func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index fcceb620318..94c124055f4 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -314,7 +314,7 @@ func createDualCTETable(ctx *plancontext.PlanningContext, tableID semantics.Tabl return createRouteFromVSchemaTable(ctx, qtbl, vschemaTable, false, nil) } -func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTEDef) Operator { +func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE) Operator { union, ok := def.Query.(*sqlparser.Union) if !ok { panic(vterrors.VT13001("expected UNION in recursive CTE")) @@ -323,19 +323,16 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTEDef) init := translateQueryToOp(ctx, union.Left) // Push the CTE definition to the stack so that it can be used in the recursive part of the query - ctx.PushCTE(&plancontext.ContextCTE{ - CTEDef: def, - }) + ctx.PushCTE(def, *def.IDForRecurse) tail := translateQueryToOp(ctx, union.Right) activeCTE, err := ctx.PopCTE() if err != nil { panic(err) } - for _, expression := range activeCTE.Expressions { - tail = tail.AddPredicate(ctx, expression.RightExpr) - } - return newRecurse(def.Name, init, tail, activeCTE.Expressions) + + return newRecurse(def, init, tail, activeCTE.Expressions) } + func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Operator { var output Operator for _, tableExpr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index 7c12a85c89b..e32a800c091 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -38,7 +38,7 @@ func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *Recu switch { case a == dual: - return mergeCTE(ctx, initRoute, tailRoute, routingB, in) + return mergeCTE(initRoute, tailRoute, routingB, in) case a == sharded && b == sharded: return tryMergeCTESharded(ctx, initRoute, tailRoute, in) default: @@ -58,7 +58,7 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in aExpr := tblA.VindexExpressions() bExpr := tblB.VindexExpressions() if aVdx == bVdx && gen4ValuesEqual(ctx, aExpr, bExpr) { - return mergeCTE(ctx, init, tail, tblA, in) + return mergeCTE(init, tail, tblA, in) } } } @@ -66,11 +66,12 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in return nil } -func mergeCTE(ctx *plancontext.PlanningContext, init, tail *Route, r Routing, in *RecurseCTE) *Route { +func mergeCTE(init, tail *Route, r Routing, in *RecurseCTE) *Route { + in.Def.Merged = true return &Route{ Routing: r, Source: &RecurseCTE{ - Name: in.Name, + Def: in.Def, ColumnNames: in.ColumnNames, Init: init.Source, Tail: tail.Source, diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 71d2e5a8048..c1a41f94827 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -17,6 +17,7 @@ limitations under the License. package operators import ( + "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -141,11 +142,44 @@ func addJoinPredicates( if subq != nil { continue } + + // if we are inside a CTE, we need to check if we depend on the recursion table + if cte := ctx.ActiveCTE(); cte != nil && ctx.SemTable.DirectDeps(pred).IsOverlapping(cte.Id) { + original := pred + pred = breakCTEExpressionInLhsAndRhs(ctx, pred, cte) + ctx.AddJoinPredicates(original, pred) + } op = op.AddPredicate(ctx, pred) } return sqc.getRootOperator(op, nil) } +// breakCTEExpressionInLhsAndRhs breaks the expression into LHS and RHS +func breakCTEExpressionInLhsAndRhs( + ctx *plancontext.PlanningContext, + pred sqlparser.Expr, + cte *plancontext.ContextCTE, +) sqlparser.Expr { + col := breakExpressionInLHSandRHS(ctx, pred, cte.Id) + + lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr { + col, ok := bve.Expr.(*sqlparser.ColName) + if !ok { + panic(vterrors.VT13001("expected column name")) + } + return plancontext.BindVarExpr{ + Name: bve.Name, + Expr: col, + } + }) + cte.Expressions = append(cte.Expressions, &plancontext.RecurseExpression{ + Original: col.Original, + RightExpr: col.RHSExpr, + LeftExprs: lhsExprs, + }) + return col.RHSExpr +} + func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { lqg, lok := LHS.(*QueryGraph) rqg, rok := RHS.(*QueryGraph) diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 8b70b52dfcf..f008a6daa00 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -17,7 +17,12 @@ limitations under the License. package operators import ( + "fmt" "slices" + "strings" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -25,8 +30,10 @@ import ( // RecurseCTE is used to represent a recursive CTE type RecurseCTE struct { - // Name is the name of the recursive CTE - Name string + Init, Tail Operator + + // Def is the CTE definition according to the semantics + Def *semantics.CTE // ColumnNames is the list of column names that are sent between the two parts of the recursive CTE ColumnNames []string @@ -34,24 +41,31 @@ type RecurseCTE struct { // ColumnOffsets is the list of column offsets that are sent between the two parts of the recursive CTE Offsets []int - Init, Tail Operator + // Expressions are the expressions that are needed on the recurse side of the CTE + Expressions []*plancontext.RecurseExpression + + // Vars is the map of variables that are sent between the two parts of the recursive CTE + // It's filled in at offset planning time + Vars map[string]int } var _ Operator = (*RecurseCTE)(nil) -func newRecurse(name string, init, tail Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { +func newRecurse(def *semantics.CTE, init, tail Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { return &RecurseCTE{ - Name: name, - Init: init, - Tail: tail, + Def: def, + Init: init, + Tail: tail, + Expressions: expressions, } } func (r *RecurseCTE) Clone(inputs []Operator) Operator { return &RecurseCTE{ - Name: r.Name, + Def: r.Def, ColumnNames: slices.Clone(r.ColumnNames), Offsets: slices.Clone(r.Offsets), + Expressions: slices.Clone(r.Expressions), Init: inputs[0], Tail: inputs[1], } @@ -91,9 +105,41 @@ func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser. return r.Init.GetSelectExprs(ctx) } -func (r *RecurseCTE) ShortDescription() string { return "" } +func (r *RecurseCTE) ShortDescription() string { + if len(r.Vars) > 0 { + return fmt.Sprintf("%v", r.Vars) + } + exprs := slice.Map(r.Expressions, func(expr *plancontext.RecurseExpression) string { + return sqlparser.String(expr.Original) + }) + return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(exprs, ", ")) +} func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { // RecurseCTE is a special case. It never guarantees any ordering. return nil } + +func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { + r.Vars = make(map[string]int) + columns := r.Init.GetColumns(ctx) + for _, expr := range r.Expressions { + outer: + for _, lhsExpr := range expr.LeftExprs { + _, found := r.Vars[lhsExpr.Name] + if found { + continue + } + + for offset, column := range columns { + if lhsExpr.Expr.Name.EqualString(column.ColumnName()) { + r.Vars[lhsExpr.Name] = offset + continue outer + } + } + + panic("couldn't find column") + } + } + return r +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 7e2c321eb25..91d2fc80e35 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -383,23 +383,27 @@ func (ctx *PlanningContext) ContainsAggr(e sqlparser.SQLNode) (hasAggr bool) { } type ContextCTE struct { - *semantics.CTEDef + *semantics.CTE + Id semantics.TableSet Expressions []*RecurseExpression } type RecurseExpression struct { Original sqlparser.Expr RightExpr sqlparser.Expr - LeftExpr []BindVarExpr + LeftExprs []BindVarExpr } type BindVarExpr struct { Name string - Expr sqlparser.Expr + Expr *sqlparser.ColName } -func (ctx *PlanningContext) PushCTE(def *ContextCTE) { - ctx.CurrentCTE = append(ctx.CurrentCTE, def) +func (ctx *PlanningContext) PushCTE(def *semantics.CTE, id semantics.TableSet) { + ctx.CurrentCTE = append(ctx.CurrentCTE, &ContextCTE{ + CTE: def, + Id: id, + }) } func (ctx *PlanningContext) PopCTE() (*ContextCTE, error) { @@ -410,3 +414,10 @@ func (ctx *PlanningContext) PopCTE() (*ContextCTE, error) { ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] return activeCTE, nil } + +func (ctx *PlanningContext) ActiveCTE() *ContextCTE { + if len(ctx.CurrentCTE) == 0 { + return nil + } + return ctx.CurrentCTE[len(ctx.CurrentCTE)-1] +} diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index d6647681103..5b2266a8185 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2105,5 +2105,72 @@ "user.user" ] } + }, + { + "comment": "Recursive CTE that cannot be merged", + "query": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from user e inner join cte on e.manager_id = cte.id) select name from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from user e inner join cte on e.manager_id = cte.id) select name from cte", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": "2", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, id, `name` from `user` where 1 != 1", + "Query": "select `name`, id, `name` from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name`, e.id from `user` as e where 1 != 1", + "Query": "select e.`name`, e.id from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Merge into a single dual route", + "query": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index a13e1b4c69a..31ec3ea9b6c 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4731,28 +4731,6 @@ ] } }, - { - "comment": "Merge into a single dual route", - "query": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", - "plan": { - "QueryType": "SELECT", - "Original": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", - "Instructions": { - "OperatorType": "Route", - "Variant": "Reference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "with cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", - "Query": "with cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", - "Table": "dual" - }, - "TablesUsed": [ - "main.dual" - ] - } - }, { "comment": "Recursive CTE with star projection", "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 85407712d98..92c29970e9c 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -33,12 +33,12 @@ import ( type CTETable struct { TableName string ASTNode *sqlparser.AliasedTableExpr - *CTEDef + *CTE } var _ TableInfo = (*CTETable)(nil) -func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef *CTEDef) *CTETable { +func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef *CTE) *CTETable { var name string if node.As.IsEmpty() { name = t.Name.String() @@ -59,7 +59,7 @@ func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef return &CTETable{ TableName: name, ASTNode: node, - CTEDef: cteDef, + CTE: cteDef, } } @@ -134,7 +134,7 @@ func (cte *CTETable) getTableSet(org originable) TableSet { return org.tableSetFor(cte.ASTNode) } -type CTEDef struct { +type CTE struct { Name string Query sqlparser.SelectStatement isAuthoritative bool @@ -144,9 +144,12 @@ type CTEDef struct { // Was this CTE marked for being recursive? Recursive bool + + // The CTE had the init and recursive parts merged + Merged bool } -func (cte *CTEDef) recursive(org originable) (id TableSet) { +func (cte *CTE) recursive(org originable) (id TableSet) { if cte.recursiveDeps != nil { return *cte.recursiveDeps } diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 13eaf3a6da1..88b5b8725ae 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -33,7 +33,7 @@ type RealTable struct { dbName, tableName string ASTNode *sqlparser.AliasedTableExpr Table *vindexes.Table - CTE *CTEDef + CTE *CTE VindexHint *sqlparser.IndexHint isInfSchema bool collationEnv *collations.Environment diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 7c9efd79513..16285307846 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -45,7 +45,7 @@ type ( done map[*sqlparser.AliasedTableExpr]TableInfo // cte is a map of CTE definitions that are used in the query - cte map[string]*CTEDef + cte map[string]*CTE } ) @@ -54,7 +54,7 @@ func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableC si: si, currentDb: currentDb, done: map[*sqlparser.AliasedTableExpr]TableInfo{}, - cte: map[string]*CTEDef{}, + cte: map[string]*CTE{}, } } @@ -64,7 +64,7 @@ func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { return true } for _, cte := range with.CTEs { - etc.cte[cte.ID.String()] = &CTEDef{ + etc.cte[cte.ID.String()] = &CTE{ Name: cte.ID.String(), Query: cte.Subquery, Columns: cte.Columns, @@ -326,7 +326,7 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq return scope.addTable(tableInfo) } -func (etc *earlyTableCollector) getCTE(t sqlparser.TableName) *CTEDef { +func (etc *earlyTableCollector) getCTE(t sqlparser.TableName) *CTE { if t.Qualifier.NotEmpty() { return nil } @@ -367,7 +367,7 @@ func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t return tableInfo, nil } -func (etc *earlyTableCollector) buildRecursiveCTE(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper, cteDef *CTEDef) (TableInfo, error) { +func (etc *earlyTableCollector) buildRecursiveCTE(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper, cteDef *CTE) (TableInfo, error) { // If sc is nil, then we are in the early table collector. // In early table collector, we don't go over the CTE definitions, so we must be seeing a usage of the CTE. if sc != nil && len(sc.commonTableExprScopes) > 0 { @@ -395,7 +395,7 @@ func (etc *earlyTableCollector) buildRecursiveCTE(node *sqlparser.AliasedTableEx }, nil } -func checkValidRecursiveCTE(cteDef *CTEDef) error { +func checkValidRecursiveCTE(cteDef *CTE) error { if cteDef.IDForRecurse != nil { return vterrors.VT09029(cteDef.Name) } From adedb5bc4b882b80ee30098d93add8f35687c0e5 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 31 Jul 2024 11:02:06 +0200 Subject: [PATCH 09/28] change terminology for recursice ctes to seed/term Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/cached_size.go | 4 +- go/vt/vtgate/engine/recurse_cte.go | 28 ++++++------- go/vt/vtgate/engine/recurse_cte_test.go | 6 +-- .../planbuilder/operator_transformers.go | 10 ++--- .../planbuilder/operators/SQL_builder.go | 4 +- .../planbuilder/operators/cte_merging.go | 31 +++++++-------- .../planbuilder/operators/recurse_cte.go | 39 ++++++++----------- go/vt/vtgate/semantics/cte_table.go | 2 +- 8 files changed, 58 insertions(+), 66 deletions(-) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 8c1fcc2f425..4e6b998a222 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -868,11 +868,11 @@ func (cached *RecurseCTE) CachedSize(alloc bool) int64 { size += int64(48) } // field Init vitess.io/vitess/go/vt/vtgate/engine.Primitive - if cc, ok := cached.Init.(cachedObject); ok { + if cc, ok := cached.Seed.(cachedObject); ok { size += cc.CachedSize(true) } // field Recurse vitess.io/vitess/go/vt/vtgate/engine.Primitive - if cc, ok := cached.Recurse.(cachedObject); ok { + if cc, ok := cached.Term.(cachedObject); ok { size += cc.CachedSize(true) } // field Vars map[string]int diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go index f8745a3036e..4cb45168919 100644 --- a/go/vt/vtgate/engine/recurse_cte.go +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -24,12 +24,12 @@ import ( ) // RecurseCTE is used to represent recursive CTEs -// Init is used to represent the initial query. -// It's result are then used to start the recursion on the RecurseCTE side -// The values being sent to the RecurseCTE side are stored in the Vars map - +// Seed is used to represent the non-recursive part that initializes the result set. +// It's result are then used to start the recursion on the Term side +// The values being sent to the Term side are stored in the Vars map - // the key is the bindvar name and the value is the index of the column in the recursive result type RecurseCTE struct { - Init, Recurse Primitive + Seed, Term Primitive Vars map[string]int } @@ -37,7 +37,7 @@ type RecurseCTE struct { var _ Primitive = (*RecurseCTE)(nil) func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - res, err := vcursor.ExecutePrimitive(ctx, r.Init, bindVars, wantfields) + res, err := vcursor.ExecutePrimitive(ctx, r.Seed, bindVars, wantfields) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars m for k, col := range r.Vars { joinVars[k] = sqltypes.ValueBindVariable(row[col]) } - rresult, err := vcursor.ExecutePrimitive(ctx, r.Recurse, combineVars(bindVars, joinVars), false) + rresult, err := vcursor.ExecutePrimitive(ctx, r.Term, combineVars(bindVars, joinVars), false) if err != nil { return nil, err } @@ -72,7 +72,7 @@ func (r *RecurseCTE) TryStreamExecute(ctx context.Context, vcursor VCursor, bind } return callback(res) } - return vcursor.StreamExecutePrimitive(ctx, r.Init, bindVars, wantfields, func(result *sqltypes.Result) error { + return vcursor.StreamExecutePrimitive(ctx, r.Seed, bindVars, wantfields, func(result *sqltypes.Result) error { err := callback(result) if err != nil { return err @@ -91,7 +91,7 @@ func (r *RecurseCTE) recurse(ctx context.Context, vcursor VCursor, bindvars map[ joinVars[k] = sqltypes.ValueBindVariable(row[col]) } - err := vcursor.StreamExecutePrimitive(ctx, r.Recurse, combineVars(bindvars, joinVars), false, func(result *sqltypes.Result) error { + err := vcursor.StreamExecutePrimitive(ctx, r.Term, combineVars(bindvars, joinVars), false, func(result *sqltypes.Result) error { err := callback(result) if err != nil { return err @@ -110,18 +110,18 @@ func (r *RecurseCTE) RouteType() string { } func (r *RecurseCTE) GetKeyspaceName() string { - if r.Init.GetKeyspaceName() == r.Recurse.GetKeyspaceName() { - return r.Init.GetKeyspaceName() + if r.Seed.GetKeyspaceName() == r.Term.GetKeyspaceName() { + return r.Seed.GetKeyspaceName() } - return r.Init.GetKeyspaceName() + "_" + r.Recurse.GetKeyspaceName() + return r.Seed.GetKeyspaceName() + "_" + r.Term.GetKeyspaceName() } func (r *RecurseCTE) GetTableName() string { - return r.Init.GetTableName() + return r.Seed.GetTableName() } func (r *RecurseCTE) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return r.Init.GetFields(ctx, vcursor, bindVars) + return r.Seed.GetFields(ctx, vcursor, bindVars) } func (r *RecurseCTE) NeedsTransaction() bool { @@ -129,7 +129,7 @@ func (r *RecurseCTE) NeedsTransaction() bool { } func (r *RecurseCTE) Inputs() ([]Primitive, []map[string]any) { - return []Primitive{r.Init, r.Recurse}, nil + return []Primitive{r.Seed, r.Term}, nil } func (r *RecurseCTE) description() PrimitiveDescription { diff --git a/go/vt/vtgate/engine/recurse_cte_test.go b/go/vt/vtgate/engine/recurse_cte_test.go index 674b4f3533d..d6826284d21 100644 --- a/go/vt/vtgate/engine/recurse_cte_test.go +++ b/go/vt/vtgate/engine/recurse_cte_test.go @@ -67,9 +67,9 @@ func TestRecurseDualQuery(t *testing.T) { bv := map[string]*querypb.BindVariable{} cte := &RecurseCTE{ - Init: leftPrim, - Recurse: rightPrim, - Vars: map[string]int{"col1": 0}, + Seed: leftPrim, + Term: rightPrim, + Vars: map[string]int{"col1": 0}, } r, err := cte.TryExecute(context.Background(), &noopVCursor{}, bv, true) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index ec2ef0d0f87..b4aaf6fc64d 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -984,18 +984,18 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) } func transformRecurseCTE(ctx *plancontext.PlanningContext, op *operators.RecurseCTE) (engine.Primitive, error) { - init, err := transformToPrimitive(ctx, op.Init) + seed, err := transformToPrimitive(ctx, op.Seed) if err != nil { return nil, err } - tail, err := transformToPrimitive(ctx, op.Tail) + term, err := transformToPrimitive(ctx, op.Term) if err != nil { return nil, err } return &engine.RecurseCTE{ - Init: init, - Recurse: tail, - Vars: op.Vars, + Seed: seed, + Term: term, + Vars: op.Vars, }, nil } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 1da867a20c4..bddce58c620 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -695,9 +695,9 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { } func buildCTE(op *RecurseCTE, qb *queryBuilder) { - buildQuery(op.Init, qb) + buildQuery(op.Seed, qb) qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.Tail, qbR) + buildQuery(op.Term, qbR) qb.cteWith(qbR, op.Def.Name) } diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index e32a800c091..4252b2834f5 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -22,7 +22,7 @@ import ( ) func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator, *ApplyResult) { - op := tryMergeCTE(ctx, in.Init, in.Tail, in) + op := tryMergeCTE(ctx, in.Seed, in.Term, in) if op == nil { return in, NoRewrite } @@ -30,25 +30,25 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator return op, Rewrote("Merged CTE") } -func tryMergeCTE(ctx *plancontext.PlanningContext, init, tail Operator, in *RecurseCTE) *Route { - initRoute, tailRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(init, tail) - if initRoute == nil || !sameKeyspace { +func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route { + seedRoute, termRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term) + if seedRoute == nil || !sameKeyspace { return nil } switch { case a == dual: - return mergeCTE(initRoute, tailRoute, routingB, in) + return mergeCTE(seedRoute, termRoute, routingB, in) case a == sharded && b == sharded: - return tryMergeCTESharded(ctx, initRoute, tailRoute, in) + return tryMergeCTESharded(ctx, seedRoute, termRoute, in) default: return nil } } -func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in *RecurseCTE) *Route { - tblA := init.Routing.(*ShardedRouting) - tblB := tail.Routing.(*ShardedRouting) +func tryMergeCTESharded(ctx *plancontext.PlanningContext, seed, term *Route, in *RecurseCTE) *Route { + tblA := seed.Routing.(*ShardedRouting) + tblB := term.Routing.(*ShardedRouting) switch tblA.RouteOpCode { case engine.EqualUnique: // If the two routes fully match, they can be merged together. @@ -58,7 +58,7 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in aExpr := tblA.VindexExpressions() bExpr := tblB.VindexExpressions() if aVdx == bVdx && gen4ValuesEqual(ctx, aExpr, bExpr) { - return mergeCTE(init, tail, tblA, in) + return mergeCTE(seed, term, tblA, in) } } } @@ -66,16 +66,15 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, init, tail *Route, in return nil } -func mergeCTE(init, tail *Route, r Routing, in *RecurseCTE) *Route { +func mergeCTE(seed, term *Route, r Routing, in *RecurseCTE) *Route { in.Def.Merged = true return &Route{ Routing: r, Source: &RecurseCTE{ - Def: in.Def, - ColumnNames: in.ColumnNames, - Init: init.Source, - Tail: tail.Source, + Def: in.Def, + Seed: seed.Source, + Term: term.Source, }, - MergedWith: []*Route{tail}, + MergedWith: []*Route{term}, } } diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index f008a6daa00..4bbcf7e8e00 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -30,17 +30,12 @@ import ( // RecurseCTE is used to represent a recursive CTE type RecurseCTE struct { - Init, Tail Operator + Seed, // used to describe the non-recursive part that initializes the result set + Term Operator // the part that repeatedly applies the recursion, processing the result set // Def is the CTE definition according to the semantics Def *semantics.CTE - // ColumnNames is the list of column names that are sent between the two parts of the recursive CTE - ColumnNames []string - - // ColumnOffsets is the list of column offsets that are sent between the two parts of the recursive CTE - Offsets []int - // Expressions are the expressions that are needed on the recurse side of the CTE Expressions []*plancontext.RecurseExpression @@ -51,11 +46,11 @@ type RecurseCTE struct { var _ Operator = (*RecurseCTE)(nil) -func newRecurse(def *semantics.CTE, init, tail Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { +func newRecurse(def *semantics.CTE, seed, term Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { return &RecurseCTE{ Def: def, - Init: init, - Tail: tail, + Seed: seed, + Term: term, Expressions: expressions, } } @@ -63,30 +58,28 @@ func newRecurse(def *semantics.CTE, init, tail Operator, expressions []*plancont func (r *RecurseCTE) Clone(inputs []Operator) Operator { return &RecurseCTE{ Def: r.Def, - ColumnNames: slices.Clone(r.ColumnNames), - Offsets: slices.Clone(r.Offsets), Expressions: slices.Clone(r.Expressions), - Init: inputs[0], - Tail: inputs[1], + Seed: inputs[0], + Term: inputs[1], } } func (r *RecurseCTE) Inputs() []Operator { - return []Operator{r.Init, r.Tail} + return []Operator{r.Seed, r.Term} } func (r *RecurseCTE) SetInputs(operators []Operator) { - r.Init = operators[0] - r.Tail = operators[1] + r.Seed = operators[0] + r.Term = operators[1] } func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { - r.Tail = newFilter(r, e) + r.Term = newFilter(r, e) return r } func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { - return r.Init.AddColumn(ctx, reuseExisting, addToGroupBy, expr) + return r.Seed.AddColumn(ctx, reuseExisting, addToGroupBy, expr) } func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { @@ -94,15 +87,15 @@ func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { } func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { - return r.Init.FindCol(ctx, expr, underRoute) + return r.Seed.FindCol(ctx, expr, underRoute) } func (r *RecurseCTE) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { - return r.Init.GetColumns(ctx) + return r.Seed.GetColumns(ctx) } func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { - return r.Init.GetSelectExprs(ctx) + return r.Seed.GetSelectExprs(ctx) } func (r *RecurseCTE) ShortDescription() string { @@ -122,7 +115,7 @@ func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { r.Vars = make(map[string]int) - columns := r.Init.GetColumns(ctx) + columns := r.Seed.GetColumns(ctx) for _, expr := range r.Expressions { outer: for _, lhsExpr := range expr.LeftExprs { diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 92c29970e9c..2d1df5a0068 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -145,7 +145,7 @@ type CTE struct { // Was this CTE marked for being recursive? Recursive bool - // The CTE had the init and recursive parts merged + // The CTE had the seed and term parts merged Merged bool } From 6c9ca09f32b8a0cab0511b6a7e05ac885453c570 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 31 Jul 2024 11:17:25 +0200 Subject: [PATCH 10/28] codegen Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/cached_size.go | 4 ++-- go/vt/vtgate/planbuilder/operators/SQL_builder.go | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 4e6b998a222..06aa9f0d6a9 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -867,11 +867,11 @@ func (cached *RecurseCTE) CachedSize(alloc bool) int64 { if alloc { size += int64(48) } - // field Init vitess.io/vitess/go/vt/vtgate/engine.Primitive + // field Seed vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Seed.(cachedObject); ok { size += cc.CachedSize(true) } - // field Recurse vitess.io/vitess/go/vt/vtgate/engine.Primitive + // field Term vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Term.(cachedObject); ok { size += cc.CachedSize(true) } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index bddce58c620..dab19b878aa 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -55,6 +55,8 @@ func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement return q.stmt, q.dmlOperator, nil } +// includeTable will return false if the table is a CTE, and it is not merged +// it will return true if the table is not a CTE or if it is a CTE and it is merged func (qb *queryBuilder) includeTable(op *Table) bool { if qb.ctx.SemTable == nil { return true @@ -540,8 +542,7 @@ func buildLimit(op *Limit, qb *queryBuilder) { } func buildTable(op *Table, qb *queryBuilder) { - toto := qb.includeTable(op) - if !toto { + if !qb.includeTable(op) { return } From c6277bb73a0079caf18523a366867546fa94a773 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 31 Jul 2024 11:23:09 +0200 Subject: [PATCH 11/28] panic on the error, no the string Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operators/SQL_builder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index dab19b878aa..1f42afc2671 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -77,7 +77,7 @@ func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics. if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) if err != nil { - panic(err.Error()) + panic(err) } cte, isCTE := tblInfo.(*semantics.CTETable) if isCTE { From 426aa01472e89fd787b37325a6607c845e4c7063 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 1 Aug 2024 13:16:10 +0200 Subject: [PATCH 12/28] mark query as single unsharded keyspace when possible Signed-off-by: Andres Taylor --- go/vt/vtgate/semantics/analyzer.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 5be67c63436..dd06f770462 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -387,16 +387,16 @@ func (a *analyzer) reAnalyze(statement sqlparser.SQLNode) error { // canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys // if so, we can stop the analyzer early func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) { - if a.fullAnalysis { - return false - } - ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables) a.singleUnshardedKeyspace = ks != nil if !a.singleUnshardedKeyspace { return false } + if a.fullAnalysis { + return false + } + defer func() { a.canShortcut = canShortCut }() From 94d8fd55a435f8892fd4052f0aeb6f6e9edbe25e Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 1 Aug 2024 15:25:53 +0200 Subject: [PATCH 13/28] move test Signed-off-by: Andres Taylor --- .../planbuilder/testdata/cte_cases.json | 22 +++++++++++++++++++ .../planbuilder/testdata/from_cases.json | 22 ------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 5b2266a8185..8e7157968a8 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2172,5 +2172,27 @@ "main.dual" ] } + }, + { + "comment": "Recursive CTE with star projection", + "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 31ec3ea9b6c..2e0fe429c1f 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4731,28 +4731,6 @@ ] } }, - { - "comment": "Recursive CTE with star projection", - "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - "plan": { - "QueryType": "SELECT", - "Original": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - "Instructions": { - "OperatorType": "Route", - "Variant": "Reference", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "with cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", - "Query": "with cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", - "Table": "dual" - }, - "TablesUsed": [ - "main.dual" - ] - } - }, { "comment": "Cross keyspace join", "query": "select 1 from user join t1 on user.id = t1.id", From 8fd269ad98d9d6ee3145a1775f5a3ff288054bdb Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 5 Aug 2024 13:50:30 +0200 Subject: [PATCH 14/28] added end-to-end tests for recursive queries Signed-off-by: Andres Taylor --- .../vtgate/vitess_tester/cte/queries.test | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 go/test/endtoend/vtgate/vitess_tester/cte/queries.test diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test new file mode 100644 index 00000000000..eeab097d615 --- /dev/null +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -0,0 +1,116 @@ +# Create tables +CREATE TABLE employees +( + id INT PRIMARY KEY, + name VARCHAR(100), + manager_id INT +); + +# Insert data into the tables +INSERT INTO employees (id, name, manager_id) +VALUES (1, 'CEO', NULL), + (2, 'CTO', 1), + (3, 'CFO', 1), + (4, 'Engineer1', 2), + (5, 'Engineer2', 2), + (6, 'Accountant1', 3), + (7, 'Accountant2', 3); + +# Simple recursive CTE using literal values +WITH RECURSIVE numbers AS (SELECT 1 AS n + UNION ALL + SELECT n + 1 + FROM numbers + WHERE n < 5) +SELECT * +FROM numbers; + +# Recursive CTE joined with a normal table +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte; + +# Recursive CTE used in a derived table outside the CTE definition +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT derived.id, derived.name, derived.manager_id +FROM (SELECT * FROM emp_cte) AS derived; + +# Recursive CTE with additional computation +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id, 1 AS level + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id, cte.level + 1 + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte; + +# Recursive CTE with filtering in the recursive part +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id + WHERE e.name LIKE 'Engineer%') +SELECT * +FROM emp_cte; + +# Recursive CTE with limit +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte +LIMIT 5; + +# Recursive CTE with DISTINCT to avoid duplicates +WITH RECURSIVE distinct_emp_cte AS (SELECT DISTINCT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT DISTINCT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN distinct_emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM distinct_emp_cte; + +# Recursive CTE with aggregation outside the CTE +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT manager_id, COUNT(*) AS employee_count +FROM emp_cte +GROUP BY manager_id; + +# Recursive CTE using literal values and joined with a real table on the outside +WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 'Root' AS name, NULL AS manager_id + UNION ALL + SELECT id + 1, CONCAT('Node', id + 1), id + FROM literal_cte + WHERE id < 5) +SELECT l.id, l.name, l.manager_id, e.name AS employee_name +FROM literal_cte l + LEFT JOIN employees e ON l.id = e.id; \ No newline at end of file From 5746dbaf3665f0ebce4b9fecea357d05d9069d3a Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 5 Aug 2024 15:51:36 +0200 Subject: [PATCH 15/28] also handle CTE on the LHS of joins Signed-off-by: Andres Taylor --- .../planbuilder/operators/SQL_builder.go | 14 +++--- .../planbuilder/testdata/cte_cases.json | 45 +++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 1f42afc2671..b887ca75950 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -605,12 +605,16 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.RHS, qbR) - // if we have a recursive cte on the rhs, we might not have a statement - if qbR.stmt == nil { - return - } - qb.joinWith(qbR, pred, op.JoinType) + switch { + // if we have a recursive cte, we might be missing a statement from one of the sides + case qbR.stmt == nil: + // do nothing + case qb.stmt == nil: + qb.stmt = qbR.stmt + default: + qb.joinWith(qbR, pred, op.JoinType) + } } func buildUnion(op *Union, qb *queryBuilder) { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 8e7157968a8..6493222fc01 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2151,6 +2151,51 @@ ] } }, + { + "comment": "Recursive CTE that cannot be merged 2", + "query": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from cte join user e on e.manager_id = cte.id) select name from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from cte join user e on e.manager_id = cte.id) select name from cte", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": "2", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, id, `name` from `user` where 1 != 1", + "Query": "select `name`, id, `name` from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name`, e.id from `user` as e where 1 != 1", + "Query": "select e.`name`, e.id from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, { "comment": "Merge into a single dual route", "query": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", From def7c2809ffb8c4369e176b1735165246fa5b47e Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 5 Aug 2024 16:03:22 +0200 Subject: [PATCH 16/28] make sure to mark CTE as recursive Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operators/SQL_builder.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index b887ca75950..eabe12aae85 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -244,6 +244,7 @@ func (qb *queryBuilder) cteWith(other *queryBuilder, name string) { qb.stmt = &sqlparser.Select{ With: &sqlparser.With{ + Recursive: true, CTEs: []*sqlparser.CommonTableExpr{{ ID: sqlparser.NewIdentifierCS(name), Columns: nil, From c7470c5a03ea0934a95143146fb5e9adf7f1477d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 5 Aug 2024 16:47:19 +0200 Subject: [PATCH 17/28] rewrite expressions passing into ctes Signed-off-by: Andres Taylor --- .../vtgate/vitess_tester/cte/queries.test | 62 ++++++++++++++++++- .../planbuilder/operators/recurse_cte.go | 33 ++++++++-- .../planbuilder/testdata/cte_cases.json | 26 +++++--- go/vt/vtgate/semantics/cte_table.go | 9 +++ 4 files changed, 115 insertions(+), 15 deletions(-) diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index eeab097d615..baa0682a179 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -113,4 +113,64 @@ WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 'Root' AS name, NULL AS manager_i WHERE id < 5) SELECT l.id, l.name, l.manager_id, e.name AS employee_name FROM literal_cte l - LEFT JOIN employees e ON l.id = e.id; \ No newline at end of file + LEFT JOIN employees e ON l.id = e.id; + +# Recursive CTE for generating a series of numbers +WITH RECURSIVE + number_series AS (SELECT 1 AS n + UNION ALL + SELECT n + 1 + FROM number_series + WHERE n < 5), + +# Recursive CTE that uses the number series + number_names AS (SELECT n, CONCAT('Number', n) AS name + FROM number_series) +SELECT * +FROM number_names; + +# Recursive CTE for generating a series of numbers +WITH RECURSIVE + number_series AS (SELECT 1 AS n + UNION ALL + SELECT n + 1 + FROM number_series + WHERE n < 5), + +# Independent recursive CTE for employees hierarchy + employee_hierarchy AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN employee_hierarchy eh ON e.manager_id = eh.id) + +# Joining results from both CTEs +SELECT ns.n, ns.n AS number, eh.id, eh.name, eh.manager_id +FROM number_series ns + JOIN employee_hierarchy eh ON ns.n = eh.id; + +# Recursive CTE for generating a series of numbers +WITH RECURSIVE + number_series AS (SELECT 1 AS n + UNION ALL + SELECT n + 1 + FROM number_series + WHERE n < 5), + +# Independent recursive CTE for employees hierarchy + employee_hierarchy AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN employee_hierarchy eh ON e.manager_id = eh.id) + +# Union results from both CTEs +SELECT n AS id, CONCAT('Number', n) AS name, NULL AS manager_id +FROM number_series +UNION +SELECT id, name, manager_id +FROM employee_hierarchy; \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 4bbcf7e8e00..a089eec2218 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -22,10 +22,10 @@ import ( "strings" "vitess.io/vitess/go/slice" - "vitess.io/vitess/go/vt/vtgate/semantics" - "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" ) // RecurseCTE is used to represent a recursive CTE @@ -42,6 +42,9 @@ type RecurseCTE struct { // Vars is the map of variables that are sent between the two parts of the recursive CTE // It's filled in at offset planning time Vars map[string]int + + // MyTableID is the id of the CTE + MyTableInfo *semantics.CTETable } var _ Operator = (*RecurseCTE)(nil) @@ -78,8 +81,28 @@ func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Ex return r } -func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { - return r.Seed.AddColumn(ctx, reuseExisting, addToGroupBy, expr) +func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int { + r.makeSureWeHaveTableInfo(ctx) + e := semantics.RewriteDerivedTableExpression(expr.Expr, r.MyTableInfo) + return r.Seed.FindCol(ctx, e, false) +} + +func (r *RecurseCTE) makeSureWeHaveTableInfo(ctx *plancontext.PlanningContext) { + if r.MyTableInfo == nil { + for _, table := range ctx.SemTable.Tables { + cte, ok := table.(*semantics.CTETable) + if !ok { + continue + } + if cte.CTE == r.Def { + r.MyTableInfo = cte + break + } + } + if r.MyTableInfo == nil { + panic(vterrors.VT13001("CTE not found")) + } + } } func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { @@ -87,6 +110,8 @@ func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { } func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + r.makeSureWeHaveTableInfo(ctx) + expr = semantics.RewriteDerivedTableExpression(expr, r.MyTableInfo) return r.Seed.FindCol(ctx, expr, underRoute) } diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 6493222fc01..35ef3430325 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2114,7 +2114,10 @@ "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from user e inner join cte on e.manager_id = cte.id) select name from cte", "Instructions": { "OperatorType": "SimpleProjection", - "Columns": "2", + "ColumnNames": [ + "0:name" + ], + "Columns": "0", "Inputs": [ { "OperatorType": "RecurseCTE", @@ -2126,8 +2129,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `name`, id, `name` from `user` where 1 != 1", - "Query": "select `name`, id, `name` from `user` where manager_id is null", + "FieldQuery": "select `name`, id from `user` where 1 != 1", + "Query": "select `name`, id from `user` where manager_id is null", "Table": "`user`" }, { @@ -2159,7 +2162,10 @@ "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from cte join user e on e.manager_id = cte.id) select name from cte", "Instructions": { "OperatorType": "SimpleProjection", - "Columns": "2", + "ColumnNames": [ + "0:name" + ], + "Columns": "0", "Inputs": [ { "OperatorType": "RecurseCTE", @@ -2171,8 +2177,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `name`, id, `name` from `user` where 1 != 1", - "Query": "select `name`, id, `name` from `user` where manager_id is null", + "FieldQuery": "select `name`, id from `user` where 1 != 1", + "Query": "select `name`, id from `user` where manager_id is null", "Table": "`user`" }, { @@ -2209,8 +2215,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "with cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", - "Query": "with cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", + "FieldQuery": "with recursive cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with recursive cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", "Table": "dual" }, "TablesUsed": [ @@ -2231,8 +2237,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "with cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", - "Query": "with cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", + "FieldQuery": "with recursive cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with recursive cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", "Table": "dual" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 2d1df5a0068..320189ff871 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -127,6 +127,15 @@ func (cte *CTETable) dependencies(colName string, org originable) (dependencies, } func (cte *CTETable) getExprFor(s string) (sqlparser.Expr, error) { + for _, se := range cte.Query.GetColumns() { + ae, ok := se.(*sqlparser.AliasedExpr) + if !ok { + return nil, vterrors.VT09015() + } + if ae.ColumnName() == s { + return ae.Expr, nil + } + } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Unknown column '%s' in 'field list'", s) } From b6e04a5abc68df4452bfa8245a1452937fccebb9 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 7 Aug 2024 13:36:45 +0200 Subject: [PATCH 18/28] handle projections on the term using expressions from the seed Signed-off-by: Andres Taylor --- .../vtgate/vitess_tester/cte/queries.test | 12 ++- go/vt/vtgate/engine/recurse_cte.go | 6 ++ .../vtgate/planbuilder/operators/ast_to_op.go | 24 ++++- go/vt/vtgate/planbuilder/operators/join.go | 20 ++-- go/vt/vtgate/planbuilder/operators/phases.go | 92 ++++++++++++++++++- .../planbuilder/operators/recurse_cte.go | 17 +++- .../plancontext/planning_context.go | 4 +- .../vtgate/planbuilder/testdata/onecase.json | 3 +- go/vt/vtgate/semantics/analyzer.go | 5 + 9 files changed, 165 insertions(+), 18 deletions(-) diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index baa0682a179..ca4d8b1785e 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -173,4 +173,14 @@ SELECT n AS id, CONCAT('Number', n) AS name, NULL AS manager_id FROM number_series UNION SELECT id, name, manager_id -FROM employee_hierarchy; \ No newline at end of file +FROM employee_hierarchy; + +WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, cte.level + 1 + FROM employees e + JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte; \ No newline at end of file diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go index 4cb45168919..f50646ad271 100644 --- a/go/vt/vtgate/engine/recurse_cte.go +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -133,7 +133,13 @@ func (r *RecurseCTE) Inputs() ([]Primitive, []map[string]any) { } func (r *RecurseCTE) description() PrimitiveDescription { + other := map[string]interface{}{ + "JoinVars": orderedStringIntMap(r.Vars), + } + return PrimitiveDescription{ OperatorType: "RecurseCTE", + Other: other, + Inputs: nil, } } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 94c124055f4..a8c60eb9f5c 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -320,17 +320,35 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE) Op panic(vterrors.VT13001("expected UNION in recursive CTE")) } - init := translateQueryToOp(ctx, union.Left) + seed := translateQueryToOp(ctx, union.Left) // Push the CTE definition to the stack so that it can be used in the recursive part of the query ctx.PushCTE(def, *def.IDForRecurse) - tail := translateQueryToOp(ctx, union.Right) + + term := translateQueryToOp(ctx, union.Right) + horizon, ok := term.(*Horizon) + if !ok { + panic(vterrors.VT09027(def.Name)) + } + term = horizon.Source + horizon.Source = nil // not sure about this activeCTE, err := ctx.PopCTE() if err != nil { panic(err) } - return newRecurse(def, init, tail, activeCTE.Expressions) + return newRecurse(def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def)) +} + +func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet { + for i, table := range ctx.SemTable.Tables { + tbl, ok := table.(*semantics.CTETable) + if !ok || tbl.CTE.Name != def.Name { + continue + } + return semantics.SingleTableSet(i) + } + panic(vterrors.VT13001("recursive table not found")) } func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index c1a41f94827..35760bceafb 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" ) // Join represents a join. If we have a predicate, this is an inner join. If no predicate exists, it is a cross join @@ -146,7 +147,7 @@ func addJoinPredicates( // if we are inside a CTE, we need to check if we depend on the recursion table if cte := ctx.ActiveCTE(); cte != nil && ctx.SemTable.DirectDeps(pred).IsOverlapping(cte.Id) { original := pred - pred = breakCTEExpressionInLhsAndRhs(ctx, pred, cte) + pred = addCTEPredicate(ctx, pred, cte) ctx.AddJoinPredicates(original, pred) } op = op.AddPredicate(ctx, pred) @@ -154,13 +155,19 @@ func addJoinPredicates( return sqc.getRootOperator(op, nil) } -// breakCTEExpressionInLhsAndRhs breaks the expression into LHS and RHS -func breakCTEExpressionInLhsAndRhs( +// addCTEPredicate breaks the expression into LHS and RHS +func addCTEPredicate( ctx *plancontext.PlanningContext, pred sqlparser.Expr, cte *plancontext.ContextCTE, ) sqlparser.Expr { - col := breakExpressionInLHSandRHS(ctx, pred, cte.Id) + expr := breakCTEExpressionInLhsAndRhs(ctx, pred, cte.Id) + cte.Predicates = append(cte.Predicates, expr) + return expr.RightExpr +} + +func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlparser.Expr, lhsID semantics.TableSet) *plancontext.RecurseExpression { + col := breakExpressionInLHSandRHS(ctx, pred, lhsID) lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr { col, ok := bve.Expr.(*sqlparser.ColName) @@ -172,12 +179,11 @@ func breakCTEExpressionInLhsAndRhs( Expr: col, } }) - cte.Expressions = append(cte.Expressions, &plancontext.RecurseExpression{ + return &plancontext.RecurseExpression{ Original: col.Original, RightExpr: col.RHSExpr, LeftExprs: lhsExprs, - }) - return col.RHSExpr + } } func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index bf8e96372bc..3623efe12a0 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -207,7 +207,7 @@ func removePerformanceDistinctAboveRoute(_ *plancontext.PlanningContext, op Oper } func enableDelegateAggregation(ctx *plancontext.PlanningContext, op Operator) Operator { - return addColumnsToInput(ctx, op) + return prepareForAggregationPushing(ctx, op) } // addColumnsToInput adds columns needed by an operator to its input. @@ -341,3 +341,93 @@ func addLiteralGroupingToRHS(in *ApplyJoin) (Operator, *ApplyResult) { }) return in, NoRewrite } + +// prepareForAggregationPushing adds columns needed by an operator to its input. +// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. +func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operator) Operator { + addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + filter, ok := in.(*Filter) + if !ok { + return in, NoRewrite + } + + var neededAggrs []sqlparser.Expr + extractAggrs := func(cursor *sqlparser.CopyOnWriteCursor) { + node := cursor.Node() + if ctx.IsAggr(node) { + neededAggrs = append(neededAggrs, node.(sqlparser.Expr)) + } + } + + for _, expr := range filter.Predicates { + _ = sqlparser.CopyOnRewrite(expr, dontEnterSubqueries, extractAggrs, nil) + } + + if neededAggrs == nil { + return in, NoRewrite + } + + addedCols := false + aggregator := findAggregatorInSource(filter.Source) + for _, aggr := range neededAggrs { + if aggregator.FindCol(ctx, aggr, false) == -1 { + aggregator.addColumnWithoutPushing(ctx, aeWrap(aggr), false) + addedCols = true + } + } + + if addedCols { + return in, Rewrote("added columns because filter needs it") + } + return in, NoRewrite + } + + step1 := TopDown(root, TableID, addColumnsNeededByFilter, stopAtRoute) + pushDownHorizon := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + // These recursive CTEs have not been pushed under a route, so we will have to evaluate it one the vtgate + // That means that we need to turn anything that is coming from the recursion into arguments + rcte, ok := in.(*RecurseCTE) + if !ok { + return in, NoRewrite + } + hz := rcte.Horizon + hz.Source = rcte.Term + newTerm, _ := expandHorizon(ctx, hz) + pr := findProjection(newTerm) + ap, err := pr.GetAliasedProjections() + if err != nil { + panic(vterrors.VT09015()) + } + + // We need to break the expressions into LHS and RHS, and store them in the CTE for later use + expressions := slice.Map(ap, func(p *ProjExpr) *plancontext.RecurseExpression { + recurseExpression := breakCTEExpressionInLhsAndRhs(ctx, p.EvalExpr, rcte.LHSId) + p.EvalExpr = recurseExpression.RightExpr + return recurseExpression + }) + rcte.Expressions = append(rcte.Expressions, expressions...) + rcte.Term = newTerm + return rcte, Rewrote("expanded horizon on term side of recursive CTE") + } + + return TopDown(step1, TableID, pushDownHorizon, stopAtRoute) +} + +func findProjection(op Operator) *Projection { + for { + proj, ok := op.(*Projection) + if ok { + return proj + } + inputs := op.Inputs() + if len(inputs) != 1 { + panic(vterrors.VT13001("unexpected multiple inputs")) + } + src := inputs[0] + _, isRoute := src.(*Route) + if isRoute { + panic(vterrors.VT13001("failed to find the projection")) + } + op = src + } +} diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index a089eec2218..3267ff5727f 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -36,7 +36,7 @@ type RecurseCTE struct { // Def is the CTE definition according to the semantics Def *semantics.CTE - // Expressions are the expressions that are needed on the recurse side of the CTE + // Expressions are the predicates that are needed on the recurse side of the CTE Expressions []*plancontext.RecurseExpression // Vars is the map of variables that are sent between the two parts of the recursive CTE @@ -45,16 +45,27 @@ type RecurseCTE struct { // MyTableID is the id of the CTE MyTableInfo *semantics.CTETable + + Horizon *Horizon + LHSId semantics.TableSet } var _ Operator = (*RecurseCTE)(nil) -func newRecurse(def *semantics.CTE, seed, term Operator, expressions []*plancontext.RecurseExpression) *RecurseCTE { +func newRecurse( + def *semantics.CTE, + seed, term Operator, + expressions []*plancontext.RecurseExpression, + horizon *Horizon, + id semantics.TableSet, +) *RecurseCTE { return &RecurseCTE{ Def: def, Seed: seed, Term: term, Expressions: expressions, + Horizon: horizon, + LHSId: id, } } @@ -64,6 +75,8 @@ func (r *RecurseCTE) Clone(inputs []Operator) Operator { Expressions: slices.Clone(r.Expressions), Seed: inputs[0], Term: inputs[1], + Horizon: r.Horizon, + LHSId: r.LHSId, } } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 91d2fc80e35..3fd079e5d58 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -384,8 +384,8 @@ func (ctx *PlanningContext) ContainsAggr(e sqlparser.SQLNode) (hasAggr bool) { type ContextCTE struct { *semantics.CTE - Id semantics.TableSet - Expressions []*RecurseExpression + Id semantics.TableSet + Predicates []*RecurseExpression } type RecurseExpression struct { diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..e2e13147b64 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,9 +1,8 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", "plan": { - } } ] \ No newline at end of file diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index dd06f770462..d5bf53bd29a 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -448,6 +448,11 @@ func (a *analyzer) noteQuerySignature(node sqlparser.SQLNode) { if node.GroupBy != nil { a.sig.Aggregation = true } + case *sqlparser.With: + if node.Recursive { + // TODO: hacky - we should split this into it's own thing + a.sig.Aggregation = true + } case sqlparser.AggrFunc: a.sig.Aggregation = true case *sqlparser.Delete, *sqlparser.Update, *sqlparser.Insert: From 730c3b5cd7f1fa31d001965091a626eb4f6e47af Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 7 Aug 2024 14:03:55 +0200 Subject: [PATCH 19/28] handle projections on merged CTEs Signed-off-by: Andres Taylor --- .../planbuilder/operators/SQL_builder.go | 10 +++++ .../vtgate/planbuilder/operators/ast_to_op.go | 2 +- .../planbuilder/operators/cte_merging.go | 16 +++++--- go/vt/vtgate/planbuilder/operators/phases.go | 4 +- .../planbuilder/operators/recurse_cte.go | 40 ++++++++++++------- .../vtgate/planbuilder/testdata/onecase.json | 2 +- 6 files changed, 50 insertions(+), 24 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index eabe12aae85..3cb017d1339 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -701,9 +701,19 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { } func buildCTE(op *RecurseCTE, qb *queryBuilder) { + predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) buildQuery(op.Seed, qb) qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.Term, qbR) + qbR.addPredicate(pred) qb.cteWith(qbR, op.Def.Name) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index a8c60eb9f5c..60a501ea434 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -337,7 +337,7 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE) Op panic(err) } - return newRecurse(def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def)) + return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def)) } func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index 4252b2834f5..b392fee67de 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -38,7 +38,7 @@ func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *Recu switch { case a == dual: - return mergeCTE(seedRoute, termRoute, routingB, in) + return mergeCTE(ctx, seedRoute, termRoute, routingB, in) case a == sharded && b == sharded: return tryMergeCTESharded(ctx, seedRoute, termRoute, in) default: @@ -58,7 +58,7 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, seed, term *Route, in aExpr := tblA.VindexExpressions() bExpr := tblB.VindexExpressions() if aVdx == bVdx && gen4ValuesEqual(ctx, aExpr, bExpr) { - return mergeCTE(seed, term, tblA, in) + return mergeCTE(ctx, seed, term, tblA, in) } } } @@ -66,14 +66,18 @@ func tryMergeCTESharded(ctx *plancontext.PlanningContext, seed, term *Route, in return nil } -func mergeCTE(seed, term *Route, r Routing, in *RecurseCTE) *Route { +func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in *RecurseCTE) *Route { in.Def.Merged = true + hz := in.Horizon + hz.Source = term.Source + newTerm, _ := expandHorizon(ctx, hz) return &Route{ Routing: r, Source: &RecurseCTE{ - Def: in.Def, - Seed: seed.Source, - Term: term.Source, + Predicates: in.Predicates, + Def: in.Def, + Seed: seed.Source, + Term: newTerm, }, MergedWith: []*Route{term}, } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 3623efe12a0..2d695ebc66d 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -400,12 +400,12 @@ func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operato } // We need to break the expressions into LHS and RHS, and store them in the CTE for later use - expressions := slice.Map(ap, func(p *ProjExpr) *plancontext.RecurseExpression { + projections := slice.Map(ap, func(p *ProjExpr) *plancontext.RecurseExpression { recurseExpression := breakCTEExpressionInLhsAndRhs(ctx, p.EvalExpr, rcte.LHSId) p.EvalExpr = recurseExpression.RightExpr return recurseExpression }) - rcte.Expressions = append(rcte.Expressions, expressions...) + rcte.Projections = projections rcte.Term = newTerm return rcte, Rewrote("expanded horizon on term side of recursive CTE") } diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 3267ff5727f..c35e51a815c 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -18,9 +18,10 @@ package operators import ( "fmt" - "slices" "strings" + "golang.org/x/exp/maps" + "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -37,7 +38,8 @@ type RecurseCTE struct { Def *semantics.CTE // Expressions are the predicates that are needed on the recurse side of the CTE - Expressions []*plancontext.RecurseExpression + Predicates []*plancontext.RecurseExpression + Projections []*plancontext.RecurseExpression // Vars is the map of variables that are sent between the two parts of the recursive CTE // It's filled in at offset planning time @@ -53,28 +55,34 @@ type RecurseCTE struct { var _ Operator = (*RecurseCTE)(nil) func newRecurse( + ctx *plancontext.PlanningContext, def *semantics.CTE, seed, term Operator, - expressions []*plancontext.RecurseExpression, + predicates []*plancontext.RecurseExpression, horizon *Horizon, id semantics.TableSet, ) *RecurseCTE { + for _, pred := range predicates { + ctx.AddJoinPredicates(pred.Original, pred.RightExpr) + } return &RecurseCTE{ - Def: def, - Seed: seed, - Term: term, - Expressions: expressions, - Horizon: horizon, - LHSId: id, + Def: def, + Seed: seed, + Term: term, + Predicates: predicates, + Horizon: horizon, + LHSId: id, } } func (r *RecurseCTE) Clone(inputs []Operator) Operator { return &RecurseCTE{ - Def: r.Def, - Expressions: slices.Clone(r.Expressions), Seed: inputs[0], Term: inputs[1], + Def: r.Def, + Predicates: r.Predicates, + Projections: r.Projections, + Vars: maps.Clone(r.Vars), Horizon: r.Horizon, LHSId: r.LHSId, } @@ -140,10 +148,10 @@ func (r *RecurseCTE) ShortDescription() string { if len(r.Vars) > 0 { return fmt.Sprintf("%v", r.Vars) } - exprs := slice.Map(r.Expressions, func(expr *plancontext.RecurseExpression) string { + expressions := slice.Map(r.expressions(), func(expr *plancontext.RecurseExpression) string { return sqlparser.String(expr.Original) }) - return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(exprs, ", ")) + return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(expressions, ", ")) } func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { @@ -151,10 +159,14 @@ func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { return nil } +func (r *RecurseCTE) expressions() []*plancontext.RecurseExpression { + return append(r.Predicates, r.Projections...) +} + func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { r.Vars = make(map[string]int) columns := r.Seed.GetColumns(ctx) - for _, expr := range r.Expressions { + for _, expr := range r.expressions() { outer: for _, lhsExpr := range expr.LeftExprs { _, found := r.Vars[lhsExpr.Name] diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index e2e13147b64..9d653b2f6e9 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,7 +1,7 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", + "query": "", "plan": { } } From a2ef6739b26aa654d979843380cd65e1ac1c045f Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 7 Aug 2024 14:08:17 +0200 Subject: [PATCH 20/28] add more planner tests Signed-off-by: Andres Taylor --- .../planbuilder/testdata/cte_cases.json | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 35ef3430325..db890ef4882 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2121,6 +2121,9 @@ "Inputs": [ { "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 1 + }, "Inputs": [ { "OperatorType": "Route", @@ -2169,6 +2172,9 @@ "Inputs": [ { "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 1 + }, "Inputs": [ { "OperatorType": "Route", @@ -2245,5 +2251,75 @@ "main.dual" ] } + }, + { + "comment": "Recursive CTE calculations on the term side - merged", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL and id = 6 UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id and e.id = 6) SELECT * FROM emp_cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL and id = 6 UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id and e.id = 6) SELECT * FROM emp_cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive emp_cte as (select id, 1 as `level` from `user` where 1 != 1 union all select e.id, cte.`level` + 1 from cte as cte, `user` as e where 1 != 1) select id, `level` from emp_cte where 1 != 1", + "Query": "with recursive emp_cte as (select id, 1 as `level` from `user` where manager_id is null and id = 6 union all select e.id, cte.`level` + 1 from cte as cte, `user` as e where e.id = 6 and e.manager_id = cte.id) select id, `level` from emp_cte", + "Table": "`user`, dual", + "Values": [ + "6" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Recursive CTE calculations on the term side - unmerged", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", + "Instructions": { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 0, + "cte_level": 1 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, 1 as `level` from `user` where 1 != 1", + "Query": "select id, 1 as `level` from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.id, :cte_level + 1 as `cte.``level`` + 1` from `user` as e where 1 != 1", + "Query": "select e.id, :cte_level + 1 as `cte.``level`` + 1` from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } } ] From 2b4088ee3af7f5f54f8e703e44f8d3b6b5df44de Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 8 Aug 2024 08:52:34 +0200 Subject: [PATCH 21/28] test: add cte end to end tests Signed-off-by: Andres Taylor --- .../vtgate/vitess_tester/cte/queries.test | 99 ++----------------- 1 file changed, 9 insertions(+), 90 deletions(-) diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index ca4d8b1785e..f554ecdcd48 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -82,105 +82,24 @@ SELECT * FROM emp_cte LIMIT 5; -# Recursive CTE with DISTINCT to avoid duplicates -WITH RECURSIVE distinct_emp_cte AS (SELECT DISTINCT id, name, manager_id - FROM employees - WHERE manager_id IS NULL - UNION ALL - SELECT DISTINCT e.id, e.name, e.manager_id - FROM employees e - INNER JOIN distinct_emp_cte cte ON e.manager_id = cte.id) -SELECT * -FROM distinct_emp_cte; - -# Recursive CTE with aggregation outside the CTE -WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id - FROM employees - WHERE manager_id IS NULL - UNION ALL - SELECT e.id, e.name, e.manager_id - FROM employees e - INNER JOIN emp_cte cte ON e.manager_id = cte.id) -SELECT manager_id, COUNT(*) AS employee_count -FROM emp_cte -GROUP BY manager_id; - # Recursive CTE using literal values and joined with a real table on the outside -WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 'Root' AS name, NULL AS manager_id +WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, NULL AS manager_id UNION ALL - SELECT id + 1, CONCAT('Node', id + 1), id + SELECT id + 1, value * 2, id FROM literal_cte WHERE id < 5) -SELECT l.id, l.name, l.manager_id, e.name AS employee_name +SELECT l.id, l.value, l.manager_id, e.name AS employee_name FROM literal_cte l LEFT JOIN employees e ON l.id = e.id; -# Recursive CTE for generating a series of numbers -WITH RECURSIVE - number_series AS (SELECT 1 AS n - UNION ALL - SELECT n + 1 - FROM number_series - WHERE n < 5), - -# Recursive CTE that uses the number series - number_names AS (SELECT n, CONCAT('Number', n) AS name - FROM number_series) -SELECT * -FROM number_names; - -# Recursive CTE for generating a series of numbers -WITH RECURSIVE - number_series AS (SELECT 1 AS n - UNION ALL - SELECT n + 1 - FROM number_series - WHERE n < 5), - -# Independent recursive CTE for employees hierarchy - employee_hierarchy AS (SELECT id, name, manager_id - FROM employees - WHERE manager_id IS NULL - UNION ALL - SELECT e.id, e.name, e.manager_id - FROM employees e - INNER JOIN employee_hierarchy eh ON e.manager_id = eh.id) - -# Joining results from both CTEs -SELECT ns.n, ns.n AS number, eh.id, eh.name, eh.manager_id -FROM number_series ns - JOIN employee_hierarchy eh ON ns.n = eh.id; - -# Recursive CTE for generating a series of numbers -WITH RECURSIVE - number_series AS (SELECT 1 AS n - UNION ALL - SELECT n + 1 - FROM number_series - WHERE n < 5), - -# Independent recursive CTE for employees hierarchy - employee_hierarchy AS (SELECT id, name, manager_id +# Recursive CTE with aggregation outside the CTE +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id FROM employees WHERE manager_id IS NULL UNION ALL SELECT e.id, e.name, e.manager_id FROM employees e - INNER JOIN employee_hierarchy eh ON e.manager_id = eh.id) - -# Union results from both CTEs -SELECT n AS id, CONCAT('Number', n) AS name, NULL AS manager_id -FROM number_series -UNION -SELECT id, name, manager_id -FROM employee_hierarchy; - -WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level - FROM employees - WHERE manager_id IS NULL - UNION ALL - SELECT e.id, cte.level + 1 - FROM employees e - JOIN emp_cte cte ON e.manager_id = cte.id) -SELECT * -FROM emp_cte; \ No newline at end of file + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT manager_id, COUNT(*) AS employee_count +FROM emp_cte +GROUP BY manager_id; \ No newline at end of file From c2071318080296f241ef5f1502df99b99941a692 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 8 Aug 2024 11:10:48 +0200 Subject: [PATCH 22/28] handle recursive CTE table IDs correctly Signed-off-by: Andres Taylor --- .gitignore | 1 + .../vtgate/vitess_tester/cte/queries.test | 2 +- .../planbuilder/operators/apply_join.go | 1 - .../vtgate/planbuilder/operators/ast_to_op.go | 6 +-- .../planbuilder/operators/cte_merging.go | 2 + go/vt/vtgate/planbuilder/operators/phases.go | 2 +- .../planbuilder/operators/recurse_cte.go | 20 ++++++-- .../planbuilder/testdata/cte_cases.json | 49 +++++++++++++++++++ 8 files changed, 73 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 70ae13ad32d..e8c441d3bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ report # mise files .mise.toml +/errors/ diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index f554ecdcd48..8014c4c187d 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -83,7 +83,7 @@ FROM emp_cte LIMIT 5; # Recursive CTE using literal values and joined with a real table on the outside -WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, NULL AS manager_id +WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id UNION ALL SELECT id + 1, value * 2, id FROM literal_cte diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index f7bd5b131b8..4c6baab3729 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -204,7 +204,6 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq rhs := TableID(aj.RHS) both := lhs.Merge(rhs) deps := ctx.SemTable.RecursiveDeps(e) - switch { case deps.IsSolvedBy(lhs): col.LHSExprs = []BindVarExpr{{Expr: e}} diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 60a501ea434..4f0ab742935 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -270,7 +270,7 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr return createDualCTETable(ctx, tableID, tableInfo) case *semantics.RealTable: if tableInfo.CTE != nil { - return createRecursiveCTE(ctx, tableInfo.CTE) + return createRecursiveCTE(ctx, tableInfo.CTE, tableID) } qg := newQueryGraph() @@ -314,7 +314,7 @@ func createDualCTETable(ctx *plancontext.PlanningContext, tableID semantics.Tabl return createRouteFromVSchemaTable(ctx, qtbl, vschemaTable, false, nil) } -func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE) Operator { +func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE, outerID semantics.TableSet) Operator { union, ok := def.Query.(*sqlparser.Union) if !ok { panic(vterrors.VT13001("expected UNION in recursive CTE")) @@ -337,7 +337,7 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE) Op panic(err) } - return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def)) + return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID) } func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index b392fee67de..9ca453f39c6 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -78,6 +78,8 @@ func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in Def: in.Def, Seed: seed.Source, Term: newTerm, + LeftID: in.LeftID, + OuterID: in.OuterID, }, MergedWith: []*Route{term}, } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 2d695ebc66d..c0ea3df372f 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -401,7 +401,7 @@ func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operato // We need to break the expressions into LHS and RHS, and store them in the CTE for later use projections := slice.Map(ap, func(p *ProjExpr) *plancontext.RecurseExpression { - recurseExpression := breakCTEExpressionInLhsAndRhs(ctx, p.EvalExpr, rcte.LHSId) + recurseExpression := breakCTEExpressionInLhsAndRhs(ctx, p.EvalExpr, rcte.LeftID) p.EvalExpr = recurseExpression.RightExpr return recurseExpression }) diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index c35e51a815c..f5f7c5ed5b0 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -48,8 +48,14 @@ type RecurseCTE struct { // MyTableID is the id of the CTE MyTableInfo *semantics.CTETable + // Horizon is stored here until we either expand it or push it under a route Horizon *Horizon - LHSId semantics.TableSet + + // The LeftID is the id of the left side of the CTE + LeftID, + + // The OuterID is the id for this use of the CTE + OuterID semantics.TableSet } var _ Operator = (*RecurseCTE)(nil) @@ -60,7 +66,7 @@ func newRecurse( seed, term Operator, predicates []*plancontext.RecurseExpression, horizon *Horizon, - id semantics.TableSet, + leftID, outerID semantics.TableSet, ) *RecurseCTE { for _, pred := range predicates { ctx.AddJoinPredicates(pred.Original, pred.RightExpr) @@ -71,7 +77,8 @@ func newRecurse( Term: term, Predicates: predicates, Horizon: horizon, - LHSId: id, + LeftID: leftID, + OuterID: outerID, } } @@ -84,7 +91,8 @@ func (r *RecurseCTE) Clone(inputs []Operator) Operator { Projections: r.Projections, Vars: maps.Clone(r.Vars), Horizon: r.Horizon, - LHSId: r.LHSId, + LeftID: r.LeftID, + OuterID: r.OuterID, } } @@ -186,3 +194,7 @@ func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { } return r } + +func (r *RecurseCTE) introducesTableID() semantics.TableSet { + return r.OuterID +} diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index db890ef4882..ad5c2ea7ed3 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2321,5 +2321,54 @@ "user.user" ] } + }, + { + "comment": "Outer join with recursive CTE", + "query": "WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id UNION ALL SELECT id + 1, value * 2, id FROM literal_cte WHERE id < 5) SELECT l.id, l.value, l.manager_id, e.name AS employee_name FROM literal_cte l LEFT JOIN user e ON l.id = e.id", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id UNION ALL SELECT id + 1, value * 2, id FROM literal_cte WHERE id < 5) SELECT l.id, l.value, l.manager_id, e.name AS employee_name FROM literal_cte l LEFT JOIN user e ON l.id = e.id", + "Instructions": { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,L:1,L:2,R:0", + "JoinVars": { + "l_id": 0 + }, + "TableName": "dual_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual where 1 != 1 union all select id + 1, value * 2, id from literal_cte where 1 != 1) select l.id, l.value, l.manager_id from literal_cte where 1 != 1", + "Query": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual union all select id + 1, value * 2, id from literal_cte where id < 5) select l.id, l.value, l.manager_id from literal_cte", + "Table": "dual" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name` as employee_name from `user` as e where 1 != 1", + "Query": "select e.`name` as employee_name from `user` as e where e.id = :l_id", + "Table": "`user`", + "Values": [ + ":l_id" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } } ] From 8566ebe2cd5d6c11bc3c7560bd7bb11977403b62 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 8 Aug 2024 11:23:19 +0200 Subject: [PATCH 23/28] preserve CTE table alias Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operators/SQL_builder.go | 11 ++++++++--- go/vt/vtgate/planbuilder/testdata/cte_cases.json | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 3cb017d1339..092d6328e1d 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -236,7 +236,7 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { } } -func (qb *queryBuilder) cteWith(other *queryBuilder, name string) { +func (qb *queryBuilder) cteWith(other *queryBuilder, name, alias string) { cteUnion := &sqlparser.Union{ Left: qb.stmt.(sqlparser.SelectStatement), Right: other.stmt.(sqlparser.SelectStatement), @@ -253,7 +253,7 @@ func (qb *queryBuilder) cteWith(other *queryBuilder, name string) { }, } - qb.addTable("", name, "", "", nil) + qb.addTable("", name, alias, "", nil) } type FromStatement interface { @@ -714,7 +714,12 @@ func buildCTE(op *RecurseCTE, qb *queryBuilder) { qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.Term, qbR) qbR.addPredicate(pred) - qb.cteWith(qbR, op.Def.Name) + infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) + if err != nil { + panic(err) + } + + qb.cteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String()) } func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index ad5c2ea7ed3..f09c0d236b8 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2344,8 +2344,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual where 1 != 1 union all select id + 1, value * 2, id from literal_cte where 1 != 1) select l.id, l.value, l.manager_id from literal_cte where 1 != 1", - "Query": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual union all select id + 1, value * 2, id from literal_cte where id < 5) select l.id, l.value, l.manager_id from literal_cte", + "FieldQuery": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual where 1 != 1 union all select id + 1, value * 2, id from literal_cte where 1 != 1) select l.id, l.value, l.manager_id from literal_cte as l where 1 != 1", + "Query": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual union all select id + 1, value * 2, id from literal_cte where id < 5) select l.id, l.value, l.manager_id from literal_cte as l", "Table": "dual" }, { From 5f2aa277c0463cf7fa528420fcb80490d7e5b76d Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 14 Aug 2024 13:08:17 +0200 Subject: [PATCH 24/28] changelog Signed-off-by: Andres Taylor --- changelog/21.0/21.0.0/summary.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/changelog/21.0/21.0.0/summary.md b/changelog/21.0/21.0.0/summary.md index a29e2d286ec..f24b2ee87ab 100644 --- a/changelog/21.0/21.0.0/summary.md +++ b/changelog/21.0/21.0.0/summary.md @@ -12,6 +12,7 @@ - **[New VTGate Shutdown Behavior](#new-vtgate-shutdown-behavior)** - **[Tablet Throttler: Multi-Metric support](#tablet-throttler)** - **[Allow Cross Cell Promotion in PRS](#allow-cross-cell)** + - **[Support for recursive CTEs](#recursive-cte)** ## Major Changes @@ -102,3 +103,6 @@ Metrics are assigned a default _scope_, which could be `self` (isolated to the t Up until now if the users wanted to promote a replica in a different cell than the current primary using `PlannedReparentShard`, they had to specify the new primary with the `--new-primary` flag. We have now added a new flag `--allow-cross-cell-promotion` that lets `PlannedReparentShard` choose a primary in a different cell even if no new primary is provided explicitly. + +### Experimental support for recursive CTEs +We have added experimental support for recursive CTEs in Vitess. We are marking it as experimental because it is not yet fully tested and may have some limitations. We are looking for feedback from the community to improve this feature. \ No newline at end of file From efda1dfcd5c86246e0ecda70d20780b4c6a01379 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 14 Aug 2024 13:20:05 +0200 Subject: [PATCH 25/28] handle ws offsets on recursive ctes Signed-off-by: Andres Taylor --- .../planbuilder/operators/recurse_cte.go | 15 +++- .../planbuilder/testdata/cte_cases.json | 68 +++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index f5f7c5ed5b0..173f14a616c 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -113,7 +113,11 @@ func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Ex func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int { r.makeSureWeHaveTableInfo(ctx) e := semantics.RewriteDerivedTableExpression(expr.Expr, r.MyTableInfo) - return r.Seed.FindCol(ctx, e, false) + offset := r.Seed.FindCol(ctx, e, false) + if offset == -1 { + panic(vterrors.VT13001("CTE column not found")) + } + return offset } func (r *RecurseCTE) makeSureWeHaveTableInfo(ctx *plancontext.PlanningContext) { @@ -134,8 +138,13 @@ func (r *RecurseCTE) makeSureWeHaveTableInfo(ctx *plancontext.PlanningContext) { } } -func (r *RecurseCTE) AddWSColumn(*plancontext.PlanningContext, int, bool) int { - panic("implement me") +func (r *RecurseCTE) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + seed := r.Seed.AddWSColumn(ctx, offset, underRoute) + term := r.Term.AddWSColumn(ctx, offset, underRoute) + if seed != term { + panic(vterrors.VT13001("CTE columns don't match")) + } + return seed } func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index f09c0d236b8..35470ce77d0 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2370,5 +2370,73 @@ "user.user" ] } + }, + { + "comment": "Aggregation on the output of a recursive CTE", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, e.name, e.manager_id FROM user e INNER JOIN emp_cte cte ON e.manager_id = cte.id) SELECT manager_id, COUNT(*) AS employee_count FROM emp_cte GROUP BY manager_id", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, e.name, e.manager_id FROM user e INNER JOIN emp_cte cte ON e.manager_id = cte.id) SELECT manager_id, COUNT(*) AS employee_count FROM emp_cte GROUP BY manager_id", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_star(1) AS employee_count", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + ":2 as manager_id", + "1 as 1", + "weight_string(:2) as weight_string(manager_id)" + ], + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 0 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select dt.c0 as id, dt.c1 as `name`, dt.c2 as manager_id, weight_string(dt.c2) from (select id, `name`, manager_id from `user` where 1 != 1) as dt(c0, c1, c2) where 1 != 1", + "Query": "select dt.c0 as id, dt.c1 as `name`, dt.c2 as manager_id, weight_string(dt.c2) from (select id, `name`, manager_id from `user` where manager_id is null) as dt(c0, c1, c2)", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.id, e.`name`, e.manager_id, weight_string(e.manager_id) from `user` as e where 1 != 1", + "Query": "select e.id, e.`name`, e.manager_id, weight_string(e.manager_id) from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } } ] From f8501c02cb2e3962dce5262fe09291fe50154011 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 14 Aug 2024 14:49:21 +0200 Subject: [PATCH 26/28] remove unused code Signed-off-by: Andres Taylor --- go/vt/vtgate/planbuilder/operators/phases.go | 44 -------------------- 1 file changed, 44 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index c0ea3df372f..33b17f4e83e 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -210,50 +210,6 @@ func enableDelegateAggregation(ctx *plancontext.PlanningContext, op Operator) Op return prepareForAggregationPushing(ctx, op) } -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. -func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator { - - addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { - addedCols := false - filter, ok := in.(*Filter) - if !ok { - return in, NoRewrite - } - - var neededAggrs []sqlparser.Expr - extractAggrs := func(cursor *sqlparser.CopyOnWriteCursor) { - node := cursor.Node() - if ctx.IsAggr(node) { - neededAggrs = append(neededAggrs, node.(sqlparser.Expr)) - } - } - - for _, expr := range filter.Predicates { - _ = sqlparser.CopyOnRewrite(expr, dontEnterSubqueries, extractAggrs, nil) - } - - if neededAggrs == nil { - return in, NoRewrite - } - - aggregator := findAggregatorInSource(filter.Source) - for _, aggr := range neededAggrs { - if aggregator.FindCol(ctx, aggr, false) == -1 { - aggregator.addColumnWithoutPushing(ctx, aeWrap(aggr), false) - addedCols = true - } - } - - if addedCols { - return in, Rewrote("added columns because filter needs it") - } - return in, NoRewrite - } - - return TopDown(root, TableID, addColumnsNeededByFilter, stopAtRoute) -} - // addOrderingForAllAggregations is run we have pushed down Aggregators as far down as possible. func addOrderingForAllAggregations(ctx *plancontext.PlanningContext, root Operator) Operator { visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { From 3aaad014fd233eef5be51df0f225ec6fef52f2e9 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 15 Aug 2024 09:38:20 +0200 Subject: [PATCH 27/28] make sure to respect the transaction needs Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/recurse_cte.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go index f50646ad271..d8e1abce4da 100644 --- a/go/vt/vtgate/engine/recurse_cte.go +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -125,7 +125,7 @@ func (r *RecurseCTE) GetFields(ctx context.Context, vcursor VCursor, bindVars ma } func (r *RecurseCTE) NeedsTransaction() bool { - return false + return r.Seed.NeedsTransaction() || r.Term.NeedsTransaction() } func (r *RecurseCTE) Inputs() ([]Primitive, []map[string]any) { From 36a458d8234ba323cd87b6333809e4cf6d6013be Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Mon, 19 Aug 2024 10:27:09 +0200 Subject: [PATCH 28/28] address review comments Signed-off-by: Andres Taylor --- go/mysql/sqlerror/constants.go | 1 + go/mysql/sqlerror/sql_error.go | 1 + .../vtgate/vitess_tester/cte/queries.test | 7 +++++- go/vt/sqlparser/ast_funcs.go | 2 +- go/vt/vterrors/code.go | 5 +++++ go/vt/vterrors/state.go | 1 + go/vt/vtgate/engine/recurse_cte.go | 10 +++++++++ .../planbuilder/operators/SQL_builder.go | 10 ++++----- go/vt/vtgate/planbuilder/operators/phases.go | 22 +++++++++++++------ .../planbuilder/operators/recurse_cte.go | 2 +- .../plancontext/planning_context.go | 3 +-- go/vt/vtgate/semantics/analyzer.go | 3 +-- go/vt/vtgate/semantics/real_table.go | 20 +++++------------ go/vt/vtgate/semantics/scoper.go | 4 ---- go/vt/vtgate/semantics/semantic_table.go | 13 ++++++----- 15 files changed, 61 insertions(+), 43 deletions(-) diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go index bd4c188af14..a61239ce17b 100644 --- a/go/mysql/sqlerror/constants.go +++ b/go/mysql/sqlerror/constants.go @@ -260,6 +260,7 @@ const ( ERCTERecursiveForbidsAggregation = ErrorCode(3575) ERCTERecursiveForbiddenJoinOrder = ErrorCode(3576) ERCTERecursiveRequiresSingleReference = ErrorCode(3577) + ERCTEMaxRecursionDepth = ErrorCode(3636) ERRegexpStringNotTerminated = ErrorCode(3684) ERRegexpBufferOverflow = ErrorCode(3684) ERRegexpIllegalArgument = ErrorCode(3685) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index 603456a7ae9..63883760243 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -250,6 +250,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.CTERecursiveRequiresUnion: {num: ERCTERecursiveRequiresUnion, state: SSUnknownSQLState}, vterrors.CTERecursiveForbidsAggregation: {num: ERCTERecursiveForbidsAggregation, state: SSUnknownSQLState}, vterrors.CTERecursiveForbiddenJoinOrder: {num: ERCTERecursiveForbiddenJoinOrder, state: SSUnknownSQLState}, + vterrors.CTEMaxRecursionDepth: {num: ERCTEMaxRecursionDepth, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index 8014c4c187d..de38a21cd78 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -102,4 +102,9 @@ WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id INNER JOIN emp_cte cte ON e.manager_id = cte.id) SELECT manager_id, COUNT(*) AS employee_count FROM emp_cte -GROUP BY manager_id; \ No newline at end of file +GROUP BY manager_id; + +--error infinite recursion +with recursive cte as (select 1 as n union all select n+1 from cte) +select * +from cte; \ No newline at end of file diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 5e152065622..ae96fe9c1fe 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -436,7 +436,7 @@ func (node *AliasedTableExpr) TableNameString() string { tableName, ok := node.Expr.(TableName) if !ok { - panic(vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: Derived table should have an alias. This should not be possible")) + panic(vterrors.VT13001("Derived table should have an alias. This should not be possible")) } return tableName.Name.String() diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 0b2c298f17f..31c98cef280 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -101,6 +101,7 @@ var ( VT09027 = errorWithState("VT09027", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbidsAggregation, "Recursive Common Table Expression '%s' can contain neither aggregation nor window functions in recursive query block", "") VT09028 = errorWithState("VT09028", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbiddenJoinOrder, "In recursive query block of Recursive Common Table Expression '%s', the recursive table must neither be in the right argument of a LEFT JOIN, nor be forced to be non-first with join order hints", "") VT09029 = errorWithState("VT09029", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveRequiresSingleReference, "In recursive query block of Recursive Common Table Expression %s, the recursive table must be referenced only once, and not in any subquery", "") + VT09030 = errorWithState("VT09030", vtrpcpb.Code_FAILED_PRECONDITION, CTEMaxRecursionDepth, "Recursive query aborted after 1000 iterations.", "") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") VT10002 = errorWithoutState("VT10002", vtrpcpb.Code_ABORTED, "atomic distributed transaction not allowed: %s", "The distributed transaction cannot be committed. A rollback decision is taken.") @@ -187,6 +188,10 @@ var ( VT09022, VT09023, VT09024, + VT09026, + VT09027, + VT09028, + VT09029, VT10001, VT10002, VT12001, diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 1f1c5922c37..528000e9e41 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -66,6 +66,7 @@ const ( CTERecursiveRequiresUnion CTERecursiveForbidsAggregation CTERecursiveForbiddenJoinOrder + CTEMaxRecursionDepth // not found BadDb diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go index d8e1abce4da..f523883d280 100644 --- a/go/vt/vtgate/engine/recurse_cte.go +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vterrors" ) // RecurseCTE is used to represent recursive CTEs @@ -45,6 +46,7 @@ func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars m // recurseRows contains the rows used in the next recursion recurseRows := res.Rows joinVars := make(map[string]*querypb.BindVariable) + loops := 0 for len(recurseRows) > 0 { // copy over the results from the previous recursion theseRows := recurseRows @@ -53,12 +55,20 @@ func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars m for k, col := range r.Vars { joinVars[k] = sqltypes.ValueBindVariable(row[col]) } + // check if the context is done - we might be in a long running recursion + if err := ctx.Err(); err != nil { + return nil, err + } rresult, err := vcursor.ExecutePrimitive(ctx, r.Term, combineVars(bindVars, joinVars), false) if err != nil { return nil, err } recurseRows = append(recurseRows, rresult.Rows...) res.Rows = append(res.Rows, rresult.Rows...) + loops++ + if loops > 1000 { // TODO: This should be controlled with a system variable setting + return nil, vterrors.VT09030("") + } } } return res, nil diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 092d6328e1d..8cc23c57ae7 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -63,7 +63,7 @@ func (qb *queryBuilder) includeTable(op *Table) bool { } tbl, err := qb.ctx.SemTable.TableInfoFor(op.QTable.ID) if err != nil { - return true + panic(err) } cteTbl, isCTE := tbl.(*semantics.CTETable) if !isCTE { @@ -236,7 +236,7 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { } } -func (qb *queryBuilder) cteWith(other *queryBuilder, name, alias string) { +func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string) { cteUnion := &sqlparser.Union{ Left: qb.stmt.(sqlparser.SelectStatement), Right: other.stmt.(sqlparser.SelectStatement), @@ -451,7 +451,7 @@ func buildQuery(op Operator, qb *queryBuilder) { case *Insert: buildDML(op, qb) case *RecurseCTE: - buildCTE(op, qb) + buildRecursiveCTE(op, qb) default: panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) } @@ -700,7 +700,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { sqlparser.RemoveKeyspaceInCol(qb.stmt) } -func buildCTE(op *RecurseCTE, qb *queryBuilder) { +func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done err := qb.ctx.SkipJoinPredicates(jc.Original) @@ -719,7 +719,7 @@ func buildCTE(op *RecurseCTE, qb *queryBuilder) { panic(err) } - qb.cteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String()) + qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String()) } func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index 33b17f4e83e..d5354e9548f 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -36,6 +36,7 @@ const ( initialPlanning pullDistinctFromUnion delegateAggregation + recursiveCTEHorizons addAggrOrdering cleanOutPerfDistinct dmlWithInput @@ -53,6 +54,8 @@ func (p Phase) String() string { return "pull distinct from UNION" case delegateAggregation: return "split aggregation between vtgate and mysql" + case recursiveCTEHorizons: + return "expand recursive CTE horizons" case addAggrOrdering: return "optimize aggregations with ORDER BY" case cleanOutPerfDistinct: @@ -72,6 +75,8 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool { return s.Union case delegateAggregation: return s.Aggregation + case recursiveCTEHorizons: + return s.RecursiveCTE case addAggrOrdering: return s.Aggregation case cleanOutPerfDistinct: @@ -93,6 +98,8 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator { return enableDelegateAggregation(ctx, op) case addAggrOrdering: return addOrderingForAllAggregations(ctx, op) + case recursiveCTEHorizons: + return planRecursiveCTEHorizons(ctx, op) case cleanOutPerfDistinct: return removePerformanceDistinctAboveRoute(ctx, op) case subquerySettling: @@ -301,7 +308,7 @@ func addLiteralGroupingToRHS(in *ApplyJoin) (Operator, *ApplyResult) { // prepareForAggregationPushing adds columns needed by an operator to its input. // This happens only when the filter expression can be retrieved as an offset from the underlying mysql. func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operator) Operator { - addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + return TopDown(root, TableID, func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { filter, ok := in.(*Filter) if !ok { return in, NoRewrite @@ -336,10 +343,13 @@ func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operato return in, Rewrote("added columns because filter needs it") } return in, NoRewrite - } + }, stopAtRoute) +} - step1 := TopDown(root, TableID, addColumnsNeededByFilter, stopAtRoute) - pushDownHorizon := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { +// prepareForAggregationPushing adds columns needed by an operator to its input. +// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. +func planRecursiveCTEHorizons(ctx *plancontext.PlanningContext, root Operator) Operator { + return TopDown(root, TableID, func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { // These recursive CTEs have not been pushed under a route, so we will have to evaluate it one the vtgate // That means that we need to turn anything that is coming from the recursion into arguments rcte, ok := in.(*RecurseCTE) @@ -364,9 +374,7 @@ func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operato rcte.Projections = projections rcte.Term = newTerm return rcte, Rewrote("expanded horizon on term side of recursive CTE") - } - - return TopDown(step1, TableID, pushDownHorizon, stopAtRoute) + }, stopAtRoute) } func findProjection(op Operator) *Projection { diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 173f14a616c..7a8c9dcd355 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -198,7 +198,7 @@ func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { } } - panic("couldn't find column") + panic(vterrors.VT13001("couldn't find column")) } } return r diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 3fd079e5d58..00ac889c082 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -21,7 +21,6 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine/opcode" @@ -408,7 +407,7 @@ func (ctx *PlanningContext) PushCTE(def *semantics.CTE, id semantics.TableSet) { func (ctx *PlanningContext) PopCTE() (*ContextCTE, error) { if len(ctx.CurrentCTE) == 0 { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no CTE to pop") + return nil, vterrors.VT13001("no CTE to pop") } activeCTE := ctx.CurrentCTE[len(ctx.CurrentCTE)-1] ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index d5bf53bd29a..ec42f638629 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -450,8 +450,7 @@ func (a *analyzer) noteQuerySignature(node sqlparser.SQLNode) { } case *sqlparser.With: if node.Recursive { - // TODO: hacky - we should split this into it's own thing - a.sig.Aggregation = true + a.sig.RecursiveCTE = true } case sqlparser.AggrFunc: a.sig.Aggregation = true diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 88b5b8725ae..399395a9edf 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -127,22 +127,14 @@ func (r *RealTable) getCTEColumns() []ColumnInfo { // Authoritative implements the TableInfo interface func (r *RealTable) authoritative() bool { - if r.Table != nil { + switch { + case r.Table != nil: return r.Table.ColumnListAuthoritative + case r.CTE != nil: + return r.CTE.isAuthoritative + default: + return false } - if r.CTE != nil { - if len(r.CTE.Columns) > 0 { - return true - } - for _, se := range r.CTE.Query.GetColumns() { - _, isAe := se.(*sqlparser.AliasedExpr) - if !isAe { - return false - } - } - return true - } - return false } func extractSelectExprsFromCTE(selectExprs sqlparser.SelectExprs) []ColumnInfo { diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index b51cedd2338..9d596d9ecd1 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -17,7 +17,6 @@ limitations under the License. package semantics import ( - "fmt" "reflect" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -188,9 +187,6 @@ func (s *scoper) enterJoinScope(cursor *sqlparser.Cursor) { func (s *scoper) pushSelectScope(node *sqlparser.Select) { currScope := newScope(s.currentScope()) - if len(s.scopes) > 0 && s.scopes[len(s.scopes)-1] != s.currentScope() { - fmt.Println("BUG: scope counts did not match") - } currScope.stmtScope = true s.push(currScope) diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index 2eda2c5c29f..6738546fe37 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -77,12 +77,13 @@ type ( // QuerySignature is used to identify shortcuts in the planning process QuerySignature struct { - Aggregation bool - DML bool - Distinct bool - HashJoin bool - SubQueries bool - Union bool + Aggregation bool + DML bool + Distinct bool + HashJoin bool + SubQueries bool + Union bool + RecursiveCTE bool } // SemTable contains semantic analysis information about the query.