From 06b6f2937409b85f7a8f3b38a73c70dc9588d23c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Mon, 19 Aug 2024 11:35:47 +0200 Subject: [PATCH] Support recursive CTEs (#16427) Signed-off-by: Manan Gupta Signed-off-by: Andres Taylor --- .gitignore | 1 + changelog/21.0/21.0.0/summary.md | 4 + go/mysql/sqlerror/constants.go | 47 +-- go/mysql/sqlerror/sql_error.go | 153 ++++---- .../vtgate/vitess_tester/cte/queries.test | 110 ++++++ go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 2 +- go/vt/sqlparser/ast_copy_on_rewrite.go | 4 +- go/vt/sqlparser/ast_equals.go | 2 +- go/vt/sqlparser/ast_format.go | 2 +- go/vt/sqlparser/ast_format_fast.go | 4 +- go/vt/sqlparser/ast_funcs.go | 14 + go/vt/sqlparser/ast_rewrite.go | 4 +- go/vt/sqlparser/ast_visit.go | 2 +- go/vt/sqlparser/cached_size.go | 8 +- go/vt/sqlparser/sql.go | 2 +- go/vt/sqlparser/sql.y | 2 +- go/vt/vterrors/code.go | 9 + go/vt/vterrors/state.go | 5 + go/vt/vtgate/engine/cached_size.go | 34 ++ go/vt/vtgate/engine/fake_primitive_test.go | 2 +- go/vt/vtgate/engine/recurse_cte.go | 155 ++++++++ go/vt/vtgate/engine/recurse_cte_test.go | 129 +++++++ .../planbuilder/operator_transformers.go | 18 + .../planbuilder/operators/SQL_builder.go | 88 ++++- .../planbuilder/operators/apply_join.go | 1 - .../vtgate/planbuilder/operators/ast_to_op.go | 79 ++++- .../planbuilder/operators/cte_merging.go | 86 +++++ go/vt/vtgate/planbuilder/operators/join.go | 40 +++ go/vt/vtgate/planbuilder/operators/phases.go | 144 +++++--- .../planbuilder/operators/query_planning.go | 3 + .../planbuilder/operators/recurse_cte.go | 209 +++++++++++ .../plancontext/planning_context.go | 44 +++ .../planbuilder/testdata/cte_cases.json | 333 ++++++++++++++++++ .../vtgate/planbuilder/testdata/onecase.json | 1 - .../testdata/unsupported_cases.json | 5 - go/vt/vtgate/semantics/analyzer.go | 13 +- go/vt/vtgate/semantics/analyzer_test.go | 58 ++- go/vt/vtgate/semantics/check_invalid.go | 4 - go/vt/vtgate/semantics/cte_table.go | 177 ++++++++++ go/vt/vtgate/semantics/derived_table.go | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 6 +- go/vt/vtgate/semantics/early_rewriter_test.go | 3 + go/vt/vtgate/semantics/foreign_keys_test.go | 121 +++---- go/vt/vtgate/semantics/real_table.go | 69 +++- go/vt/vtgate/semantics/scoper.go | 16 +- .../{semantic_state.go => semantic_table.go} | 17 +- ...c_state_test.go => semantic_table_test.go} | 0 go/vt/vtgate/semantics/table_collector.go | 174 ++++++--- 49 files changed, 2075 insertions(+), 333 deletions(-) create mode 100644 go/test/endtoend/vtgate/vitess_tester/cte/queries.test create mode 100644 go/vt/vtgate/engine/recurse_cte.go create mode 100644 go/vt/vtgate/engine/recurse_cte_test.go create mode 100644 go/vt/vtgate/planbuilder/operators/cte_merging.go create mode 100644 go/vt/vtgate/planbuilder/operators/recurse_cte.go create mode 100644 go/vt/vtgate/semantics/cte_table.go rename go/vt/vtgate/semantics/{semantic_state.go => semantic_table.go} (99%) rename go/vt/vtgate/semantics/{semantic_state_test.go => semantic_table_test.go} (100%) diff --git a/.gitignore b/.gitignore index 70ae13ad32d..e8c441d3bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ report # mise files .mise.toml +/errors/ diff --git a/changelog/21.0/21.0.0/summary.md b/changelog/21.0/21.0.0/summary.md index a29e2d286ec..f24b2ee87ab 100644 --- a/changelog/21.0/21.0.0/summary.md +++ b/changelog/21.0/21.0.0/summary.md @@ -12,6 +12,7 @@ - **[New VTGate Shutdown Behavior](#new-vtgate-shutdown-behavior)** - **[Tablet Throttler: Multi-Metric support](#tablet-throttler)** - **[Allow Cross Cell Promotion in PRS](#allow-cross-cell)** + - **[Support for recursive CTEs](#recursive-cte)** ## Major Changes @@ -102,3 +103,6 @@ Metrics are assigned a default _scope_, which could be `self` (isolated to the t Up until now if the users wanted to promote a replica in a different cell than the current primary using `PlannedReparentShard`, they had to specify the new primary with the `--new-primary` flag. We have now added a new flag `--allow-cross-cell-promotion` that lets `PlannedReparentShard` choose a primary in a different cell even if no new primary is provided explicitly. + +### Experimental support for recursive CTEs +We have added experimental support for recursive CTEs in Vitess. We are marking it as experimental because it is not yet fully tested and may have some limitations. We are looking for feedback from the community to improve this feature. \ No newline at end of file diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go index 15c590b92a8..a61239ce17b 100644 --- a/go/mysql/sqlerror/constants.go +++ b/go/mysql/sqlerror/constants.go @@ -255,27 +255,32 @@ const ( ERJSONValueTooBig = ErrorCode(3150) ERJSONDocumentTooDeep = ErrorCode(3157) - ERLockNowait = ErrorCode(3572) - ERRegexpStringNotTerminated = ErrorCode(3684) - ERRegexpBufferOverflow = ErrorCode(3684) - ERRegexpIllegalArgument = ErrorCode(3685) - ERRegexpIndexOutOfBounds = ErrorCode(3686) - ERRegexpInternal = ErrorCode(3687) - ERRegexpRuleSyntax = ErrorCode(3688) - ERRegexpBadEscapeSequence = ErrorCode(3689) - ERRegexpUnimplemented = ErrorCode(3690) - ERRegexpMismatchParen = ErrorCode(3691) - ERRegexpBadInterval = ErrorCode(3692) - ERRRegexpMaxLtMin = ErrorCode(3693) - ERRegexpInvalidBackRef = ErrorCode(3694) - ERRegexpLookBehindLimit = ErrorCode(3695) - ERRegexpMissingCloseBracket = ErrorCode(3696) - ERRegexpInvalidRange = ErrorCode(3697) - ERRegexpStackOverflow = ErrorCode(3698) - ERRegexpTimeOut = ErrorCode(3699) - ERRegexpPatternTooBig = ErrorCode(3700) - ERRegexpInvalidCaptureGroup = ErrorCode(3887) - ERRegexpInvalidFlag = ErrorCode(3900) + ERLockNowait = ErrorCode(3572) + ERCTERecursiveRequiresUnion = ErrorCode(3573) + ERCTERecursiveForbidsAggregation = ErrorCode(3575) + ERCTERecursiveForbiddenJoinOrder = ErrorCode(3576) + ERCTERecursiveRequiresSingleReference = ErrorCode(3577) + ERCTEMaxRecursionDepth = ErrorCode(3636) + ERRegexpStringNotTerminated = ErrorCode(3684) + ERRegexpBufferOverflow = ErrorCode(3684) + ERRegexpIllegalArgument = ErrorCode(3685) + ERRegexpIndexOutOfBounds = ErrorCode(3686) + ERRegexpInternal = ErrorCode(3687) + ERRegexpRuleSyntax = ErrorCode(3688) + ERRegexpBadEscapeSequence = ErrorCode(3689) + ERRegexpUnimplemented = ErrorCode(3690) + ERRegexpMismatchParen = ErrorCode(3691) + ERRegexpBadInterval = ErrorCode(3692) + ERRRegexpMaxLtMin = ErrorCode(3693) + ERRegexpInvalidBackRef = ErrorCode(3694) + ERRegexpLookBehindLimit = ErrorCode(3695) + ERRegexpMissingCloseBracket = ErrorCode(3696) + ERRegexpInvalidRange = ErrorCode(3697) + ERRegexpStackOverflow = ErrorCode(3698) + ERRegexpTimeOut = ErrorCode(3699) + ERRegexpPatternTooBig = ErrorCode(3700) + ERRegexpInvalidCaptureGroup = ErrorCode(3887) + ERRegexpInvalidFlag = ErrorCode(3900) ERCharacterSetMismatch = ErrorCode(3995) diff --git a/go/mysql/sqlerror/sql_error.go b/go/mysql/sqlerror/sql_error.go index eaa49c2c537..63883760243 100644 --- a/go/mysql/sqlerror/sql_error.go +++ b/go/mysql/sqlerror/sql_error.go @@ -172,80 +172,85 @@ type mysqlCode struct { } var stateToMysqlCode = map[vterrors.State]mysqlCode{ - vterrors.Undefined: {num: ERUnknownError, state: SSUnknownSQLState}, - vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError}, - vterrors.BadDb: {num: ERBadDb, state: SSClientError}, - vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError}, - vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable}, - vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError}, - vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, - vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, - vterrors.DbDropExists: {num: ERDbDropExists, state: SSUnknownSQLState}, - vterrors.DupFieldName: {num: ERDupFieldName, state: SSDupFieldName}, - vterrors.EmptyQuery: {num: EREmptyQuery, state: SSClientError}, - vterrors.IncorrectGlobalLocalVar: {num: ERIncorrectGlobalLocalVar, state: SSUnknownSQLState}, - vterrors.InnodbReadOnly: {num: ERInnodbReadOnly, state: SSUnknownSQLState}, - vterrors.LockOrActiveTransaction: {num: ERLockOrActiveTransaction, state: SSUnknownSQLState}, - vterrors.NoDB: {num: ERNoDb, state: SSNoDB}, - vterrors.NoSuchTable: {num: ERNoSuchTable, state: SSUnknownTable}, - vterrors.NotSupportedYet: {num: ERNotSupportedYet, state: SSClientError}, - vterrors.ForbidSchemaChange: {num: ERForbidSchemaChange, state: SSUnknownSQLState}, - vterrors.MixOfGroupFuncAndFields: {num: ERMixOfGroupFuncAndFields, state: SSClientError}, - vterrors.NetPacketTooLarge: {num: ERNetPacketTooLarge, state: SSNetError}, - vterrors.NonUniqError: {num: ERNonUniq, state: SSConstraintViolation}, - vterrors.NonUniqTable: {num: ERNonUniqTable, state: SSClientError}, - vterrors.NonUpdateableTable: {num: ERNonUpdateableTable, state: SSUnknownSQLState}, - vterrors.QueryInterrupted: {num: ERQueryInterrupted, state: SSQueryInterrupted}, - vterrors.SPDoesNotExist: {num: ERSPDoesNotExist, state: SSClientError}, - vterrors.SyntaxError: {num: ERSyntaxError, state: SSClientError}, - vterrors.UnsupportedPS: {num: ERUnsupportedPS, state: SSUnknownSQLState}, - vterrors.UnknownSystemVariable: {num: ERUnknownSystemVariable, state: SSUnknownSQLState}, - vterrors.UnknownTable: {num: ERUnknownTable, state: SSUnknownTable}, - vterrors.WrongGroupField: {num: ERWrongGroupField, state: SSClientError}, - vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns}, - vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError}, - vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError}, - vterrors.WrongValue: {num: ERWrongValue, state: SSUnknownSQLState}, - vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError}, - vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError}, - vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction}, - vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError}, - vterrors.RowIsReferenced2: {num: ERRowIsReferenced2, state: SSConstraintViolation}, - vterrors.NoReferencedRow2: {num: ErNoReferencedRow2, state: SSConstraintViolation}, - vterrors.NoSuchSession: {num: ERUnknownComError, state: SSNetError}, - vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, - vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, - vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState}, - vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState}, - vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState}, - vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError}, - vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, - vterrors.RegexpStringNotTerminated: {num: ERRegexpStringNotTerminated, state: SSUnknownSQLState}, - vterrors.RegexpBufferOverflow: {num: ERRegexpBufferOverflow, state: SSUnknownSQLState}, - vterrors.RegexpIllegalArgument: {num: ERRegexpIllegalArgument, state: SSUnknownSQLState}, - vterrors.RegexpIndexOutOfBounds: {num: ERRegexpIndexOutOfBounds, state: SSUnknownSQLState}, - vterrors.RegexpInternal: {num: ERRegexpInternal, state: SSUnknownSQLState}, - vterrors.RegexpRuleSyntax: {num: ERRegexpRuleSyntax, state: SSUnknownSQLState}, - vterrors.RegexpBadEscapeSequence: {num: ERRegexpBadEscapeSequence, state: SSUnknownSQLState}, - vterrors.RegexpUnimplemented: {num: ERRegexpUnimplemented, state: SSUnknownSQLState}, - vterrors.RegexpMismatchParen: {num: ERRegexpMismatchParen, state: SSUnknownSQLState}, - vterrors.RegexpBadInterval: {num: ERRegexpBadInterval, state: SSUnknownSQLState}, - vterrors.RegexpMaxLtMin: {num: ERRRegexpMaxLtMin, state: SSUnknownSQLState}, - vterrors.RegexpInvalidBackRef: {num: ERRegexpInvalidBackRef, state: SSUnknownSQLState}, - vterrors.RegexpLookBehindLimit: {num: ERRegexpLookBehindLimit, state: SSUnknownSQLState}, - vterrors.RegexpMissingCloseBracket: {num: ERRegexpMissingCloseBracket, state: SSUnknownSQLState}, - vterrors.RegexpInvalidRange: {num: ERRegexpInvalidRange, state: SSUnknownSQLState}, - vterrors.RegexpStackOverflow: {num: ERRegexpStackOverflow, state: SSUnknownSQLState}, - vterrors.RegexpTimeOut: {num: ERRegexpTimeOut, state: SSUnknownSQLState}, - vterrors.RegexpPatternTooBig: {num: ERRegexpPatternTooBig, state: SSUnknownSQLState}, - vterrors.RegexpInvalidFlag: {num: ERRegexpInvalidFlag, state: SSUnknownSQLState}, - vterrors.RegexpInvalidCaptureGroup: {num: ERRegexpInvalidCaptureGroup, state: SSUnknownSQLState}, - vterrors.CharacterSetMismatch: {num: ERCharacterSetMismatch, state: SSUnknownSQLState}, - vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, - vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, - vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, - vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, - vterrors.VectorConversion: {num: ERVectorConversion, state: SSUnknownSQLState}, + vterrors.Undefined: {num: ERUnknownError, state: SSUnknownSQLState}, + vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError}, + vterrors.BadDb: {num: ERBadDb, state: SSClientError}, + vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError}, + vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable}, + vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError}, + vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange}, + vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState}, + vterrors.DbDropExists: {num: ERDbDropExists, state: SSUnknownSQLState}, + vterrors.DupFieldName: {num: ERDupFieldName, state: SSDupFieldName}, + vterrors.EmptyQuery: {num: EREmptyQuery, state: SSClientError}, + vterrors.IncorrectGlobalLocalVar: {num: ERIncorrectGlobalLocalVar, state: SSUnknownSQLState}, + vterrors.InnodbReadOnly: {num: ERInnodbReadOnly, state: SSUnknownSQLState}, + vterrors.LockOrActiveTransaction: {num: ERLockOrActiveTransaction, state: SSUnknownSQLState}, + vterrors.NoDB: {num: ERNoDb, state: SSNoDB}, + vterrors.NoSuchTable: {num: ERNoSuchTable, state: SSUnknownTable}, + vterrors.NotSupportedYet: {num: ERNotSupportedYet, state: SSClientError}, + vterrors.ForbidSchemaChange: {num: ERForbidSchemaChange, state: SSUnknownSQLState}, + vterrors.MixOfGroupFuncAndFields: {num: ERMixOfGroupFuncAndFields, state: SSClientError}, + vterrors.NetPacketTooLarge: {num: ERNetPacketTooLarge, state: SSNetError}, + vterrors.NonUniqError: {num: ERNonUniq, state: SSConstraintViolation}, + vterrors.NonUniqTable: {num: ERNonUniqTable, state: SSClientError}, + vterrors.NonUpdateableTable: {num: ERNonUpdateableTable, state: SSUnknownSQLState}, + vterrors.QueryInterrupted: {num: ERQueryInterrupted, state: SSQueryInterrupted}, + vterrors.SPDoesNotExist: {num: ERSPDoesNotExist, state: SSClientError}, + vterrors.SyntaxError: {num: ERSyntaxError, state: SSClientError}, + vterrors.UnsupportedPS: {num: ERUnsupportedPS, state: SSUnknownSQLState}, + vterrors.UnknownSystemVariable: {num: ERUnknownSystemVariable, state: SSUnknownSQLState}, + vterrors.UnknownTable: {num: ERUnknownTable, state: SSUnknownTable}, + vterrors.WrongGroupField: {num: ERWrongGroupField, state: SSClientError}, + vterrors.WrongNumberOfColumnsInSelect: {num: ERWrongNumberOfColumnsInSelect, state: SSWrongNumberOfColumns}, + vterrors.WrongTypeForVar: {num: ERWrongTypeForVar, state: SSClientError}, + vterrors.WrongValueForVar: {num: ERWrongValueForVar, state: SSClientError}, + vterrors.WrongValue: {num: ERWrongValue, state: SSUnknownSQLState}, + vterrors.WrongFieldWithGroup: {num: ERWrongFieldWithGroup, state: SSClientError}, + vterrors.ServerNotAvailable: {num: ERServerIsntAvailable, state: SSNetError}, + vterrors.CantDoThisInTransaction: {num: ERCantDoThisDuringAnTransaction, state: SSCantDoThisDuringAnTransaction}, + vterrors.RequiresPrimaryKey: {num: ERRequiresPrimaryKey, state: SSClientError}, + vterrors.RowIsReferenced2: {num: ERRowIsReferenced2, state: SSConstraintViolation}, + vterrors.NoReferencedRow2: {num: ErNoReferencedRow2, state: SSConstraintViolation}, + vterrors.NoSuchSession: {num: ERUnknownComError, state: SSNetError}, + vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, + vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, + vterrors.WrongArguments: {num: ERWrongArguments, state: SSUnknownSQLState}, + vterrors.ViewWrongList: {num: ERViewWrongList, state: SSUnknownSQLState}, + vterrors.UnknownStmtHandler: {num: ERUnknownStmtHandler, state: SSUnknownSQLState}, + vterrors.KeyDoesNotExist: {num: ERKeyDoesNotExist, state: SSClientError}, + vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, + vterrors.RegexpStringNotTerminated: {num: ERRegexpStringNotTerminated, state: SSUnknownSQLState}, + vterrors.RegexpBufferOverflow: {num: ERRegexpBufferOverflow, state: SSUnknownSQLState}, + vterrors.RegexpIllegalArgument: {num: ERRegexpIllegalArgument, state: SSUnknownSQLState}, + vterrors.RegexpIndexOutOfBounds: {num: ERRegexpIndexOutOfBounds, state: SSUnknownSQLState}, + vterrors.RegexpInternal: {num: ERRegexpInternal, state: SSUnknownSQLState}, + vterrors.RegexpRuleSyntax: {num: ERRegexpRuleSyntax, state: SSUnknownSQLState}, + vterrors.RegexpBadEscapeSequence: {num: ERRegexpBadEscapeSequence, state: SSUnknownSQLState}, + vterrors.RegexpUnimplemented: {num: ERRegexpUnimplemented, state: SSUnknownSQLState}, + vterrors.RegexpMismatchParen: {num: ERRegexpMismatchParen, state: SSUnknownSQLState}, + vterrors.RegexpBadInterval: {num: ERRegexpBadInterval, state: SSUnknownSQLState}, + vterrors.RegexpMaxLtMin: {num: ERRRegexpMaxLtMin, state: SSUnknownSQLState}, + vterrors.RegexpInvalidBackRef: {num: ERRegexpInvalidBackRef, state: SSUnknownSQLState}, + vterrors.RegexpLookBehindLimit: {num: ERRegexpLookBehindLimit, state: SSUnknownSQLState}, + vterrors.RegexpMissingCloseBracket: {num: ERRegexpMissingCloseBracket, state: SSUnknownSQLState}, + vterrors.RegexpInvalidRange: {num: ERRegexpInvalidRange, state: SSUnknownSQLState}, + vterrors.RegexpStackOverflow: {num: ERRegexpStackOverflow, state: SSUnknownSQLState}, + vterrors.RegexpTimeOut: {num: ERRegexpTimeOut, state: SSUnknownSQLState}, + vterrors.RegexpPatternTooBig: {num: ERRegexpPatternTooBig, state: SSUnknownSQLState}, + vterrors.RegexpInvalidFlag: {num: ERRegexpInvalidFlag, state: SSUnknownSQLState}, + vterrors.RegexpInvalidCaptureGroup: {num: ERRegexpInvalidCaptureGroup, state: SSUnknownSQLState}, + vterrors.CharacterSetMismatch: {num: ERCharacterSetMismatch, state: SSUnknownSQLState}, + vterrors.WrongParametersToNativeFct: {num: ERWrongParametersToNativeFct, state: SSUnknownSQLState}, + vterrors.KillDeniedError: {num: ERKillDenied, state: SSUnknownSQLState}, + vterrors.BadNullError: {num: ERBadNullError, state: SSConstraintViolation}, + vterrors.InvalidGroupFuncUse: {num: ERInvalidGroupFuncUse, state: SSUnknownSQLState}, + vterrors.VectorConversion: {num: ERVectorConversion, state: SSUnknownSQLState}, + vterrors.CTERecursiveRequiresSingleReference: {num: ERCTERecursiveRequiresSingleReference, state: SSUnknownSQLState}, + vterrors.CTERecursiveRequiresUnion: {num: ERCTERecursiveRequiresUnion, state: SSUnknownSQLState}, + vterrors.CTERecursiveForbidsAggregation: {num: ERCTERecursiveForbidsAggregation, state: SSUnknownSQLState}, + vterrors.CTERecursiveForbiddenJoinOrder: {num: ERCTERecursiveForbiddenJoinOrder, state: SSUnknownSQLState}, + vterrors.CTEMaxRecursionDepth: {num: ERCTEMaxRecursionDepth, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test new file mode 100644 index 00000000000..de38a21cd78 --- /dev/null +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -0,0 +1,110 @@ +# Create tables +CREATE TABLE employees +( + id INT PRIMARY KEY, + name VARCHAR(100), + manager_id INT +); + +# Insert data into the tables +INSERT INTO employees (id, name, manager_id) +VALUES (1, 'CEO', NULL), + (2, 'CTO', 1), + (3, 'CFO', 1), + (4, 'Engineer1', 2), + (5, 'Engineer2', 2), + (6, 'Accountant1', 3), + (7, 'Accountant2', 3); + +# Simple recursive CTE using literal values +WITH RECURSIVE numbers AS (SELECT 1 AS n + UNION ALL + SELECT n + 1 + FROM numbers + WHERE n < 5) +SELECT * +FROM numbers; + +# Recursive CTE joined with a normal table +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte; + +# Recursive CTE used in a derived table outside the CTE definition +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT derived.id, derived.name, derived.manager_id +FROM (SELECT * FROM emp_cte) AS derived; + +# Recursive CTE with additional computation +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id, 1 AS level + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id, cte.level + 1 + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte; + +# Recursive CTE with filtering in the recursive part +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id + WHERE e.name LIKE 'Engineer%') +SELECT * +FROM emp_cte; + +# Recursive CTE with limit +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte +LIMIT 5; + +# Recursive CTE using literal values and joined with a real table on the outside +WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id + UNION ALL + SELECT id + 1, value * 2, id + FROM literal_cte + WHERE id < 5) +SELECT l.id, l.value, l.manager_id, e.name AS employee_name +FROM literal_cte l + LEFT JOIN employees e ON l.id = e.id; + +# Recursive CTE with aggregation outside the CTE +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT manager_id, COUNT(*) AS employee_count +FROM emp_cte +GROUP BY manager_id; + +--error infinite recursion +with recursive cte as (select 1 as n union all select n+1 from cte) +select * +from cte; \ No newline at end of file diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 8a2363331e9..938b9063011 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -161,7 +161,7 @@ type ( CommonTableExpr struct { ID IdentifierCS Columns Columns - Subquery *Subquery + Subquery SelectStatement } // ChangeColumn is used to change the column definition, can also rename the column in alter table command ChangeColumn struct { diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 7a59832b867..f22a1790232 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1022,7 +1022,7 @@ func CloneRefOfCommonTableExpr(n *CommonTableExpr) *CommonTableExpr { out := *n out.ID = CloneIdentifierCS(n.ID) out.Columns = CloneColumns(n.Columns) - out.Subquery = CloneRefOfSubquery(n.Subquery) + out.Subquery = CloneSelectStatement(n.Subquery) return &out } diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index caa00181f9e..0e329e24f31 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -1522,12 +1522,12 @@ func (c *cow) copyOnRewriteRefOfCommonTableExpr(n *CommonTableExpr, parent SQLNo if c.pre == nil || c.pre(n, parent) { _ID, changedID := c.copyOnRewriteIdentifierCS(n.ID, n) _Columns, changedColumns := c.copyOnRewriteColumns(n.Columns, n) - _Subquery, changedSubquery := c.copyOnRewriteRefOfSubquery(n.Subquery, n) + _Subquery, changedSubquery := c.copyOnRewriteSelectStatement(n.Subquery, n) if changedID || changedColumns || changedSubquery { res := *n res.ID, _ = _ID.(IdentifierCS) res.Columns, _ = _Columns.(Columns) - res.Subquery, _ = _Subquery.(*Subquery) + res.Subquery, _ = _Subquery.(SelectStatement) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 2b391db630b..cf076d706e7 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2193,7 +2193,7 @@ func (cmp *Comparator) RefOfCommonTableExpr(a, b *CommonTableExpr) bool { } return cmp.IdentifierCS(a.ID, b.ID) && cmp.Columns(a.Columns, b.Columns) && - cmp.RefOfSubquery(a.Subquery, b.Subquery) + cmp.SelectStatement(a.Subquery, b.Subquery) } // RefOfComparisonExpr does deep equals between the two objects. diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index e89da3dc270..587b32d4afe 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -167,7 +167,7 @@ func (node *With) Format(buf *TrackedBuffer) { // Format formats the node. func (node *CommonTableExpr) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "%v%v as %v ", node.ID, node.Columns, node.Subquery) + buf.astPrintf(node, "%v%v as (%v) ", node.ID, node.Columns, node.Subquery) } // Format formats the node. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index f04928c7dfa..c2b02711398 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -243,9 +243,9 @@ func (node *With) FormatFast(buf *TrackedBuffer) { func (node *CommonTableExpr) FormatFast(buf *TrackedBuffer) { node.ID.FormatFast(buf) node.Columns.FormatFast(buf) - buf.WriteString(" as ") + buf.WriteString(" as (") node.Subquery.FormatFast(buf) - buf.WriteByte(' ') + buf.WriteString(") ") } // FormatFast formats the node. diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index cd4d5304047..ae96fe9c1fe 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -428,6 +428,20 @@ func (node *AliasedTableExpr) TableName() (TableName, error) { return tableName, nil } +// TableNameString returns a TableNameString pointing to this table expr +func (node *AliasedTableExpr) TableNameString() string { + if node.As.NotEmpty() { + return node.As.String() + } + + tableName, ok := node.Expr.(TableName) + if !ok { + panic(vterrors.VT13001("Derived table should have an alias. This should not be possible")) + } + + return tableName.Name.String() +} + // IsEmpty returns true if TableName is nil or empty. func (node TableName) IsEmpty() bool { // If Name is empty, Qualifier is also empty. diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 0cad7237455..015c27a2cbd 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1964,8 +1964,8 @@ func (a *application) rewriteRefOfCommonTableExpr(parent SQLNode, node *CommonTa }) { return false } - if !a.rewriteRefOfSubquery(node, node.Subquery, func(newNode, parent SQLNode) { - parent.(*CommonTableExpr).Subquery = newNode.(*Subquery) + if !a.rewriteSelectStatement(node, node.Subquery, func(newNode, parent SQLNode) { + parent.(*CommonTableExpr).Subquery = newNode.(SelectStatement) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index d73ed076dbb..d33c2d1e055 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1172,7 +1172,7 @@ func VisitRefOfCommonTableExpr(in *CommonTableExpr, f Visit) error { if err := VisitColumns(in.Columns, f); err != nil { return err } - if err := VisitRefOfSubquery(in.Subquery, f); err != nil { + if err := VisitSelectStatement(in.Subquery, f); err != nil { return err } return nil diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 2110ea8be30..391e9a84ad3 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -834,7 +834,7 @@ func (cached *CommonTableExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } // field ID vitess.io/vitess/go/vt/sqlparser.IdentifierCS size += cached.ID.CachedSize(false) @@ -845,8 +845,10 @@ func (cached *CommonTableExpr) CachedSize(alloc bool) int64 { size += elem.CachedSize(false) } } - // field Subquery *vitess.io/vitess/go/vt/sqlparser.Subquery - size += cached.Subquery.CachedSize(true) + // field Subquery vitess.io/vitess/go/vt/sqlparser.SelectStatement + if cc, ok := cached.Subquery.(cachedObject); ok { + size += cc.CachedSize(true) + } return size } func (cached *ComparisonExpr) CachedSize(alloc bool) int64 { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 196d020a36b..9912b19f323 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -10591,7 +10591,7 @@ yydefault: var yyLOCAL *CommonTableExpr //line sql.y:757 { - yyLOCAL = &CommonTableExpr{ID: yyDollar[1].identifierCS, Columns: yyDollar[2].columnsUnion(), Subquery: yyDollar[4].subqueryUnion()} + yyLOCAL = &CommonTableExpr{ID: yyDollar[1].identifierCS, Columns: yyDollar[2].columnsUnion(), Subquery: yyDollar[4].subqueryUnion().Select} } yyVAL.union = yyLOCAL case 54: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 5bef040c4f1..64ce957d2dd 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -755,7 +755,7 @@ with_list: common_table_expr: table_id column_list_opt AS subquery { - $$ = &CommonTableExpr{ID: $1, Columns: $2, Subquery: $4} + $$ = &CommonTableExpr{ID: $1, Columns: $2, Subquery: $4.Select} } query_expression_parens: diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 83a87503265..31c98cef280 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -97,6 +97,11 @@ var ( VT09023 = errorWithoutState("VT09023", vtrpcpb.Code_FAILED_PRECONDITION, "could not map %v to a keyspace id", "Unable to determine the shard for the given row.") VT09024 = errorWithoutState("VT09024", vtrpcpb.Code_FAILED_PRECONDITION, "could not map %v to a unique keyspace id: %v", "Unable to determine the shard for the given row.") VT09025 = errorWithoutState("VT09025", vtrpcpb.Code_FAILED_PRECONDITION, "atomic transaction error: %v", "Error in atomic transactions") + VT09026 = errorWithState("VT09026", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveRequiresUnion, "Recursive Common Table Expression '%s' should contain a UNION", "") + VT09027 = errorWithState("VT09027", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbidsAggregation, "Recursive Common Table Expression '%s' can contain neither aggregation nor window functions in recursive query block", "") + VT09028 = errorWithState("VT09028", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveForbiddenJoinOrder, "In recursive query block of Recursive Common Table Expression '%s', the recursive table must neither be in the right argument of a LEFT JOIN, nor be forced to be non-first with join order hints", "") + VT09029 = errorWithState("VT09029", vtrpcpb.Code_FAILED_PRECONDITION, CTERecursiveRequiresSingleReference, "In recursive query block of Recursive Common Table Expression %s, the recursive table must be referenced only once, and not in any subquery", "") + VT09030 = errorWithState("VT09030", vtrpcpb.Code_FAILED_PRECONDITION, CTEMaxRecursionDepth, "Recursive query aborted after 1000 iterations.", "") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") VT10002 = errorWithoutState("VT10002", vtrpcpb.Code_ABORTED, "atomic distributed transaction not allowed: %s", "The distributed transaction cannot be committed. A rollback decision is taken.") @@ -183,6 +188,10 @@ var ( VT09022, VT09023, VT09024, + VT09026, + VT09027, + VT09028, + VT09029, VT10001, VT10002, VT12001, diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index 82434df382a..528000e9e41 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -62,6 +62,11 @@ const ( NoReferencedRow2 UnknownStmtHandler KeyDoesNotExist + CTERecursiveRequiresSingleReference + CTERecursiveRequiresUnion + CTERecursiveForbidsAggregation + CTERecursiveForbiddenJoinOrder + CTEMaxRecursionDepth // not found BadDb diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index c05e276caa9..06aa9f0d6a9 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -857,6 +857,40 @@ func (cached *Projection) CachedSize(alloc bool) int64 { } return size } + +//go:nocheckptr +func (cached *RecurseCTE) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(48) + } + // field Seed vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Seed.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Term vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Term.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Vars map[string]int + if cached.Vars != nil { + size += int64(48) + hmap := reflect.ValueOf(cached.Vars) + numBuckets := int(math.Pow(2, float64((*(*uint8)(unsafe.Pointer(hmap.Pointer() + uintptr(9))))))) + numOldBuckets := (*(*uint16)(unsafe.Pointer(hmap.Pointer() + uintptr(10)))) + size += hack.RuntimeAllocSize(int64(numOldBuckets * 208)) + if len(cached.Vars) > 0 || numBuckets > 1 { + size += hack.RuntimeAllocSize(int64(numBuckets * 208)) + } + for k := range cached.Vars { + size += hack.RuntimeAllocSize(int64(len(k))) + } + } + return size +} func (cached *RenameFields) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index e992c2a4623..6ab54fe9e7b 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -80,7 +80,7 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar if r == nil { return nil, f.sendErr } - return r, nil + return r.Copy(), nil } func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtgate/engine/recurse_cte.go b/go/vt/vtgate/engine/recurse_cte.go new file mode 100644 index 00000000000..f523883d280 --- /dev/null +++ b/go/vt/vtgate/engine/recurse_cte.go @@ -0,0 +1,155 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vterrors" +) + +// RecurseCTE is used to represent recursive CTEs +// Seed is used to represent the non-recursive part that initializes the result set. +// It's result are then used to start the recursion on the Term side +// The values being sent to the Term side are stored in the Vars map - +// the key is the bindvar name and the value is the index of the column in the recursive result +type RecurseCTE struct { + Seed, Term Primitive + + Vars map[string]int +} + +var _ Primitive = (*RecurseCTE)(nil) + +func (r *RecurseCTE) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + res, err := vcursor.ExecutePrimitive(ctx, r.Seed, bindVars, wantfields) + if err != nil { + return nil, err + } + + // recurseRows contains the rows used in the next recursion + recurseRows := res.Rows + joinVars := make(map[string]*querypb.BindVariable) + loops := 0 + for len(recurseRows) > 0 { + // copy over the results from the previous recursion + theseRows := recurseRows + recurseRows = nil + for _, row := range theseRows { + for k, col := range r.Vars { + joinVars[k] = sqltypes.ValueBindVariable(row[col]) + } + // check if the context is done - we might be in a long running recursion + if err := ctx.Err(); err != nil { + return nil, err + } + rresult, err := vcursor.ExecutePrimitive(ctx, r.Term, combineVars(bindVars, joinVars), false) + if err != nil { + return nil, err + } + recurseRows = append(recurseRows, rresult.Rows...) + res.Rows = append(res.Rows, rresult.Rows...) + loops++ + if loops > 1000 { // TODO: This should be controlled with a system variable setting + return nil, vterrors.VT09030("") + } + } + } + return res, nil +} + +func (r *RecurseCTE) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + if vcursor.Session().InTransaction() { + res, err := r.TryExecute(ctx, vcursor, bindVars, wantfields) + if err != nil { + return err + } + return callback(res) + } + return vcursor.StreamExecutePrimitive(ctx, r.Seed, bindVars, wantfields, func(result *sqltypes.Result) error { + err := callback(result) + if err != nil { + return err + } + return r.recurse(ctx, vcursor, bindVars, result, callback) + }) +} + +func (r *RecurseCTE) recurse(ctx context.Context, vcursor VCursor, bindvars map[string]*querypb.BindVariable, result *sqltypes.Result, callback func(*sqltypes.Result) error) error { + if len(result.Rows) == 0 { + return nil + } + joinVars := make(map[string]*querypb.BindVariable) + for _, row := range result.Rows { + for k, col := range r.Vars { + joinVars[k] = sqltypes.ValueBindVariable(row[col]) + } + + err := vcursor.StreamExecutePrimitive(ctx, r.Term, combineVars(bindvars, joinVars), false, func(result *sqltypes.Result) error { + err := callback(result) + if err != nil { + return err + } + return r.recurse(ctx, vcursor, bindvars, result, callback) + }) + if err != nil { + return err + } + } + return nil +} + +func (r *RecurseCTE) RouteType() string { + return "RecurseCTE" +} + +func (r *RecurseCTE) GetKeyspaceName() string { + if r.Seed.GetKeyspaceName() == r.Term.GetKeyspaceName() { + return r.Seed.GetKeyspaceName() + } + return r.Seed.GetKeyspaceName() + "_" + r.Term.GetKeyspaceName() +} + +func (r *RecurseCTE) GetTableName() string { + return r.Seed.GetTableName() +} + +func (r *RecurseCTE) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return r.Seed.GetFields(ctx, vcursor, bindVars) +} + +func (r *RecurseCTE) NeedsTransaction() bool { + return r.Seed.NeedsTransaction() || r.Term.NeedsTransaction() +} + +func (r *RecurseCTE) Inputs() ([]Primitive, []map[string]any) { + return []Primitive{r.Seed, r.Term}, nil +} + +func (r *RecurseCTE) description() PrimitiveDescription { + other := map[string]interface{}{ + "JoinVars": orderedStringIntMap(r.Vars), + } + + return PrimitiveDescription{ + OperatorType: "RecurseCTE", + Other: other, + Inputs: nil, + } +} diff --git a/go/vt/vtgate/engine/recurse_cte_test.go b/go/vt/vtgate/engine/recurse_cte_test.go new file mode 100644 index 00000000000..d6826284d21 --- /dev/null +++ b/go/vt/vtgate/engine/recurse_cte_test.go @@ -0,0 +1,129 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func TestRecurseDualQuery(t *testing.T) { + // Test that the RecurseCTE primitive works as expected. + // The test is testing something like this: + // WITH RECURSIVE cte AS (SELECT 1 as col1 UNION SELECT col1+1 FROM cte WHERE col1 < 5) SELECT * FROM cte; + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1", + "int64", + ), + "1", + ), + }, + } + rightFields := sqltypes.MakeTestFields( + "col4", + "int64", + ) + + rightPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + rightFields, + "2", + ), + sqltypes.MakeTestResult( + rightFields, + "3", + ), + sqltypes.MakeTestResult( + rightFields, + "4", + ), sqltypes.MakeTestResult( + rightFields, + ), + }, + } + bv := map[string]*querypb.BindVariable{} + + cte := &RecurseCTE{ + Seed: leftPrim, + Term: rightPrim, + Vars: map[string]int{"col1": 0}, + } + + r, err := cte.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `Execute col1: type:INT64 value:"1" false`, + `Execute col1: type:INT64 value:"2" false`, + `Execute col1: type:INT64 value:"3" false`, + `Execute col1: type:INT64 value:"4" false`, + }) + + wantRes := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1", + "int64", + ), + "1", + "2", + "3", + "4", + ) + expectResult(t, r, wantRes) + + // testing the streaming mode. + + leftPrim.rewind() + rightPrim.rewind() + + r, err = wrapStreamExecute(cte, &noopVCursor{}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `StreamExecute col1: type:INT64 value:"1" false`, + `StreamExecute col1: type:INT64 value:"2" false`, + `StreamExecute col1: type:INT64 value:"3" false`, + `StreamExecute col1: type:INT64 value:"4" false`, + }) + expectResult(t, r, wantRes) + + // testing the streaming mode with transaction + + leftPrim.rewind() + rightPrim.rewind() + + r, err = wrapStreamExecute(cte, &noopVCursor{inTx: true}, bv, true) + require.NoError(t, err) + + rightPrim.ExpectLog(t, []string{ + `Execute col1: type:INT64 value:"1" false`, + `Execute col1: type:INT64 value:"2" false`, + `Execute col1: type:INT64 value:"3" false`, + `Execute col1: type:INT64 value:"4" false`, + }) + expectResult(t, r, wantRes) + +} diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 546a9854f26..b4aaf6fc64d 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -77,6 +77,8 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato return transformSequential(ctx, op) case *operators.DMLWithInput: return transformDMLWithInput(ctx, op) + case *operators.RecurseCTE: + return transformRecurseCTE(ctx, op) } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToPrimitive)", op)) @@ -981,6 +983,22 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex) return prim, nil } +func transformRecurseCTE(ctx *plancontext.PlanningContext, op *operators.RecurseCTE) (engine.Primitive, error) { + seed, err := transformToPrimitive(ctx, op.Seed) + if err != nil { + return nil, err + } + term, err := transformToPrimitive(ctx, op.Term) + if err != nil { + return nil, err + } + return &engine.RecurseCTE{ + Seed: seed, + Term: term, + Vars: op.Vars, + }, nil +} + func generateQuery(statement sqlparser.Statement) string { buf := sqlparser.NewTrackedBuffer(dmlFormatter) statement.Format(buf) diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 08cf3c4801c..8cc23c57ae7 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -55,7 +55,36 @@ func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement return q.stmt, q.dmlOperator, nil } +// includeTable will return false if the table is a CTE, and it is not merged +// it will return true if the table is not a CTE or if it is a CTE and it is merged +func (qb *queryBuilder) includeTable(op *Table) bool { + if qb.ctx.SemTable == nil { + return true + } + tbl, err := qb.ctx.SemTable.TableInfoFor(op.QTable.ID) + if err != nil { + panic(err) + } + cteTbl, isCTE := tbl.(*semantics.CTETable) + if !isCTE { + return true + } + + return cteTbl.Merged +} + func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { + if tableID.NumberOfTables() == 1 && qb.ctx.SemTable != nil { + tblInfo, err := qb.ctx.SemTable.TableInfoFor(tableID) + if err != nil { + panic(err) + } + cte, isCTE := tblInfo.(*semantics.CTETable) + if isCTE { + tableName = cte.TableName + db = "" + } + } tableExpr := sqlparser.TableName{ Name: sqlparser.NewIdentifierCS(tableName), Qualifier: sqlparser.NewIdentifierCS(db), @@ -207,6 +236,26 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { } } +func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string) { + cteUnion := &sqlparser.Union{ + Left: qb.stmt.(sqlparser.SelectStatement), + Right: other.stmt.(sqlparser.SelectStatement), + } + + qb.stmt = &sqlparser.Select{ + With: &sqlparser.With{ + Recursive: true, + CTEs: []*sqlparser.CommonTableExpr{{ + ID: sqlparser.NewIdentifierCS(name), + Columns: nil, + Subquery: cteUnion, + }}, + }, + } + + qb.addTable("", name, alias, "", nil) +} + type FromStatement interface { GetFrom() []sqlparser.TableExpr SetFrom([]sqlparser.TableExpr) @@ -401,6 +450,8 @@ func buildQuery(op Operator, qb *queryBuilder) { buildDelete(op, qb) case *Insert: buildDML(op, qb) + case *RecurseCTE: + buildRecursiveCTE(op, qb) default: panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) } @@ -492,6 +543,10 @@ func buildLimit(op *Limit, qb *queryBuilder) { } func buildTable(op *Table, qb *queryBuilder) { + if !qb.includeTable(op) { + return + } + dbName := "" if op.QTable.IsInfSchema { @@ -551,7 +606,16 @@ func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.RHS, qbR) - qb.joinWith(qbR, pred, op.JoinType) + + switch { + // if we have a recursive cte, we might be missing a statement from one of the sides + case qbR.stmt == nil: + // do nothing + case qb.stmt == nil: + qb.stmt = qbR.stmt + default: + qb.joinWith(qbR, pred, op.JoinType) + } } func buildUnion(op *Union, qb *queryBuilder) { @@ -636,6 +700,28 @@ func buildHorizon(op *Horizon, qb *queryBuilder) { sqlparser.RemoveKeyspaceInCol(qb.stmt) } +func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { + predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + buildQuery(op.Seed, qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildQuery(op.Term, qbR) + qbR.addPredicate(pred) + infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) + if err != nil { + panic(err) + } + + qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String()) +} + func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { switch { case h1 == nil && h2 == nil: diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index f7bd5b131b8..4c6baab3729 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -204,7 +204,6 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq rhs := TableID(aj.RHS) both := lhs.Merge(rhs) deps := ctx.SemTable.RecursiveDeps(e) - switch { case deps.IsSolvedBy(lhs): col.LHSExprs = []BindVarExpr{{Expr: e}} diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index f017f77d6a3..4f0ab742935 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -254,24 +254,33 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr panic(err) } - if vt, isVindex := tableInfo.(*semantics.VindexTable); isVindex { - solves := tableID + switch tableInfo := tableInfo.(type) { + case *semantics.VindexTable: return &Vindex{ Table: VindexTable{ TableID: tableID, Alias: tableExpr, Table: tbl, - VTable: vt.Table.GetVindexTable(), + VTable: tableInfo.Table.GetVindexTable(), }, - Vindex: vt.Vindex, - Solved: solves, + Vindex: tableInfo.Vindex, + Solved: tableID, } + case *semantics.CTETable: + return createDualCTETable(ctx, tableID, tableInfo) + case *semantics.RealTable: + if tableInfo.CTE != nil { + return createRecursiveCTE(ctx, tableInfo.CTE, tableID) + } + + qg := newQueryGraph() + isInfSchema := tableInfo.IsInfSchema() + qt := &QueryTable{Alias: tableExpr, Table: tbl, ID: tableID, IsInfSchema: isInfSchema} + qg.Tables = append(qg.Tables, qt) + return qg + default: + panic(vterrors.VT13001(fmt.Sprintf("unknown table type %T", tableInfo))) } - qg := newQueryGraph() - isInfSchema := tableInfo.IsInfSchema() - qt := &QueryTable{Alias: tableExpr, Table: tbl, ID: tableID, IsInfSchema: isInfSchema} - qg.Tables = append(qg.Tables, qt) - return qg case *sqlparser.DerivedTable: if onlyTable && tbl.Select.GetLimit() == nil { tbl.Select.SetOrderBy(nil) @@ -292,6 +301,56 @@ func getOperatorFromAliasedTableExpr(ctx *plancontext.PlanningContext, tableExpr } } +func createDualCTETable(ctx *plancontext.PlanningContext, tableID semantics.TableSet, tableInfo *semantics.CTETable) Operator { + vschemaTable, _, _, _, _, err := ctx.VSchema.FindTableOrVindex(sqlparser.NewTableName("dual")) + if err != nil { + panic(err) + } + qtbl := &QueryTable{ + ID: tableID, + Alias: tableInfo.ASTNode, + Table: sqlparser.NewTableName("dual"), + } + return createRouteFromVSchemaTable(ctx, qtbl, vschemaTable, false, nil) +} + +func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE, outerID semantics.TableSet) Operator { + union, ok := def.Query.(*sqlparser.Union) + if !ok { + panic(vterrors.VT13001("expected UNION in recursive CTE")) + } + + seed := translateQueryToOp(ctx, union.Left) + + // Push the CTE definition to the stack so that it can be used in the recursive part of the query + ctx.PushCTE(def, *def.IDForRecurse) + + term := translateQueryToOp(ctx, union.Right) + horizon, ok := term.(*Horizon) + if !ok { + panic(vterrors.VT09027(def.Name)) + } + term = horizon.Source + horizon.Source = nil // not sure about this + activeCTE, err := ctx.PopCTE() + if err != nil { + panic(err) + } + + return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID) +} + +func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet { + for i, table := range ctx.SemTable.Tables { + tbl, ok := table.(*semantics.CTETable) + if !ok || tbl.CTE.Name != def.Name { + continue + } + return semantics.SingleTableSet(i) + } + panic(vterrors.VT13001("recursive table not found")) +} + func crossJoin(ctx *plancontext.PlanningContext, exprs sqlparser.TableExprs) Operator { var output Operator for _, tableExpr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go new file mode 100644 index 00000000000..9ca453f39c6 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -0,0 +1,86 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator, *ApplyResult) { + op := tryMergeCTE(ctx, in.Seed, in.Term, in) + if op == nil { + return in, NoRewrite + } + + return op, Rewrote("Merged CTE") +} + +func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route { + seedRoute, termRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term) + if seedRoute == nil || !sameKeyspace { + return nil + } + + switch { + case a == dual: + return mergeCTE(ctx, seedRoute, termRoute, routingB, in) + case a == sharded && b == sharded: + return tryMergeCTESharded(ctx, seedRoute, termRoute, in) + default: + return nil + } +} + +func tryMergeCTESharded(ctx *plancontext.PlanningContext, seed, term *Route, in *RecurseCTE) *Route { + tblA := seed.Routing.(*ShardedRouting) + tblB := term.Routing.(*ShardedRouting) + switch tblA.RouteOpCode { + case engine.EqualUnique: + // If the two routes fully match, they can be merged together. + if tblB.RouteOpCode == engine.EqualUnique { + aVdx := tblA.SelectedVindex() + bVdx := tblB.SelectedVindex() + aExpr := tblA.VindexExpressions() + bExpr := tblB.VindexExpressions() + if aVdx == bVdx && gen4ValuesEqual(ctx, aExpr, bExpr) { + return mergeCTE(ctx, seed, term, tblA, in) + } + } + } + + return nil +} + +func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in *RecurseCTE) *Route { + in.Def.Merged = true + hz := in.Horizon + hz.Source = term.Source + newTerm, _ := expandHorizon(ctx, hz) + return &Route{ + Routing: r, + Source: &RecurseCTE{ + Predicates: in.Predicates, + Def: in.Def, + Seed: seed.Source, + Term: newTerm, + LeftID: in.LeftID, + OuterID: in.OuterID, + }, + MergedWith: []*Route{term}, + } +} diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 71d2e5a8048..35760bceafb 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -17,9 +17,11 @@ limitations under the License. package operators import ( + "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" ) // Join represents a join. If we have a predicate, this is an inner join. If no predicate exists, it is a cross join @@ -141,11 +143,49 @@ func addJoinPredicates( if subq != nil { continue } + + // if we are inside a CTE, we need to check if we depend on the recursion table + if cte := ctx.ActiveCTE(); cte != nil && ctx.SemTable.DirectDeps(pred).IsOverlapping(cte.Id) { + original := pred + pred = addCTEPredicate(ctx, pred, cte) + ctx.AddJoinPredicates(original, pred) + } op = op.AddPredicate(ctx, pred) } return sqc.getRootOperator(op, nil) } +// addCTEPredicate breaks the expression into LHS and RHS +func addCTEPredicate( + ctx *plancontext.PlanningContext, + pred sqlparser.Expr, + cte *plancontext.ContextCTE, +) sqlparser.Expr { + expr := breakCTEExpressionInLhsAndRhs(ctx, pred, cte.Id) + cte.Predicates = append(cte.Predicates, expr) + return expr.RightExpr +} + +func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlparser.Expr, lhsID semantics.TableSet) *plancontext.RecurseExpression { + col := breakExpressionInLHSandRHS(ctx, pred, lhsID) + + lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr { + col, ok := bve.Expr.(*sqlparser.ColName) + if !ok { + panic(vterrors.VT13001("expected column name")) + } + return plancontext.BindVarExpr{ + Name: bve.Name, + Expr: col, + } + }) + return &plancontext.RecurseExpression{ + Original: col.Original, + RightExpr: col.RHSExpr, + LeftExprs: lhsExprs, + } +} + func createJoin(ctx *plancontext.PlanningContext, LHS, RHS Operator) Operator { lqg, lok := LHS.(*QueryGraph) rqg, rok := RHS.(*QueryGraph) diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index bf8e96372bc..d5354e9548f 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -36,6 +36,7 @@ const ( initialPlanning pullDistinctFromUnion delegateAggregation + recursiveCTEHorizons addAggrOrdering cleanOutPerfDistinct dmlWithInput @@ -53,6 +54,8 @@ func (p Phase) String() string { return "pull distinct from UNION" case delegateAggregation: return "split aggregation between vtgate and mysql" + case recursiveCTEHorizons: + return "expand recursive CTE horizons" case addAggrOrdering: return "optimize aggregations with ORDER BY" case cleanOutPerfDistinct: @@ -72,6 +75,8 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool { return s.Union case delegateAggregation: return s.Aggregation + case recursiveCTEHorizons: + return s.RecursiveCTE case addAggrOrdering: return s.Aggregation case cleanOutPerfDistinct: @@ -93,6 +98,8 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator { return enableDelegateAggregation(ctx, op) case addAggrOrdering: return addOrderingForAllAggregations(ctx, op) + case recursiveCTEHorizons: + return planRecursiveCTEHorizons(ctx, op) case cleanOutPerfDistinct: return removePerformanceDistinctAboveRoute(ctx, op) case subquerySettling: @@ -207,51 +214,7 @@ func removePerformanceDistinctAboveRoute(_ *plancontext.PlanningContext, op Oper } func enableDelegateAggregation(ctx *plancontext.PlanningContext, op Operator) Operator { - return addColumnsToInput(ctx, op) -} - -// addColumnsToInput adds columns needed by an operator to its input. -// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. -func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator { - - addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { - addedCols := false - filter, ok := in.(*Filter) - if !ok { - return in, NoRewrite - } - - var neededAggrs []sqlparser.Expr - extractAggrs := func(cursor *sqlparser.CopyOnWriteCursor) { - node := cursor.Node() - if ctx.IsAggr(node) { - neededAggrs = append(neededAggrs, node.(sqlparser.Expr)) - } - } - - for _, expr := range filter.Predicates { - _ = sqlparser.CopyOnRewrite(expr, dontEnterSubqueries, extractAggrs, nil) - } - - if neededAggrs == nil { - return in, NoRewrite - } - - aggregator := findAggregatorInSource(filter.Source) - for _, aggr := range neededAggrs { - if aggregator.FindCol(ctx, aggr, false) == -1 { - aggregator.addColumnWithoutPushing(ctx, aeWrap(aggr), false) - addedCols = true - } - } - - if addedCols { - return in, Rewrote("added columns because filter needs it") - } - return in, NoRewrite - } - - return TopDown(root, TableID, addColumnsNeededByFilter, stopAtRoute) + return prepareForAggregationPushing(ctx, op) } // addOrderingForAllAggregations is run we have pushed down Aggregators as far down as possible. @@ -341,3 +304,94 @@ func addLiteralGroupingToRHS(in *ApplyJoin) (Operator, *ApplyResult) { }) return in, NoRewrite } + +// prepareForAggregationPushing adds columns needed by an operator to its input. +// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. +func prepareForAggregationPushing(ctx *plancontext.PlanningContext, root Operator) Operator { + return TopDown(root, TableID, func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + filter, ok := in.(*Filter) + if !ok { + return in, NoRewrite + } + + var neededAggrs []sqlparser.Expr + extractAggrs := func(cursor *sqlparser.CopyOnWriteCursor) { + node := cursor.Node() + if ctx.IsAggr(node) { + neededAggrs = append(neededAggrs, node.(sqlparser.Expr)) + } + } + + for _, expr := range filter.Predicates { + _ = sqlparser.CopyOnRewrite(expr, dontEnterSubqueries, extractAggrs, nil) + } + + if neededAggrs == nil { + return in, NoRewrite + } + + addedCols := false + aggregator := findAggregatorInSource(filter.Source) + for _, aggr := range neededAggrs { + if aggregator.FindCol(ctx, aggr, false) == -1 { + aggregator.addColumnWithoutPushing(ctx, aeWrap(aggr), false) + addedCols = true + } + } + + if addedCols { + return in, Rewrote("added columns because filter needs it") + } + return in, NoRewrite + }, stopAtRoute) +} + +// prepareForAggregationPushing adds columns needed by an operator to its input. +// This happens only when the filter expression can be retrieved as an offset from the underlying mysql. +func planRecursiveCTEHorizons(ctx *plancontext.PlanningContext, root Operator) Operator { + return TopDown(root, TableID, func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { + // These recursive CTEs have not been pushed under a route, so we will have to evaluate it one the vtgate + // That means that we need to turn anything that is coming from the recursion into arguments + rcte, ok := in.(*RecurseCTE) + if !ok { + return in, NoRewrite + } + hz := rcte.Horizon + hz.Source = rcte.Term + newTerm, _ := expandHorizon(ctx, hz) + pr := findProjection(newTerm) + ap, err := pr.GetAliasedProjections() + if err != nil { + panic(vterrors.VT09015()) + } + + // We need to break the expressions into LHS and RHS, and store them in the CTE for later use + projections := slice.Map(ap, func(p *ProjExpr) *plancontext.RecurseExpression { + recurseExpression := breakCTEExpressionInLhsAndRhs(ctx, p.EvalExpr, rcte.LeftID) + p.EvalExpr = recurseExpression.RightExpr + return recurseExpression + }) + rcte.Projections = projections + rcte.Term = newTerm + return rcte, Rewrote("expanded horizon on term side of recursive CTE") + }, stopAtRoute) +} + +func findProjection(op Operator) *Projection { + for { + proj, ok := op.(*Projection) + if ok { + return proj + } + inputs := op.Inputs() + if len(inputs) != 1 { + panic(vterrors.VT13001("unexpected multiple inputs")) + } + src := inputs[0] + _, isRoute := src.(*Route) + if isRoute { + panic(vterrors.VT13001("failed to find the projection")) + } + op = src + } +} diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index e88fb53edb3..0f2445a22e7 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -102,6 +102,9 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return tryPushDelete(in) case *Update: return tryPushUpdate(in) + case *RecurseCTE: + return tryMergeRecurse(ctx, in) + default: return in, NoRewrite } diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go new file mode 100644 index 00000000000..7a8c9dcd355 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -0,0 +1,209 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "strings" + + "golang.org/x/exp/maps" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +// RecurseCTE is used to represent a recursive CTE +type RecurseCTE struct { + Seed, // used to describe the non-recursive part that initializes the result set + Term Operator // the part that repeatedly applies the recursion, processing the result set + + // Def is the CTE definition according to the semantics + Def *semantics.CTE + + // Expressions are the predicates that are needed on the recurse side of the CTE + Predicates []*plancontext.RecurseExpression + Projections []*plancontext.RecurseExpression + + // Vars is the map of variables that are sent between the two parts of the recursive CTE + // It's filled in at offset planning time + Vars map[string]int + + // MyTableID is the id of the CTE + MyTableInfo *semantics.CTETable + + // Horizon is stored here until we either expand it or push it under a route + Horizon *Horizon + + // The LeftID is the id of the left side of the CTE + LeftID, + + // The OuterID is the id for this use of the CTE + OuterID semantics.TableSet +} + +var _ Operator = (*RecurseCTE)(nil) + +func newRecurse( + ctx *plancontext.PlanningContext, + def *semantics.CTE, + seed, term Operator, + predicates []*plancontext.RecurseExpression, + horizon *Horizon, + leftID, outerID semantics.TableSet, +) *RecurseCTE { + for _, pred := range predicates { + ctx.AddJoinPredicates(pred.Original, pred.RightExpr) + } + return &RecurseCTE{ + Def: def, + Seed: seed, + Term: term, + Predicates: predicates, + Horizon: horizon, + LeftID: leftID, + OuterID: outerID, + } +} + +func (r *RecurseCTE) Clone(inputs []Operator) Operator { + return &RecurseCTE{ + Seed: inputs[0], + Term: inputs[1], + Def: r.Def, + Predicates: r.Predicates, + Projections: r.Projections, + Vars: maps.Clone(r.Vars), + Horizon: r.Horizon, + LeftID: r.LeftID, + OuterID: r.OuterID, + } +} + +func (r *RecurseCTE) Inputs() []Operator { + return []Operator{r.Seed, r.Term} +} + +func (r *RecurseCTE) SetInputs(operators []Operator) { + r.Seed = operators[0] + r.Term = operators[1] +} + +func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { + r.Term = newFilter(r, e) + return r +} + +func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int { + r.makeSureWeHaveTableInfo(ctx) + e := semantics.RewriteDerivedTableExpression(expr.Expr, r.MyTableInfo) + offset := r.Seed.FindCol(ctx, e, false) + if offset == -1 { + panic(vterrors.VT13001("CTE column not found")) + } + return offset +} + +func (r *RecurseCTE) makeSureWeHaveTableInfo(ctx *plancontext.PlanningContext) { + if r.MyTableInfo == nil { + for _, table := range ctx.SemTable.Tables { + cte, ok := table.(*semantics.CTETable) + if !ok { + continue + } + if cte.CTE == r.Def { + r.MyTableInfo = cte + break + } + } + if r.MyTableInfo == nil { + panic(vterrors.VT13001("CTE not found")) + } + } +} + +func (r *RecurseCTE) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + seed := r.Seed.AddWSColumn(ctx, offset, underRoute) + term := r.Term.AddWSColumn(ctx, offset, underRoute) + if seed != term { + panic(vterrors.VT13001("CTE columns don't match")) + } + return seed +} + +func (r *RecurseCTE) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + r.makeSureWeHaveTableInfo(ctx) + expr = semantics.RewriteDerivedTableExpression(expr, r.MyTableInfo) + return r.Seed.FindCol(ctx, expr, underRoute) +} + +func (r *RecurseCTE) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + return r.Seed.GetColumns(ctx) +} + +func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { + return r.Seed.GetSelectExprs(ctx) +} + +func (r *RecurseCTE) ShortDescription() string { + if len(r.Vars) > 0 { + return fmt.Sprintf("%v", r.Vars) + } + expressions := slice.Map(r.expressions(), func(expr *plancontext.RecurseExpression) string { + return sqlparser.String(expr.Original) + }) + return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(expressions, ", ")) +} + +func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { + // RecurseCTE is a special case. It never guarantees any ordering. + return nil +} + +func (r *RecurseCTE) expressions() []*plancontext.RecurseExpression { + return append(r.Predicates, r.Projections...) +} + +func (r *RecurseCTE) planOffsets(ctx *plancontext.PlanningContext) Operator { + r.Vars = make(map[string]int) + columns := r.Seed.GetColumns(ctx) + for _, expr := range r.expressions() { + outer: + for _, lhsExpr := range expr.LeftExprs { + _, found := r.Vars[lhsExpr.Name] + if found { + continue + } + + for offset, column := range columns { + if lhsExpr.Expr.Name.EqualString(column.ColumnName()) { + r.Vars[lhsExpr.Name] = offset + continue outer + } + } + + panic(vterrors.VT13001("couldn't find column")) + } + } + return r +} + +func (r *RecurseCTE) introducesTableID() semantics.TableSet { + return r.OuterID +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 58be17febab..00ac889c082 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -66,6 +66,10 @@ type PlanningContext struct { // OuterTables contains the tables that are outer to the current query // Used to set the nullable flag on the columns OuterTables semantics.TableSet + + // This is a stack of CTEs being built. It's used when we have CTEs inside CTEs, + // to remember which is the CTE currently being assembled + CurrentCTE []*ContextCTE } // CreatePlanningContext initializes a new PlanningContext with the given parameters. @@ -376,3 +380,43 @@ func (ctx *PlanningContext) ContainsAggr(e sqlparser.SQLNode) (hasAggr bool) { }, e) return } + +type ContextCTE struct { + *semantics.CTE + Id semantics.TableSet + Predicates []*RecurseExpression +} + +type RecurseExpression struct { + Original sqlparser.Expr + RightExpr sqlparser.Expr + LeftExprs []BindVarExpr +} + +type BindVarExpr struct { + Name string + Expr *sqlparser.ColName +} + +func (ctx *PlanningContext) PushCTE(def *semantics.CTE, id semantics.TableSet) { + ctx.CurrentCTE = append(ctx.CurrentCTE, &ContextCTE{ + CTE: def, + Id: id, + }) +} + +func (ctx *PlanningContext) PopCTE() (*ContextCTE, error) { + if len(ctx.CurrentCTE) == 0 { + return nil, vterrors.VT13001("no CTE to pop") + } + activeCTE := ctx.CurrentCTE[len(ctx.CurrentCTE)-1] + ctx.CurrentCTE = ctx.CurrentCTE[:len(ctx.CurrentCTE)-1] + return activeCTE, nil +} + +func (ctx *PlanningContext) ActiveCTE() *ContextCTE { + if len(ctx.CurrentCTE) == 0 { + return nil + } + return ctx.CurrentCTE[len(ctx.CurrentCTE)-1] +} diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index d6647681103..35470ce77d0 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2105,5 +2105,338 @@ "user.user" ] } + }, + { + "comment": "Recursive CTE that cannot be merged", + "query": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from user e inner join cte on e.manager_id = cte.id) select name from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from user e inner join cte on e.manager_id = cte.id) select name from cte", + "Instructions": { + "OperatorType": "SimpleProjection", + "ColumnNames": [ + "0:name" + ], + "Columns": "0", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 1 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, id from `user` where 1 != 1", + "Query": "select `name`, id from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name`, e.id from `user` as e where 1 != 1", + "Query": "select e.`name`, e.id from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Recursive CTE that cannot be merged 2", + "query": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from cte join user e on e.manager_id = cte.id) select name from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select name, id from user where manager_id is null union all select e.name, e.id from cte join user e on e.manager_id = cte.id) select name from cte", + "Instructions": { + "OperatorType": "SimpleProjection", + "ColumnNames": [ + "0:name" + ], + "Columns": "0", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 1 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, id from `user` where 1 != 1", + "Query": "select `name`, id from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name`, e.id from `user` as e where 1 != 1", + "Query": "select e.`name`, e.id from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Merge into a single dual route", + "query": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte AS (SELECT 1 as n UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT n FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with recursive cte as (select 1 as n from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with recursive cte as (select 1 as n from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } + }, + { + "comment": "Recursive CTE with star projection", + "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with recursive cte as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select n from cte where 1 != 1", + "Query": "with recursive cte as (select 1 from dual union all select n + 1 from cte where n < 5) select n from cte", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } + }, + { + "comment": "Recursive CTE calculations on the term side - merged", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL and id = 6 UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id and e.id = 6) SELECT * FROM emp_cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL and id = 6 UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id and e.id = 6) SELECT * FROM emp_cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive emp_cte as (select id, 1 as `level` from `user` where 1 != 1 union all select e.id, cte.`level` + 1 from cte as cte, `user` as e where 1 != 1) select id, `level` from emp_cte where 1 != 1", + "Query": "with recursive emp_cte as (select id, 1 as `level` from `user` where manager_id is null and id = 6 union all select e.id, cte.`level` + 1 from cte as cte, `user` as e where e.id = 6 and e.manager_id = cte.id) select id, `level` from emp_cte", + "Table": "`user`, dual", + "Values": [ + "6" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Recursive CTE calculations on the term side - unmerged", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, 1 AS level FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, cte.level + 1 FROM user e JOIN emp_cte cte ON e.manager_id = cte.id) SELECT * FROM emp_cte", + "Instructions": { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 0, + "cte_level": 1 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, 1 as `level` from `user` where 1 != 1", + "Query": "select id, 1 as `level` from `user` where manager_id is null", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.id, :cte_level + 1 as `cte.``level`` + 1` from `user` as e where 1 != 1", + "Query": "select e.id, :cte_level + 1 as `cte.``level`` + 1` from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Outer join with recursive CTE", + "query": "WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id UNION ALL SELECT id + 1, value * 2, id FROM literal_cte WHERE id < 5) SELECT l.id, l.value, l.manager_id, e.name AS employee_name FROM literal_cte l LEFT JOIN user e ON l.id = e.id", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE literal_cte AS (SELECT 1 AS id, 100 AS value, 1 AS manager_id UNION ALL SELECT id + 1, value * 2, id FROM literal_cte WHERE id < 5) SELECT l.id, l.value, l.manager_id, e.name AS employee_name FROM literal_cte l LEFT JOIN user e ON l.id = e.id", + "Instructions": { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0,L:1,L:2,R:0", + "JoinVars": { + "l_id": 0 + }, + "TableName": "dual_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual where 1 != 1 union all select id + 1, value * 2, id from literal_cte where 1 != 1) select l.id, l.value, l.manager_id from literal_cte as l where 1 != 1", + "Query": "with recursive literal_cte as (select 1 as id, 100 as value, 1 as manager_id from dual union all select id + 1, value * 2, id from literal_cte where id < 5) select l.id, l.value, l.manager_id from literal_cte as l", + "Table": "dual" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.`name` as employee_name from `user` as e where 1 != 1", + "Query": "select e.`name` as employee_name from `user` as e where e.id = :l_id", + "Table": "`user`", + "Values": [ + ":l_id" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Aggregation on the output of a recursive CTE", + "query": "WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, e.name, e.manager_id FROM user e INNER JOIN emp_cte cte ON e.manager_id = cte.id) SELECT manager_id, COUNT(*) AS employee_count FROM emp_cte GROUP BY manager_id", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id FROM user WHERE manager_id IS NULL UNION ALL SELECT e.id, e.name, e.manager_id FROM user e INNER JOIN emp_cte cte ON e.manager_id = cte.id) SELECT manager_id, COUNT(*) AS employee_count FROM emp_cte GROUP BY manager_id", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_star(1) AS employee_count", + "GroupBy": "(0|2)", + "ResultColumns": 2, + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + ":2 as manager_id", + "1 as 1", + "weight_string(:2) as weight_string(manager_id)" + ], + "Inputs": [ + { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "Inputs": [ + { + "OperatorType": "RecurseCTE", + "JoinVars": { + "cte_id": 0 + }, + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select dt.c0 as id, dt.c1 as `name`, dt.c2 as manager_id, weight_string(dt.c2) from (select id, `name`, manager_id from `user` where 1 != 1) as dt(c0, c1, c2) where 1 != 1", + "Query": "select dt.c0 as id, dt.c1 as `name`, dt.c2 as manager_id, weight_string(dt.c2) from (select id, `name`, manager_id from `user` where manager_id is null) as dt(c0, c1, c2)", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.id, e.`name`, e.manager_id, weight_string(e.manager_id) from `user` as e where 1 != 1", + "Query": "select e.id, e.`name`, e.manager_id, weight_string(e.manager_id) from `user` as e where e.manager_id = :cte_id", + "Table": "`user`, dual" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..9d653b2f6e9 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -3,7 +3,6 @@ "comment": "Add your test case here for debugging and run go test -run=One.", "query": "", "plan": { - } } ] \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 0e230b3e44d..9241cec595c 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -329,11 +329,6 @@ "query": "with user as (select aa from user where user.id=1) select ref.col from ref join user", "plan": "VT12001: unsupported: do not support CTE that use the CTE alias inside the CTE query" }, - { - "comment": "Recursive WITH", - "query": "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - "plan": "VT12001: unsupported: recursive common table expression" - }, { "comment": "Alias cannot clash with base tables", "query": "WITH user AS (SELECT col FROM user) SELECT * FROM user", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 8bb7cc393fc..ec42f638629 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -357,7 +357,7 @@ func (a *analyzer) collationEnv() *collations.Environment { } func (a *analyzer) analyze(statement sqlparser.Statement) error { - _ = sqlparser.Rewrite(statement, nil, a.earlyUp) + _ = sqlparser.Rewrite(statement, a.earlyTables.down, a.earlyTables.up) if a.err != nil { return a.err } @@ -424,13 +424,6 @@ func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) return true } -// earlyUp collects tables in the query, so we can check -// if this a single unsharded query we are dealing with -func (a *analyzer) earlyUp(cursor *sqlparser.Cursor) bool { - a.earlyTables.up(cursor) - return true -} - func (a *analyzer) shouldContinue() bool { return a.err == nil } @@ -455,6 +448,10 @@ func (a *analyzer) noteQuerySignature(node sqlparser.SQLNode) { if node.GroupBy != nil { a.sig.Aggregation = true } + case *sqlparser.With: + if node.Recursive { + a.sig.RecursiveCTE = true + } case sqlparser.AggrFunc: a.sig.Aggregation = true case *sqlparser.Delete, *sqlparser.Update, *sqlparser.Insert: diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 0fbf0911f3a..0c42456b0ab 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -195,6 +195,59 @@ func TestBindingMultiTablePositive(t *testing.T) { } } +func TestBindingRecursiveCTEs(t *testing.T) { + type testCase struct { + query string + rdeps TableSet + ddeps TableSet + } + queries := []testCase{{ + query: "with recursive x as (select id from user union select x.id + 1 from x where x.id < 15) select t.id from x join x t;", + rdeps: TS3, + ddeps: TS3, + }, { + query: "WITH RECURSIVE user_cte AS (SELECT id, name FROM user WHERE id = 42 UNION ALL SELECT u.id, u.name FROM user u JOIN user_cte cte ON u.id = cte.id + 1 WHERE u.id = 42) SELECT id FROM user_cte", + rdeps: TS3, + ddeps: TS3, + }} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + stmt, semTable := parseAndAnalyzeStrict(t, query.query, "user") + sel := stmt.(*sqlparser.Select) + assert.Equal(t, query.rdeps, semTable.RecursiveDeps(extract(sel, 0)), "recursive") + assert.Equal(t, query.ddeps, semTable.DirectDeps(extract(sel, 0)), "direct") + }) + } +} + +func TestRecursiveCTEChecking(t *testing.T) { + type testCase struct { + name, query, err string + } + queries := []testCase{{ + name: "recursive CTE using aggregation", + query: "with recursive x as (select id from user union select count(*) from x) select t.id from x join x t", + err: "VT09027: Recursive Common Table Expression 'x' can contain neither aggregation nor window functions in recursive query block", + }, { + name: "recursive CTE using grouping", + query: "with recursive x as (select id from user union select id+1 from x where id < 10 group by 1) select t.id from x join x t", + err: "VT09027: Recursive Common Table Expression 'x' can contain neither aggregation nor window functions in recursive query block", + }, { + name: "use the same recursive cte twice in definition", + query: "with recursive x as (select 1 union select id+1 from x where id < 10 union select id+2 from x where id < 20) select t.id from x", + err: "VT09029: In recursive query block of Recursive Common Table Expression x, the recursive table must be referenced only once, and not in any subquery", + }} + for _, tc := range queries { + t.Run(tc.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(tc.query) + require.NoError(t, err) + + _, err = AnalyzeStrict(parse, "user", fakeSchemaInfo()) + require.EqualError(t, err, tc.err) + }) + } +} + func TestBindingMultiAliasedTablePositive(t *testing.T) { type testCase struct { query string @@ -887,9 +940,6 @@ func TestInvalidQueries(t *testing.T) { }, { sql: "select 1 from t1 where (id, id) in (select 1, 2, 3)", serr: "Operand should contain 2 column(s)", - }, { - sql: "WITH RECURSIVE cte (n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5) SELECT * FROM cte", - serr: "VT12001: unsupported: recursive common table expression", }, { sql: "with x as (select 1), x as (select 1) select * from x", serr: "VT03013: not unique table/alias: 'x'", @@ -956,7 +1006,7 @@ func TestScopingWithWITH(t *testing.T) { }, { query: "with c as (select x as foo from user), t as (select foo as id from c) select id from t", recursive: TS0, - direct: TS3, + direct: TS2, }, { query: "with t as (select foo as id from user) select t.id from t", recursive: TS0, diff --git a/go/vt/vtgate/semantics/check_invalid.go b/go/vt/vtgate/semantics/check_invalid.go index a739e857c00..6509f5f5ee8 100644 --- a/go/vt/vtgate/semantics/check_invalid.go +++ b/go/vt/vtgate/semantics/check_invalid.go @@ -48,10 +48,6 @@ func (a *analyzer) checkForInvalidConstructs(cursor *sqlparser.Cursor) error { } case *sqlparser.Subquery: return a.checkSubqueryColumns(cursor.Parent(), node) - case *sqlparser.With: - if node.Recursive { - return vterrors.VT12001("recursive common table expression") - } case *sqlparser.Insert: if !a.singleUnshardedKeyspace && node.Action == sqlparser.ReplaceAct { return ShardedError{Inner: &UnsupportedConstruct{errString: "REPLACE INTO with sharded keyspace"}} diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go new file mode 100644 index 00000000000..320189ff871 --- /dev/null +++ b/go/vt/vtgate/semantics/cte_table.go @@ -0,0 +1,177 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + "strings" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +// CTETable contains the information about the CTE table. +// This is a special TableInfo that is used to represent the recursive table inside a CTE. For the query: +// WITH RECURSIVE cte AS (SELECT 1 UNION ALL SELECT * FROM cte as C1) SELECT * FROM cte as C2 +// The CTE table C1 is represented by a CTETable. +type CTETable struct { + TableName string + ASTNode *sqlparser.AliasedTableExpr + *CTE +} + +var _ TableInfo = (*CTETable)(nil) + +func newCTETable(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, cteDef *CTE) *CTETable { + var name string + if node.As.IsEmpty() { + name = t.Name.String() + } else { + name = node.As.String() + } + + authoritative := true + for _, expr := range cteDef.Query.GetColumns() { + _, isStar := expr.(*sqlparser.StarExpr) + if isStar { + authoritative = false + break + } + } + cteDef.isAuthoritative = authoritative + + return &CTETable{ + TableName: name, + ASTNode: node, + CTE: cteDef, + } +} + +func (cte *CTETable) Name() (sqlparser.TableName, error) { + return sqlparser.NewTableName(cte.TableName), nil +} + +func (cte *CTETable) GetVindexTable() *vindexes.Table { + return nil +} + +func (cte *CTETable) IsInfSchema() bool { + return false +} + +func (cte *CTETable) matches(name sqlparser.TableName) bool { + return cte.TableName == name.Name.String() && name.Qualifier.IsEmpty() +} + +func (cte *CTETable) authoritative() bool { + return cte.isAuthoritative +} + +func (cte *CTETable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { + return cte.ASTNode +} + +func (cte *CTETable) canShortCut() shortCut { + return canShortCut +} + +func (cte *CTETable) getColumns(bool) []ColumnInfo { + selExprs := cte.Query.GetColumns() + cols := make([]ColumnInfo, 0, len(selExprs)) + for i, selExpr := range selExprs { + ae, isAe := selExpr.(*sqlparser.AliasedExpr) + if !isAe { + panic(vterrors.VT12001("should not be called")) + } + if len(cte.Columns) == 0 { + cols = append(cols, ColumnInfo{Name: ae.ColumnName()}) + continue + } + + // We have column aliases defined on the CTE + cols = append(cols, ColumnInfo{Name: cte.Columns[i].String()}) + } + return cols +} + +func (cte *CTETable) dependencies(colName string, org originable) (dependencies, error) { + directDeps := org.tableSetFor(cte.ASTNode) + columns := cte.getColumns(false) + for _, columnInfo := range columns { + if strings.EqualFold(columnInfo.Name, colName) { + return createCertain(directDeps, directDeps, evalengine.NewUnknownType()), nil + } + } + + if cte.authoritative() { + return ¬hing{}, nil + } + + return createUncertain(directDeps, directDeps), nil +} + +func (cte *CTETable) getExprFor(s string) (sqlparser.Expr, error) { + for _, se := range cte.Query.GetColumns() { + ae, ok := se.(*sqlparser.AliasedExpr) + if !ok { + return nil, vterrors.VT09015() + } + if ae.ColumnName() == s { + return ae.Expr, nil + } + } + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Unknown column '%s' in 'field list'", s) +} + +func (cte *CTETable) getTableSet(org originable) TableSet { + return org.tableSetFor(cte.ASTNode) +} + +type CTE struct { + Name string + Query sqlparser.SelectStatement + isAuthoritative bool + recursiveDeps *TableSet + Columns sqlparser.Columns + IDForRecurse *TableSet + + // Was this CTE marked for being recursive? + Recursive bool + + // The CTE had the seed and term parts merged + Merged bool +} + +func (cte *CTE) recursive(org originable) (id TableSet) { + if cte.recursiveDeps != nil { + return *cte.recursiveDeps + } + + // We need to find the recursive dependencies of the CTE + // We'll do this by walking the inner query and finding all the tables + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + ate, ok := node.(*sqlparser.AliasedTableExpr) + if !ok { + return true, nil + } + id = id.Merge(org.tableSetFor(ate)) + return true, nil + }, cte.Query) + return +} diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index aabbe9f0b22..684966f8ac8 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -146,7 +146,7 @@ func (dt *DerivedTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { } func (dt *DerivedTable) canShortCut() shortCut { - panic(vterrors.VT12001("should not be called")) + return canShortCut } // GetVindexTable implements the TableInfo interface diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 611c91e512c..ee12765e984 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -57,7 +57,9 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) case *sqlparser.With: - return r.handleWith(node) + if !node.Recursive { + return r.handleWith(node) + } case *sqlparser.AliasedTableExpr: return r.handleAliasedTable(node) case *sqlparser.Delete: @@ -144,7 +146,7 @@ func (r *earlyRewriter) handleAliasedTable(node *sqlparser.AliasedTableExpr) err node.As = tbl.Name } node.Expr = &sqlparser.DerivedTable{ - Select: cte.Subquery.Select, + Select: cte.Subquery, } if len(cte.Columns) > 0 { node.Columns = cte.Columns diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 16b3756189f..fab8211f74e 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -184,6 +184,9 @@ func TestExpandStar(t *testing.T) { // if we are only star-expanding authoritative tables, we don't need to stop the expansion sql: "SELECT * FROM (SELECT t2.*, 12 AS foo FROM t3, t2) as results", expSQL: "select c1, c2, foo from (select t2.c1, t2.c2, 12 as foo from t3, t2) as results", + }, { + sql: "with recursive hierarchy as (select t1.a, t1.b from t1 where t1.a is null union select t1.a, t1.b from t1 join hierarchy on t1.a = hierarchy.b) select * from hierarchy", + expSQL: "with recursive hierarchy as (select t1.a, t1.b from t1 where t1.a is null union select t1.a, t1.b from t1 join hierarchy on t1.a = hierarchy.b) select a, b from hierarchy", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { diff --git a/go/vt/vtgate/semantics/foreign_keys_test.go b/go/vt/vtgate/semantics/foreign_keys_test.go index e1c26ecf569..a46c67c9710 100644 --- a/go/vt/vtgate/semantics/foreign_keys_test.go +++ b/go/vt/vtgate/semantics/foreign_keys_test.go @@ -141,13 +141,10 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "Collect all foreign key constraints", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - &DerivedTable{}, - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + &DerivedTable{}), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -171,12 +168,10 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "keyspace not found in schema information", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -188,12 +183,9 @@ func TestGetAllManagedForeignKeys(t *testing.T) { { name: "Cyclic fk constraints error", fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], tbl["t1"], - &DerivedTable{}, - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], tbl["t1"], + &DerivedTable{}), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -236,17 +228,11 @@ func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { }, }, getError: func() error { return fmt.Errorf("ambiguous test error") }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t4"], - tbl["t5"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, + tables: makeTableCollector(&FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }}, tbl["t4"], + tbl["t5"]), } updateExprs := sqlparser.UpdateExprs{ &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, @@ -350,12 +336,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Delete Query", stmt: &sqlparser.Delete{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -389,12 +373,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { cold: SingleTableSet(1), }, }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t4"], - tbl["t5"], - }, - }, + tables: makeTableCollector(nil, + tbl["t4"], + tbl["t5"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -433,12 +415,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { Action: sqlparser.ReplaceAct, }, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"], + ), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -465,12 +445,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { Action: sqlparser.InsertAct, }, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t0"], + tbl["t1"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -502,12 +479,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { colb: SingleTableSet(0), }, }, - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t6"], - tbl["t1"], - }, - }, + tables: makeTableCollector(nil, + tbl["t6"], + tbl["t1"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -536,12 +510,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Insert error", stmt: &sqlparser.Insert{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -554,12 +525,9 @@ func TestGetInvolvedForeignKeys(t *testing.T) { name: "Update error", stmt: &sqlparser.Update{}, fkManager: &fkManager{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - }, + tables: makeTableCollector(nil, + tbl["t2"], + tbl["t3"]), si: &FakeSI{ KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ "ks": vschemapb.Keyspace_managed, @@ -600,3 +568,12 @@ func pkInfo(parentTable *vindexes.Table, pCols []string, cCols []string) vindexe ChildColumns: sqlparser.MakeColumns(cCols...), } } + +func makeTableCollector(si SchemaInformation, tables ...TableInfo) *tableCollector { + return &tableCollector{ + earlyTableCollector: earlyTableCollector{ + Tables: tables, + si: si, + }, + } +} diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 4f1639d0897..399395a9edf 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -20,9 +20,11 @@ import ( "strings" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/slice" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -31,6 +33,7 @@ type RealTable struct { dbName, tableName string ASTNode *sqlparser.AliasedTableExpr Table *vindexes.Table + CTE *CTE VindexHint *sqlparser.IndexHint isInfSchema bool collationEnv *collations.Environment @@ -70,9 +73,17 @@ func (r *RealTable) IsInfSchema() bool { // GetColumns implements the TableInfo interface func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { - if r.Table == nil { + switch { + case r.CTE != nil: + return r.getCTEColumns() + case r.Table == nil: return nil + default: + return r.getVindexTableColumns(ignoreInvisbleCol) } +} + +func (r *RealTable) getVindexTableColumns(ignoreInvisbleCol bool) []ColumnInfo { nameMap := map[string]any{} cols := make([]ColumnInfo, 0, len(r.Table.Columns)) for _, col := range r.Table.Columns { @@ -105,6 +116,57 @@ func (r *RealTable) getColumns(ignoreInvisbleCol bool) []ColumnInfo { return cols } +func (r *RealTable) getCTEColumns() []ColumnInfo { + selectExprs := r.CTE.Query.GetColumns() + ci := extractColumnsFromCTE(r.CTE.Columns, selectExprs) + if ci != nil { + return ci + } + return extractSelectExprsFromCTE(selectExprs) +} + +// Authoritative implements the TableInfo interface +func (r *RealTable) authoritative() bool { + switch { + case r.Table != nil: + return r.Table.ColumnListAuthoritative + case r.CTE != nil: + return r.CTE.isAuthoritative + default: + return false + } +} + +func extractSelectExprsFromCTE(selectExprs sqlparser.SelectExprs) []ColumnInfo { + var ci []ColumnInfo + for _, expr := range selectExprs { + ae, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return nil + } + ci = append(ci, ColumnInfo{ + Name: ae.ColumnName(), + Type: evalengine.NewUnknownType(), // TODO: set the proper type + }) + } + return ci +} + +func extractColumnsFromCTE(columns sqlparser.Columns, selectExprs sqlparser.SelectExprs) []ColumnInfo { + if len(columns) == 0 { + return nil + } + if len(selectExprs) != len(columns) { + panic("mismatch of columns") + } + return slice.Map(columns, func(from sqlparser.IdentifierCI) ColumnInfo { + return ColumnInfo{ + Name: from.String(), + Type: evalengine.NewUnknownType(), + } + }) +} + // GetExpr implements the TableInfo interface func (r *RealTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { return r.ASTNode @@ -145,11 +207,6 @@ func (r *RealTable) Name() (sqlparser.TableName, error) { return r.ASTNode.TableName() } -// Authoritative implements the TableInfo interface -func (r *RealTable) authoritative() bool { - return r.Table != nil && r.Table.ColumnListAuthoritative -} - // Matches implements the TableInfo interface func (r *RealTable) matches(name sqlparser.TableName) bool { return (name.Qualifier.IsEmpty() || name.Qualifier.String() == r.dbName) && r.tableName == name.Name.String() diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index ae3e5b7e88d..9d596d9ecd1 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -35,9 +35,10 @@ type ( binder *binder // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 - specialExprScopes map[*sqlparser.Literal]*scope - statementIDs map[sqlparser.Statement]TableSet - si SchemaInformation + specialExprScopes map[*sqlparser.Literal]*scope + statementIDs map[sqlparser.Statement]TableSet + commonTableExprScopes []*sqlparser.CommonTableExpr + si SchemaInformation } scope struct { @@ -105,6 +106,8 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { s.currentScope().inHaving = true return nil } + case *sqlparser.CommonTableExpr: + s.commonTableExprScopes = append(s.commonTableExprScopes, node) } return nil } @@ -240,6 +243,9 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { case sqlparser.AggrFunc: s.currentScope().inHavingAggr = false case sqlparser.TableExpr: + // inside joins and derived tables, we can only see the tables in the table/join. + // we also want the tables available in the outer query, for SELECT expressions and the WHERE clause, + // so we copy the tables from the current scope to the parent scope if isParentSelect(cursor) { curScope := s.currentScope() s.popScope() @@ -258,6 +264,8 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { s.binder.usingJoinInfo[ts] = m } } + case *sqlparser.CommonTableExpr: + s.commonTableExprScopes = s.commonTableExprScopes[:len(s.commonTableExprScopes)-1] } return nil } @@ -367,7 +375,7 @@ func checkForInvalidAliasUse(cte *sqlparser.CommonTableExpr, name string) (err e } return err == nil } - _ = sqlparser.CopyOnRewrite(cte.Subquery.Select, down, nil, nil) + _ = sqlparser.CopyOnRewrite(cte.Subquery, down, nil, nil) return err } diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_table.go similarity index 99% rename from go/vt/vtgate/semantics/semantic_state.go rename to go/vt/vtgate/semantics/semantic_table.go index ac2fd9c1604..6738546fe37 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -77,12 +77,13 @@ type ( // QuerySignature is used to identify shortcuts in the planning process QuerySignature struct { - Aggregation bool - DML bool - Distinct bool - HashJoin bool - SubQueries bool - Union bool + Aggregation bool + DML bool + Distinct bool + HashJoin bool + SubQueries bool + Union bool + RecursiveCTE bool } // SemTable contains semantic analysis information about the query. @@ -773,10 +774,6 @@ func singleUnshardedKeyspace(tableInfos []TableInfo) (ks *vindexes.Keyspace, tab } for _, table := range tableInfos { - if _, isDT := table.(*DerivedTable); isDT { - continue - } - sc := table.canShortCut() var vtbl *vindexes.Table diff --git a/go/vt/vtgate/semantics/semantic_state_test.go b/go/vt/vtgate/semantics/semantic_table_test.go similarity index 100% rename from go/vt/vtgate/semantics/semantic_state_test.go rename to go/vt/vtgate/semantics/semantic_table_test.go diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 948edb37d47..16285307846 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -28,45 +28,59 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -// tableCollector is responsible for gathering information about the tables listed in the FROM clause, -// and adding them to the current scope, plus keeping the global list of tables used in the query -type tableCollector struct { - Tables []TableInfo - scoper *scoper - si SchemaInformation - currentDb string - org originable - unionInfo map[*sqlparser.Union]unionInfo - done map[*sqlparser.AliasedTableExpr]TableInfo -} +type ( + // tableCollector is responsible for gathering information about the tables listed in the FROM clause, + // and adding them to the current scope, plus keeping the global list of tables used in the query + tableCollector struct { + earlyTableCollector + scoper *scoper + org originable + unionInfo map[*sqlparser.Union]unionInfo + } -type earlyTableCollector struct { - si SchemaInformation - currentDb string - Tables []TableInfo - done map[*sqlparser.AliasedTableExpr]TableInfo - withTables map[sqlparser.IdentifierCS]any -} + earlyTableCollector struct { + si SchemaInformation + currentDb string + Tables []TableInfo + done map[*sqlparser.AliasedTableExpr]TableInfo + + // cte is a map of CTE definitions that are used in the query + cte map[string]*CTE + } +) func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { return &earlyTableCollector{ - si: si, - currentDb: currentDb, - done: map[*sqlparser.AliasedTableExpr]TableInfo{}, - withTables: map[sqlparser.IdentifierCS]any{}, + si: si, + currentDb: currentDb, + done: map[*sqlparser.AliasedTableExpr]TableInfo{}, + cte: map[string]*CTE{}, } } -func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { - switch node := cursor.Node().(type) { - case *sqlparser.AliasedTableExpr: - etc.visitAliasedTableExpr(node) - case *sqlparser.With: - for _, cte := range node.CTEs { - etc.withTables[cte.ID] = nil +func (etc *earlyTableCollector) down(cursor *sqlparser.Cursor) bool { + with, ok := cursor.Node().(*sqlparser.With) + if !ok { + return true + } + for _, cte := range with.CTEs { + etc.cte[cte.ID.String()] = &CTE{ + Name: cte.ID.String(), + Query: cte.Subquery, + Columns: cte.Columns, + Recursive: with.Recursive, } } + return true +} +func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) bool { + ate, ok := cursor.Node().(*sqlparser.AliasedTableExpr) + if !ok { + return true + } + etc.visitAliasedTableExpr(ate) + return true } func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { @@ -79,25 +93,22 @@ func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTabl func (etc *earlyTableCollector) newTableCollector(scoper *scoper, org originable) *tableCollector { return &tableCollector{ - Tables: etc.Tables, - scoper: scoper, - si: etc.si, - currentDb: etc.currentDb, - unionInfo: map[*sqlparser.Union]unionInfo{}, - done: etc.done, - org: org, + earlyTableCollector: *etc, + scoper: scoper, + unionInfo: map[*sqlparser.Union]unionInfo{}, + org: org, } } func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sqlparser.AliasedTableExpr) { if tbl.Qualifier.IsEmpty() { - _, isCTE := etc.withTables[tbl.Name] + _, isCTE := etc.cte[tbl.Name.String()] if isCTE { // no need to handle these tables here, we wait for the late phase instead return } } - tableInfo, err := getTableInfo(aet, tbl, etc.si, etc.currentDb) + tableInfo, err := etc.getTableInfo(aet, tbl, nil) if err != nil { // this could just be a CTE that we haven't processed, so we'll give it the benefit of the doubt for now return @@ -304,7 +315,7 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq tableInfo, found = tc.done[node] if !found { - tableInfo, err = getTableInfo(node, t, tc.si, tc.currentDb) + tableInfo, err = tc.earlyTableCollector.getTableInfo(node, t, tc.scoper) if err != nil { return err } @@ -315,12 +326,32 @@ func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sq return scope.addTable(tableInfo) } -func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si SchemaInformation, currentDb string) (TableInfo, error) { +func (etc *earlyTableCollector) getCTE(t sqlparser.TableName) *CTE { + if t.Qualifier.NotEmpty() { + return nil + } + + return etc.cte[t.Name.String()] +} + +func (etc *earlyTableCollector) getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper) (TableInfo, error) { var tbl *vindexes.Table var vindex vindexes.Vindex + if cteDef := etc.getCTE(t); cteDef != nil { + cte, err := etc.buildRecursiveCTE(node, t, sc, cteDef) + if err != nil { + return nil, err + } + if cte != nil { + // if we didn't get a table, it means we can't build a recursive CTE, + // so we need to look for a regular table instead + return cte, nil + } + } + isInfSchema := sqlparser.SystemSchema(t.Qualifier.String()) var err error - tbl, vindex, _, _, _, err = si.FindTableOrVindex(t) + tbl, vindex, _, _, _, err = etc.si.FindTableOrVindex(t) if err != nil && !isInfSchema { // if we are dealing with a system table, it might not be available in the vschema, but that is OK return nil, err @@ -329,13 +360,64 @@ func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si Sc tbl = newVindexTable(t.Name) } - tableInfo, err := createTable(t, node, tbl, isInfSchema, vindex, si, currentDb) + tableInfo, err := etc.createTable(t, node, tbl, isInfSchema, vindex) if err != nil { return nil, err } return tableInfo, nil } +func (etc *earlyTableCollector) buildRecursiveCTE(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, sc *scoper, cteDef *CTE) (TableInfo, error) { + // If sc is nil, then we are in the early table collector. + // In early table collector, we don't go over the CTE definitions, so we must be seeing a usage of the CTE. + if sc != nil && len(sc.commonTableExprScopes) > 0 { + cte := sc.commonTableExprScopes[len(sc.commonTableExprScopes)-1] + if cte.ID.String() == t.Name.String() { + + if err := checkValidRecursiveCTE(cteDef); err != nil { + return nil, err + } + + cteTable := newCTETable(node, t, cteDef) + cteTableSet := SingleTableSet(len(etc.Tables)) + cteDef.IDForRecurse = &cteTableSet + if !cteDef.Recursive { + return nil, nil + } + return cteTable, nil + } + } + return &RealTable{ + tableName: node.TableNameString(), + ASTNode: node, + CTE: cteDef, + collationEnv: etc.si.Environment().CollationEnv(), + }, nil +} + +func checkValidRecursiveCTE(cteDef *CTE) error { + if cteDef.IDForRecurse != nil { + return vterrors.VT09029(cteDef.Name) + } + + union, isUnion := cteDef.Query.(*sqlparser.Union) + if !isUnion { + return vterrors.VT09026(cteDef.Name) + } + + firstSelect := sqlparser.GetFirstSelect(union.Right) + if firstSelect.GroupBy != nil { + return vterrors.VT09027(cteDef.Name) + } + + for _, expr := range firstSelect.GetColumns() { + if sqlparser.ContainsAggregation(expr) { + return vterrors.VT09027(cteDef.Name) + } + } + return nil +} + func (tc *tableCollector) handleDerivedTable(node *sqlparser.AliasedTableExpr, t *sqlparser.DerivedTable) error { switch sel := t.Select.(type) { case *sqlparser.Select: @@ -437,14 +519,12 @@ func (tc *tableCollector) tableInfoFor(id TableSet) (TableInfo, error) { return tc.Tables[offset], nil } -func createTable( +func (etc *earlyTableCollector) createTable( t sqlparser.TableName, alias *sqlparser.AliasedTableExpr, tbl *vindexes.Table, isInfSchema bool, vindex vindexes.Vindex, - si SchemaInformation, - currentDb string, ) (TableInfo, error) { hint := getVindexHint(alias.Hints) @@ -458,13 +538,13 @@ func createTable( Table: tbl, VindexHint: hint, isInfSchema: isInfSchema, - collationEnv: si.Environment().CollationEnv(), + collationEnv: etc.si.Environment().CollationEnv(), } if alias.As.IsEmpty() { dbName := t.Qualifier.String() if dbName == "" { - dbName = currentDb + dbName = etc.currentDb } table.dbName = dbName