diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cca98ae52..9398eb74c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,22 +25,20 @@ Thank you to all who have contributed! ## [0.14.9] -### Added - ### Changed - With full, closed schema, the planner will now give a plan-time warning when it can prove an exclude path will never exclude a value (relevant issue -- https://github.com/partiql/partiql-lang/issues/91). -### Deprecated - -### Fixed - -### Removed - -### Security +### Experimental Changes +- **BREAKING**: For the _experimental_ `org.partiql.lang.domains` of `PartiqlLogical`, `PartiqlLogicalResolved`, and `PartiqlPhysical`, +the modeling of DML has changed substantially. These changes, while considered breaking changes, are part of an +experimental area of the PartiQL library and thus do not mandate a major-version bump of this library. Consumers +of these experimental APIs should be wary of these changes. ### Contributors - @alancai98 +- @dlurton +- @johnedquinn ## [0.14.8] diff --git a/partiql-ast/src/main/pig/partiql.ion b/partiql-ast/src/main/pig/partiql.ion index 7c845ac441..982be08078 100644 --- a/partiql-ast/src/main/pig/partiql.ion +++ b/partiql-ast/src/main/pig/partiql.ion @@ -802,55 +802,83 @@ may then be further optimized by selecting better implementations of each operat ) ) + (with statement + (exclude dml) + (include + // An `INSERT` DML operation, which is fundamentally different from UPDATE and DELETE + // because it lacks a FROM and WHERE clause while also including an ON CONFLICT clause. + // + // Models: INSERT INTO [AS ] [] + (dml_insert + // The target is an expression indicates the table whose data is to be manipulated. + // With current PartiQL Parser `SqlParser`, this can be an identifier or a simplified path expression + // consisting of only literal path steps (and with no wildcard or unpivot operators). + // Note: partiql_ast uses the `expr` sum type for this, which is too broad. We're not + // changing that at this time because `partiql_ast` is established public API, but we can + // use dml_target instead which has the properly narrowed domain. + (target dml_target) + (target_alias var_decl) + (rows_to_insert expr) + (on_conflict (? on_conflict)) + ) + + // Models: UPDATE [AS ] SET [WHERE ] + (dml_update + (target dml_target) + (target_alias var_decl) + (assignments (* set_assignment 0)) + (where (? expr)) + ) + + // Models DELETE + (dml_delete (from bexpr)) // note: the bexpr includes filters, etc. + ) + ) + (include - // Indicates kind of DML operation. - (sum dml_operation - (dml_insert target_alias::var_decl) - (dml_delete) - - // Represents the REPLACE statement as well as INSERT ... ON CONFLICT DO REPLACE ... - // [target-alias]: represents the alias for the table name. See the following syntactical example: - // `INSERT INTO Table1 AS << { 'id': 1, 'name': 'Arash' } >> ON CONFLICT DO REPLACE ...` - // [condition]: represents the condition by which a row should be replaced. See the following syntactical example: - // `INSERT INTO x << {'id': 1, 'name': 'John'}} >> ON CONFLICT DO REPLACE EXCLUDED WHERE ` - // [row_alias]: represents the alias given to the rows meant to be inserted/replaced. It is made optional - // since dml_replace is currently shared by REPLACE (which does not allow the aliasing of rows) and - // INSERT ... ON CONFLICT DO REPLACE ... (which aliases the rows as "EXCLUDED" for use within the [condition]). - (dml_replace target_alias::var_decl condition::(? expr) row_alias::(? var_decl)) - - // Represents the UPSERT statement as well as INSERT ... ON CONFLICT DO UPDATE ... - // [target-alias]: represents the alias for the table name. See the following syntactical example: - // `INSERT INTO Table1 AS << { 'id': 1, 'name': 'Arash' } >> ON CONFLICT DO UPDATE ...` - // [condition]: represents the condition by which a row should be replaced. See the following syntactical example: - // `INSERT INTO x << {'id': 1, 'name': 'John'}} >> ON CONFLICT DO UPDATE EXCLUDED WHERE ` - // [row_alias]: represents the alias given to the rows meant to be inserted/updated. It is made optional - // since dml_update is currently shared by UPSERT (which does not allow the aliasing of rows) and - // INSERT ... ON CONFLICT DO UPDATE ... (which aliases the rows as "EXCLUDED" for use within the [condition]). - (dml_update target_alias::var_decl condition::(? expr) row_alias::(? var_decl)) + // represents simple paths, i.e. suitable for the left side of an `=` operator within a `SET` clause. + // Example `a_field.nested_field[42]` + (record simple_path + // The first element, `a_field` in the example above. + (root identifier) + // The subsequent elements, `nested_field` and `[42]` in the example above. + (steps (* simple_path_step 0)) + ) + (sum simple_path_step + // for bracket paths steps, i.e. `[42]` in the simple_path example above. + (sps_index (index int)) + // for symbols, i.e. `nested_field` in the simple_path example above. + (sps_identifier (identifier identifier)) ) - ) - // Redefine statement.dml to be a simpler subset of the full DML functionality expressed with PartiQL's DML - // syntax. Full functionality is out of scope for now. This is factored minimally to support - // `INSERT INTO` and `DELETE FROM ... [WHERE ]` but this may need refactoring when - // `FROM ... UPDATE` and `UPDATE` is supported later. - (with statement - (exclude dml) - (include - // A DML operation, such as `INSERT`, `UPDATE` or `DELETE` - (dml - // The target is an expression that is indicates the table whose data is to be manipulated. - // With current PartiQL Parser `SqlParser`, this can be an identifier or a simplified path expression - // consisting of only literal path steps (and with no wildcard or unpivot operators). - // Note: partiql_ast uses the `expr` sum type for this, which is too broad. We're not - // changing that at this time because `partiql_ast` is established public API. - target::identifier - operation::dml_operation - rows::expr - ) + // The "target" of a DML operation, i.e. the table targeted for manipulation with INSERT, UPDATE, etc. + // This is a discrete type so it can be permuted in later domains to affect every use. + (record dml_target (identifier identifier)) + + // An assignment within a SET clause. + (record set_assignment + // The target, left of `=` + (set_target simple_path) + // The new value for the target, right of `=` + (value expr) ) + + // INSERT's ON CONFLICT Clause + (record on_conflict + (excluded_alias var_decl) + (condition (? expr)) + (action on_conflict_action) + ) + + (sum on_conflict_action + (do_update) + (do_replace) + ) + ) + + // Nodes excluded below this line will eventually have a representation in the logical algebra, but not // initially. @@ -934,11 +962,9 @@ may then be further optimized by selecting better implementations of each operat ) ) - // Replace statement.dml.target with statement.dml.uniqueId (the "resolved" corollary). - (with statement - (exclude dml) - (include (dml uniqueId::symbol operation::dml_operation rows::expr)) - ) + // Replace statement.dml.uniqueId (the "resolved" corollary). + (exclude dml_target) + (include (record dml_target (uniqueId symbol))) ) ) diff --git a/partiql-coverage/src/test/kotlin/org/partiql/coverage/api/impl/PartiQLTestExtensionTest.kt b/partiql-coverage/src/test/kotlin/org/partiql/coverage/api/impl/PartiQLTestExtensionTest.kt index 6845eba0e9..8ef3bad841 100644 --- a/partiql-coverage/src/test/kotlin/org/partiql/coverage/api/impl/PartiQLTestExtensionTest.kt +++ b/partiql-coverage/src/test/kotlin/org/partiql/coverage/api/impl/PartiQLTestExtensionTest.kt @@ -31,7 +31,6 @@ import org.partiql.coverage.api.PartiQLTest import org.partiql.coverage.api.PartiQLTestCase import org.partiql.coverage.api.PartiQLTestProvider import org.partiql.lang.CompilerPipeline -import org.partiql.lang.eval.EvaluationSession import org.partiql.lang.eval.PartiQLResult import java.lang.reflect.AnnotatedElement import java.lang.reflect.Method @@ -76,19 +75,9 @@ class PartiQLTestExtensionTest { fun test2(tc: PartiQLTestCase, result: PartiQLResult.Value) { } - @Disabled - @PartiQLTest(provider = MockProvider::class) - @JvmName("test3") - @Suppress("UNUSED") - fun test3(tc: ValidTestCase, result: PartiQLResult.Delete) { - } - - class ValidTestCase(override val session: EvaluationSession) : PartiQLTestCase - override fun provideArguments(context: ExtensionContext?): Stream = listOf( AbstractExtensionContext(ValidSignaturesProvider::class.java, "test1", PartiQLTestCase::class.java, PartiQLResult::class.java), AbstractExtensionContext(ValidSignaturesProvider::class.java, "test2", PartiQLTestCase::class.java, PartiQLResult.Value::class.java), - AbstractExtensionContext(ValidSignaturesProvider::class.java, "test3", ValidTestCase::class.java, PartiQLResult.Delete::class.java), ).map { Arguments.of(it) }.stream() } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt index 0b98276430..6a326c7da0 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerAsyncDefault.kt @@ -63,7 +63,9 @@ internal class PartiQLCompilerAsyncDefault( override suspend fun compile(statement: PartiqlPhysical.Plan): PartiQLStatementAsync { return when (val stmt = statement.stmt) { - is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Exec, is PartiqlPhysical.Statement.Query -> { val expression = exprConverter.compile(statement) @@ -75,7 +77,9 @@ internal class PartiQLCompilerAsyncDefault( override suspend fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails): PartiQLStatementAsync { return when (val stmt = statement.stmt) { - is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Exec, is PartiqlPhysical.Statement.Query -> compile(statement) is PartiqlPhysical.Statement.Explain -> PartiQLStatementAsync { compileExplain(stmt, details) } @@ -93,18 +97,6 @@ internal class PartiQLCompilerAsyncDefault( PHYSICAL_TRANSFORMED } - private suspend fun compileDml(dml: PartiqlPhysical.Statement.Dml, localsSize: Int): PartiQLStatementAsync { - val rows = exprConverter.compile(dml.rows, localsSize) - return PartiQLStatementAsync { session -> - when (dml.operation) { - is PartiqlPhysical.DmlOperation.DmlReplace -> PartiQLResult.Replace(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) - is PartiqlPhysical.DmlOperation.DmlInsert -> PartiQLResult.Insert(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) - is PartiqlPhysical.DmlOperation.DmlDelete -> PartiQLResult.Delete(dml.uniqueId.text, (rows.eval(session) as PartiQLResult.Value).value) - is PartiqlPhysical.DmlOperation.DmlUpdate -> TODO("DML Update compilation not supported yet.") - } - } - } - private fun compileExplain(statement: PartiqlPhysical.Statement.Explain, details: PartiQLPlanner.PlanningDetails): PartiQLResult.Explain.Domain { return when (val target = statement.target) { is PartiqlPhysical.ExplainTarget.Domain -> compileExplainDomain(target, details) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt index fbd6e426d5..791055a4f9 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/compiler/PartiQLCompilerDefault.kt @@ -65,7 +65,9 @@ internal class PartiQLCompilerDefault( override fun compile(statement: PartiqlPhysical.Plan): PartiQLStatement { return when (val stmt = statement.stmt) { - is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Exec, is PartiqlPhysical.Statement.Query -> { val expression = exprConverter.compile(statement) @@ -77,7 +79,9 @@ internal class PartiQLCompilerDefault( override fun compile(statement: PartiqlPhysical.Plan, details: PartiQLPlanner.PlanningDetails): PartiQLStatement { return when (val stmt = statement.stmt) { - is PartiqlPhysical.Statement.Dml -> compileDml(stmt, statement.locals.size) + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Exec, is PartiqlPhysical.Statement.Query -> compile(statement) is PartiqlPhysical.Statement.Explain -> PartiQLStatement { compileExplain(stmt, details) } @@ -95,18 +99,6 @@ internal class PartiQLCompilerDefault( PHYSICAL_TRANSFORMED } - private fun compileDml(dml: PartiqlPhysical.Statement.Dml, localsSize: Int): PartiQLStatement { - val rows = exprConverter.compile(dml.rows, localsSize) - return PartiQLStatement { session -> - when (dml.operation) { - is PartiqlPhysical.DmlOperation.DmlReplace -> PartiQLResult.Replace(dml.uniqueId.text, rows.eval(session)) - is PartiqlPhysical.DmlOperation.DmlInsert -> PartiQLResult.Insert(dml.uniqueId.text, rows.eval(session)) - is PartiqlPhysical.DmlOperation.DmlDelete -> PartiQLResult.Delete(dml.uniqueId.text, rows.eval(session)) - is PartiqlPhysical.DmlOperation.DmlUpdate -> TODO("DML Update compilation not supported yet.") - } - } - } - private fun compileExplain(statement: PartiqlPhysical.Statement.Explain, details: PartiQLPlanner.PlanningDetails): PartiQLResult.Explain.Domain { return when (val target = statement.target) { is PartiqlPhysical.ExplainTarget.Domain -> compileExplainDomain(target, details) diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt index 89430245e7..70f780aab1 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerAsyncImpl.kt @@ -226,7 +226,9 @@ internal class PhysicalPlanCompilerAsyncImpl( return when (ast) { is PartiqlPhysical.Statement.Query -> compileAstExpr(ast.expr) is PartiqlPhysical.Statement.Exec -> compileExec(ast) - is PartiqlPhysical.Statement.Dml, + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Explain -> { val value = ExprValue.newBoolean(true) thunkFactory.thunkEnvAsync(emptyMetaContainer()) { value } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerImpl.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerImpl.kt index 24f66599fe..58eeee1052 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerImpl.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/eval/physical/PhysicalPlanCompilerImpl.kt @@ -241,7 +241,9 @@ internal class PhysicalPlanCompilerImpl( return when (ast) { is PartiqlPhysical.Statement.Query -> compileAstExpr(ast.expr) is PartiqlPhysical.Statement.Exec -> compileExec(ast) - is PartiqlPhysical.Statement.Dml, + is PartiqlPhysical.Statement.DmlDelete, + is PartiqlPhysical.Statement.DmlInsert, + is PartiqlPhysical.Statement.DmlUpdate -> TODO("DML compilation not supported.") is PartiqlPhysical.Statement.Explain -> { val value = ExprValue.newBoolean(true) thunkFactory.thunkEnv(emptyMetaContainer()) { value } diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt index 376b3fc1fa..a5d1b8cc72 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransform.kt @@ -1,5 +1,6 @@ package org.partiql.lang.planner.transforms +import com.amazon.ionelement.api.ElementType import com.amazon.ionelement.api.emptyMetaContainer import com.amazon.ionelement.api.ionString import com.amazon.ionelement.api.ionSymbol @@ -17,6 +18,7 @@ import org.partiql.lang.eval.builtins.CollectionAggregationFunction import org.partiql.lang.eval.builtins.ExprFunctionCurrentUser import org.partiql.lang.eval.err import org.partiql.lang.eval.errorContextFrom +import org.partiql.lang.eval.physical.sourceLocationMeta import org.partiql.lang.eval.physical.sourceLocationMetaOrUnknown import org.partiql.lang.eval.visitors.VisitorTransformBase import org.partiql.lang.planner.PlanningProblemDetails @@ -186,6 +188,10 @@ internal class AstToLogicalVisitorTransform( } } + override fun transformOnConflict(node: PartiqlAst.OnConflict): PartiqlLogical.OnConflict { + error("Something is wrong--transformation of PartiqlAst.OnConflict has to be handled elsewhere!") + } + private fun convertGroupAsAlias(node: SymbolPrimitive, from: PartiqlAst.FromSource) = PartiqlLogical.build { val sourceAliases = getSourceAliases(from) val structFields = sourceAliases.map { alias -> @@ -341,15 +347,11 @@ internal class AstToLogicalVisitorTransform( override fun transformStatementDml(node: PartiqlAst.Statement.Dml): PartiqlLogical.Statement { require(node.operations.ops.isNotEmpty()) - // `INSERT` and `DELETE` statements are all that's needed for the current effort--and it just so - // happens that these never utilize more than one DML operation anyway. We don't need to - // support more than one DML operation until we start supporting UPDATE statements. - if (node.operations.ops.size > 1) { - problemHandler.handleUnimplementedFeature(node, "more than one DML operation") - } - return when (val dmlOp = node.operations.ops.first()) { is PartiqlAst.DmlOp.Insert -> { + if (node.operations.ops.size > 1) { + error("Malformed AST: more than 1 INSERT DML operation!") + } node.from?.let { problemHandler.handleUnimplementedFeature(dmlOp, "UPDATE / INSERT") } // Check for and block `INSERT INTO VALUES (...)` This is *no* way to support this // within without the optional comma separated list of columns that precedes `VALUES` since doing so @@ -368,36 +370,20 @@ internal class AstToLogicalVisitorTransform( } } - val target = dmlOp.target.toDmlTargetId() + val target = dmlOp.target.toDmlTarget() val alias = dmlOp.asAlias?.let { PartiqlLogical.VarDecl(it) - } ?: PartiqlLogical.VarDecl(target.name) - - val operation = when (val conflictAction = dmlOp.conflictAction) { - null -> PartiqlLogical.DmlOperation.DmlInsert(targetAlias = alias) - is PartiqlAst.ConflictAction.DoReplace -> when (conflictAction.value) { - is PartiqlAst.OnConflictValue.Excluded -> PartiqlLogical.DmlOperation.DmlReplace( - targetAlias = alias, - condition = conflictAction.condition?.let { transformExpr(it) }, - rowAlias = conflictAction.condition?.let { PartiqlLogical.VarDecl(SymbolPrimitive(EXCLUDED, emptyMetaContainer())) } - ) - } - is PartiqlAst.ConflictAction.DoUpdate -> when (conflictAction.value) { - is PartiqlAst.OnConflictValue.Excluded -> PartiqlLogical.DmlOperation.DmlUpdate( - targetAlias = alias, - condition = conflictAction.condition?.let { transformExpr(it) }, - rowAlias = conflictAction.condition?.let { PartiqlLogical.VarDecl(SymbolPrimitive(EXCLUDED, emptyMetaContainer())) } - ) - } - is PartiqlAst.ConflictAction.DoNothing -> TODO("`ON CONFLICT DO NOTHING` is not supported in logical plan yet.") - } + } ?: PartiqlLogical.VarDecl(target.identifier.name) - PartiqlLogical.Statement.Dml( - target = target, - operation = operation, - rows = transformExpr(dmlOp.values), - metas = node.metas - ) + PartiqlLogical.build { + dmlInsert( + target = target, + targetAlias = alias, + rowsToInsert = transformExpr(dmlOp.values), + metas = node.metas, + onConflict = transformConflictAction(dmlOp.conflictAction) + ) + } } // INSERT single row with VALUE is disallowed. (This variation of INSERT might be removed in a future // release of PartiQL.) @@ -410,7 +396,12 @@ internal class AstToLogicalVisitorTransform( ) INVALID_STATEMENT } + is PartiqlAst.DmlOp.Delete -> { + if (node.operations.ops.size > 1) { + error("Malformed AST: more than 1 DELETE DML operation!") + } + if (node.from == null) { // unfortunately, the AST allows malformations such as this however the parser should // never actually create an AST for a DELETE statement without a FROM clause. @@ -427,27 +418,13 @@ internal class AstToLogicalVisitorTransform( } PartiqlLogical.build { - dml( - target = from.expr.toDmlTargetId(), - operation = dmlDelete(), - // This query returns entire rows which are to be deleted, which is unfortunate - // unavoidable without knowledge of schema. PartiQL embedders may apply a - // pass over the resolved logical (or later) plan that changes this to only - // include the primary keys of the rows to be deleted. - rows = bindingsToValues( - exp = id(rowsSource.asDecl.name.text, caseSensitive(), unqualified()), - query = rows, - ), - metas = node.metas - ) + dmlDelete(from = rows, metas = node.metas) } } + else -> { problemHandler.handleProblem( - Problem( - (from?.metas?.sourceLocationMetaOrUnknown?.toProblemLocation() ?: UNKNOWN_PROBLEM_LOCATION), - PlanningProblemDetails.InvalidDmlTarget - ) + Problem((from?.metas?.sourceLocationMetaOrUnknown?.toProblemLocation() ?: UNKNOWN_PROBLEM_LOCATION), PlanningProblemDetails.InvalidDmlTarget) ) INVALID_STATEMENT } @@ -461,18 +438,77 @@ internal class AstToLogicalVisitorTransform( INVALID_STATEMENT } is PartiqlAst.DmlOp.Set -> { - problemHandler.handleProblem( - Problem(dmlOp.metas.sourceLocationMetaOrUnknown.toProblemLocation(), PlanningProblemDetails.UnimplementedFeature("SET")) - ) - INVALID_STATEMENT + val setOperations = node.operations.ops.mapNotNull { + when (it) { + is PartiqlAst.DmlOp.Set -> it + else -> { + problemHandler.handleProblem( + Problem( + it.metas.sourceLocationMeta?.toProblemLocation() ?: UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.CompileError("UPDATE operation other than SET") + ) + ) + null + } + } + } + + val scan = node.from as? PartiqlAst.FromSource.Scan + ?: error("Malformed AST: UPDATE's FROM property was not a scan (this should never happen)") + + val target = scan.expr.toDmlTarget() + val alias = scan.asAlias?.let { PartiqlLogical.VarDecl(it) } + ?: PartiqlLogical.VarDecl(target.identifier.name) + + PartiqlLogical.build { + dmlUpdate( + target = target, + targetAlias = alias, + assignments = setOperations.map { + setAssignment( + it.assignment.target.toSimplePath(), + transformExpr(it.assignment.value) + ) + }, + where = node.where?.let { transformExpr(it) }, + metas = node.metas, + ) + } } } } - private fun PartiqlAst.Expr.toDmlTargetId(): PartiqlLogical.Identifier { - val dmlTargetId = when (this) { + private fun transformConflictAction(conflictAction: PartiqlAst.ConflictAction?) = + when (conflictAction) { + null -> null + is PartiqlAst.ConflictAction.DoReplace -> when (conflictAction.value) { + is PartiqlAst.OnConflictValue.Excluded -> PartiqlLogical.build { + onConflict( + excludedAlias = varDecl(EXCLUDED), + condition = conflictAction.condition?.let { transformExpr(it) }, + action = doReplace(conflictAction.metas) + ) + } + } + + is PartiqlAst.ConflictAction.DoUpdate -> when (conflictAction.value) { + is PartiqlAst.OnConflictValue.Excluded -> PartiqlLogical.build { + onConflict( + excludedAlias = varDecl(EXCLUDED), + condition = conflictAction.condition?.let { transformExpr(it) }, + action = doUpdate(conflictAction.metas) + ) + } + } + + is PartiqlAst.ConflictAction.DoNothing -> + TODO("`ON CONFLICT DO NOTHING` is not supported in logical plan yet.") + } + + private fun PartiqlAst.Expr.toDmlTarget(): PartiqlLogical.DmlTarget { + return when (this) { is PartiqlAst.Expr.Id -> PartiqlLogical.build { - identifier_(name, transformCaseSensitivity(case), metas) + dmlTarget(identifier_(name, transformCaseSensitivity(case), metas)) } else -> { problemHandler.handleProblem( @@ -481,12 +517,45 @@ internal class AstToLogicalVisitorTransform( PlanningProblemDetails.InvalidDmlTarget ) ) - INVALID_DML_TARGET_ID + PartiqlLogical.build { dmlTarget(INVALID_DML_TARGET_ID) } } } - return dmlTargetId } + private fun PartiqlAst.Expr.toSimplePath(): PartiqlLogical.SimplePath = + PartiqlLogical.build { + when (val path = this@toSimplePath) { + is PartiqlAst.Expr.Id -> { + simplePath(root = identifier_(path.name, transformCaseSensitivity(path.case), metas)) + } + is PartiqlAst.Expr.Path -> when (val root = path.root) { + is PartiqlAst.Expr.Id -> simplePath( + identifier_(root.name, transformCaseSensitivity(root.case), metas), + steps = path.steps.map { step -> + when (step) { + is PartiqlAst.PathStep.PathExpr -> when (val index = step.index) { + is PartiqlAst.Expr.Lit -> { + val stepCase = transformCaseSensitivity(step.case) + if (index.value.type.isText) { + spsIdentifier(identifier(index.value.textValue, stepCase), step.metas) + } else if (index.value.type == ElementType.INT) { + spsIndex(index.value.asInt().longValue, step.metas) + } else { + error("Malformed AST: non-text and non-integer path step expression.") + } + } + else -> error("Malformed AST: non-literal path step expression") + } + else -> error("Malformed AST: non-expression path step") + } + } + ) + else -> error("Malformed AST: root expression was not an identifier") + } + else -> error("Malformed AST: non-path expression found as AST set-target") + } + } + override fun transformStatementDdl(node: PartiqlAst.Statement.Ddl): PartiqlLogical.Statement { // It is an open question whether the planner will support DDL statements directly or if they must be handled by // some other construct. For now, we just submit an error with problem details indicating these statements diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt index c5d0d5b545..652b09dc43 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransform.kt @@ -350,9 +350,9 @@ internal data class LogicalToLogicalResolvedVisitorTransform( } } - override fun transformStatementDml(node: PartiqlLogical.Statement.Dml): PartiqlLogicalResolved.Statement { + override fun transformDmlTarget(node: PartiqlLogical.DmlTarget): PartiqlLogicalResolved.DmlTarget { // We only support DML targets that are global variables. - val bindingName = BindingName(node.target.name.text, node.target.case.toBindingCase()) + val bindingName = BindingName(node.identifier.name.text, node.identifier.case.toBindingCase()) val tableUniqueId = when (val resolvedVariable = globals.resolveGlobal(bindingName)) { is GlobalResolutionResult.GlobalVariable -> resolvedVariable.uniqueId GlobalResolutionResult.Undefined -> { @@ -360,43 +360,28 @@ internal data class LogicalToLogicalResolvedVisitorTransform( Problem( node.metas.sourceLocationMetaOrUnknown.toProblemLocation(), PlanningProblemDetails.UndefinedDmlTarget( - node.target.name.text, - node.target.case is PartiqlLogical.CaseSensitivity.CaseSensitive + node.identifier.name.text, + node.identifier.case is PartiqlLogical.CaseSensitivity.CaseSensitive ) ) ) - "undefined DML target: ${node.target.name.text} - do not run" + "undefined DML target: ${node.identifier.name.text} - do not run" } } - return PartiqlLogicalResolved.build { - dml( - uniqueId = tableUniqueId, - operation = transformDmlOperation(node.operation), - rows = transformExpr(node.rows), - metas = node.metas - ) - } + return PartiqlLogicalResolved.build { dmlTarget(uniqueId = tableUniqueId) } } - override fun transformDmlOperationDmlInsert(node: PartiqlLogical.DmlOperation.DmlInsert): PartiqlLogicalResolved.DmlOperation { - return withInputScope(this.inputScope.concatenate(node.targetAlias)) { - super.transformDmlOperationDmlInsert(node) + override fun transformStatementDmlInsert_onConflict(node: PartiqlLogical.Statement.DmlInsert): PartiqlLogicalResolved.OnConflict? { + // the alias should only be accessible to the on_conflict clause. + val scope = this.inputScope.concatenate(listOfNotNull(node.targetAlias, node.onConflict?.excludedAlias)) + return withInputScope(scope) { + super.transformStatementDmlInsert_onConflict(node) } } - override fun transformDmlOperationDmlReplace(node: PartiqlLogical.DmlOperation.DmlReplace): PartiqlLogicalResolved.DmlOperation { - val scopeWithTarget = this.inputScope.concatenate(node.targetAlias) - val inputScope = node.rowAlias?.let { scopeWithTarget.concatenate(it) } ?: scopeWithTarget - return withInputScope(inputScope) { - super.transformDmlOperationDmlReplace(node) - } - } - - override fun transformDmlOperationDmlUpdate(node: PartiqlLogical.DmlOperation.DmlUpdate): PartiqlLogicalResolved.DmlOperation { - val scopeWithTarget = this.inputScope.concatenate(node.targetAlias) - val inputScope = node.rowAlias?.let { scopeWithTarget.concatenate(it) } ?: scopeWithTarget - return withInputScope(inputScope) { - super.transformDmlOperationDmlUpdate(node) + override fun transformStatementDmlUpdate(node: PartiqlLogical.Statement.DmlUpdate): PartiqlLogicalResolved.Statement { + return withInputScope(this.inputScope.concatenate(node.targetAlias)) { + super.transformStatementDmlUpdate(node) } } @@ -404,7 +389,7 @@ internal data class LogicalToLogicalResolvedVisitorTransform( * Returns a list of variables accessible from the current scope which contain variables that may contain * an unqualified variable, in the order that they should be searched. */ - fun currentDynamicResolutionCandidates(): List = + private fun currentDynamicResolutionCandidates(): List = inputScope.varDecls.filter { it.includeInDynamicResolution } override fun transformExprBindingsToValues_exp(node: PartiqlLogical.Expr.BindingsToValues): PartiqlLogicalResolved.Expr { diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt deleted file mode 100644 index 556546e27a..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/IntegrationTests.kt +++ /dev/null @@ -1,196 +0,0 @@ -package org.partiql.lang.compiler - -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.runTest -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Assertions.assertFalse -import org.junit.jupiter.api.Assertions.assertTrue -import org.junit.jupiter.api.Test -import org.partiql.lang.ION -import org.partiql.lang.compiler.memorydb.MemoryDatabase -import org.partiql.lang.compiler.memorydb.QueryEngine -import org.partiql.lang.eval.BAG_ANNOTATION -import org.partiql.lang.eval.BindingCase -import org.partiql.lang.eval.BindingName -import org.partiql.lang.eval.ExprValue -import org.partiql.lang.eval.toIonValue - -class TestContext { - val db = MemoryDatabase().also { - it.createTable("customer", listOf("id")) - it.createTable("more_customer", listOf("id")) - } - private val queryEngine = QueryEngine(db) - - // Executes query - fun executeAndAssert( - expectedResultAsIonText: String, - sql: String, - ) { - val expectedIon = ION.singleValue(expectedResultAsIonText) - val result = queryEngine.executeQuery(sql) - assertEquals(expectedIon, result.toIonValue(ION)) - } - - // Executes query on async evaluator - suspend fun executeAndAssertAsync( - expectedResultAsIonText: String, - sql: String, - ) { - val expectedIon = ION.singleValue(expectedResultAsIonText) - val result = queryEngine.executeQueryAsync(sql) - assertEquals(expectedIon, result.toIonValue(ION)) - } - - fun intKey(value: Int) = ExprValue.newList(listOf(ExprValue.newInt(value))) -} - -/** - * Tests the query planner with some basic DML and SFW queries against using [QueryEngine] and [MemoryDatabase]. - */ -@OptIn(ExperimentalCoroutinesApi::class) -class IntegrationTests { - @Test - fun `insert, select and delete`() { - val ctx = TestContext() - val db = ctx.db - val customerMetadata = db.findTableMetadata(BindingName("customer", BindingCase.SENSITIVE))!! - - // start by inserting 4 rows - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") - - // assert each of the rows is present in the actual table. - assertEquals(4, db.getRowCount(customerMetadata.tableId)) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) - - // commented code intentionally kept. Uncomment to see detailed debug information in the console when - // this test is run - // ctx.queryEngine.enableDebugOutput = true - - // run some simple SFW queries - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 1") - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"jane\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 2") - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 3") - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"sue\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 4") - - // now delete 2 rows and assert that they are no longer present (test DELETE FROM with WHERE predicate) - - ctx.executeAndAssert("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 2") - assertEquals(3, db.getRowCount(customerMetadata.tableId)) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) - - ctx.executeAndAssert("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 4") - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) - - // finally, delete all remaining rows (test DELETE FROM without WHERE predicate) - - ctx.executeAndAssert("{rows_effected:2}", "DELETE FROM customer") - assertEquals(0, db.getRowCount(customerMetadata.tableId)) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) - } - - @Test - fun `insert, select and delete async`() = runTest { - val ctx = TestContext() - val db = ctx.db - val customerMetadata = db.findTableMetadata(BindingName("customer", BindingCase.SENSITIVE))!! - - // start by inserting 4 rows - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") - - // assert each of the rows is present in the actual table. - assertEquals(4, db.getRowCount(customerMetadata.tableId)) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) - assertTrue(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) - - // commented code intentionally kept. Uncomment to see detailed debug information in the console when - // this test is run - // ctx.queryEngine.enableDebugOutput = true - - // run some simple SFW queries - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 1") - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"jane\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 2") - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 3") - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"sue\"}]", "SELECT c.name FROM customer AS c WHERE c.id = 4") - - // now delete 2 rows and assert that they are no longer present (test DELETE FROM with WHERE predicate) - - ctx.executeAndAssertAsync("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 2") - assertEquals(3, db.getRowCount(customerMetadata.tableId)) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(2))) - - ctx.executeAndAssertAsync("{rows_effected:1}", "DELETE FROM customer AS c WHERE c.id = 4") - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(4))) - - // finally, delete all remaining rows (test DELETE FROM without WHERE predicate) - - ctx.executeAndAssertAsync("{rows_effected:2}", "DELETE FROM customer") - assertEquals(0, db.getRowCount(customerMetadata.tableId)) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(1))) - assertFalse(db.tableContainsKey(customerMetadata.tableId, ctx.intKey(3))) - } - - @Test - fun `insert with select`() { - val ctx = TestContext() - val db = ctx.db - // first put some data into the customer table - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") - ctx.executeAndAssert("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") - - // copy that data into the more_customer table by INSERTing the result of an SFW query - ctx.executeAndAssert( - "{rows_effected:2}", - "INSERT INTO more_customer SELECT c.id, c.name FROM customer AS c WHERE c.id IN (1, 3)" - ) - - val moreCustomerMetadata = db.findTableMetadata(BindingName("more_customer", BindingCase.SENSITIVE))!! - assertEquals(2, db.getRowCount(moreCustomerMetadata.tableId)) - assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(1))) - assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(3))) - - // lastly, assert we have the correct data - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM more_customer AS c where c.id = 1") - ctx.executeAndAssert("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM more_customer AS c where c.id = 3") - } - - @Test - fun `insert with select async`() = runTest { - val ctx = TestContext() - val db = ctx.db - // first put some data into the customer table - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 1, 'name': 'bob' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 2, 'name': 'jane' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 3, 'name': 'moe' } >>") - ctx.executeAndAssertAsync("{rows_effected:1}", "INSERT INTO customer << { 'id': 4, 'name': 'sue' } >>") - - // copy that data into the more_customer table by INSERTing the result of an SFW query - ctx.executeAndAssertAsync( - "{rows_effected:2}", - "INSERT INTO more_customer SELECT c.id, c.name FROM customer AS c WHERE c.id IN (1, 3)" - ) - - val moreCustomerMetadata = db.findTableMetadata(BindingName("more_customer", BindingCase.SENSITIVE))!! - assertEquals(2, db.getRowCount(moreCustomerMetadata.tableId)) - assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(1))) - assertTrue(db.tableContainsKey(moreCustomerMetadata.tableId, ctx.intKey(3))) - - // lastly, assert we have the correct data - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"bob\"}]", "SELECT c.name FROM more_customer AS c where c.id = 1") - ctx.executeAndAssertAsync("$BAG_ANNOTATION::[{ name: \"moe\"}]", "SELECT c.name FROM more_customer AS c where c.id = 3") - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryDatabase.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryDatabase.kt deleted file mode 100644 index b81ae80a82..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryDatabase.kt +++ /dev/null @@ -1,84 +0,0 @@ -package org.partiql.lang.compiler.memorydb - -import org.partiql.lang.eval.BindingCase -import org.partiql.lang.eval.BindingName -import org.partiql.lang.eval.ExprValue -import java.util.UUID - -class TableMetadata(val tableId: UUID, val name: String, val primaryKeyFields: List) - -/** - * This is an extremely simple in-memory "database" for the purposes of demonstrating how to use PartiQL's - * query planner. - * - * This database supports basic SFW and DML operations. - */ -class MemoryDatabase { - private val tables = ArrayList() - - /** - * Locates a table's schema by name, with optional case-insensitivity. - * - * Returns `null` if the table doesn't exist. - */ - fun findTableMetadata(bindingName: BindingName): TableMetadata? = - tables.firstOrNull { bindingName.isEquivalentTo(it.metadata.name) }?.metadata - - /** - * Returns a table's metadata, given its UUID. If no table with the given UUID exists, an exception is - * thrown. - */ - fun getTableMetadata(tableId: UUID): TableMetadata = - tables.firstOrNull { it.metadata.tableId == tableId }?.metadata - ?: error("Table with id '$tableId' does not exist!") - - /** - * Creates a table with the specified name and primary key fields. - * - * Currently, we assume that primary key fields are case-sensitive, but this is probably - * incorrect. DL TODO: verify this and change it if needed. - */ - fun createTable(tableName: String, primaryKeyFields: List): TableMetadata { - findTableMetadata(BindingName(tableName, BindingCase.SENSITIVE))?.let { - error("Table with the name '$tableName' already exists!") - } - - val metadata = TableMetadata(UUID.randomUUID(), tableName, primaryKeyFields) - val newTable = MemoryTable(metadata) - tables.add(newTable) - - return metadata - } - - private fun getTable(tableId: UUID) = - tables.firstOrNull { it.metadata.tableId == tableId } - // if this happens either the table has been dropped and the plan being executed is no longer valid - // or there's a bug in the query planner and/or one of the custom passes. - ?: error("Table with id '$tableId' does not exist!") - - fun getRowCount(tableId: UUID) = - getTable(tableId).size - - fun tableContainsKey(tableId: UUID, key: ExprValue) = - getTable(tableId).containsKey(key) - - /** Inserts the specified row.*/ - fun insert(tableId: UUID, row: ExprValue) { - val targetTable = getTable(tableId) - targetTable.insert(row) - } - - /** Deletes the specified row. */ - fun delete(tableId: UUID, row: ExprValue) { - val targetTable = getTable(tableId) - targetTable.delete(row) - } - - /** Gets a [Sequence] for the specified table. */ - fun getFullScanSequence(tableId: UUID): Sequence = getTable(tableId) - - fun getRecordByKey(tableId: UUID, key: ExprValue): ExprValue? { - val targetTable = getTable(tableId) - return targetTable[key] - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryTable.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryTable.kt deleted file mode 100644 index 4561f22c15..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/MemoryTable.kt +++ /dev/null @@ -1,72 +0,0 @@ -package org.partiql.lang.compiler.memorydb - -import org.partiql.lang.ION -import org.partiql.lang.eval.BindingCase -import org.partiql.lang.eval.BindingName -import org.partiql.lang.eval.DEFAULT_COMPARATOR -import org.partiql.lang.eval.ExprValue -import org.partiql.lang.eval.ExprValueType -import org.partiql.lang.eval.toIonValue -import java.util.TreeMap - -/** - * An extremely simple in-memory table, to be used with [MemoryDatabase]. - */ -class MemoryTable( - val metadata: TableMetadata -) : Sequence { - private val rows = TreeMap(DEFAULT_COMPARATOR) - - private val primaryKeyBindingNames = metadata.primaryKeyFields.map { BindingName(it, BindingCase.SENSITIVE) } - - private fun ExprValue.extractPrimaryKey(): ExprValue = - ExprValue.newList( - primaryKeyBindingNames.map { - this.bindings[it] ?: error("Row missing primary key field '${it.name}' (case: ${it.bindingCase})") - }.asIterable() - ) - - fun containsKey(key: ExprValue): Boolean { - require(key.type == ExprValueType.LIST) { "Primary key value must be a list" } - return rows.containsKey(key) - } - - val size: Int get() = rows.size - - operator fun get(key: ExprValue): ExprValue? { - require(key.type == ExprValueType.LIST) { "specified key must have type ExprValueType.LIST " } - return rows[key] - } - - fun insert(row: ExprValue) { - require(row.type == ExprValueType.STRUCT) { "Row to be inserted must be a struct" } - - val primaryKeyExprValue = row.extractPrimaryKey() - - if (rows.containsKey(primaryKeyExprValue)) { - error("Table '${this.metadata.name}' already contains a row with the specified primary key ") - } else { - // We have to detatch the ExprValue from any lazily evaluated query that may get invoked - // whenever the value is accessed. To do this we convert to Ion, which forces full materialization, - // and then create a new ExprValue based off the Ion. - val rowStruct = row.toIonValue(ION) - rows[primaryKeyExprValue] = ExprValue.of(rowStruct) - } - } - - /** - * Deletes a row from the table. [row] should at least contain all the fields which make up the - * primary key of the table, but may also contain additional rows, which are ignored. - */ - fun delete(row: ExprValue) { - require(row.type == ExprValueType.STRUCT) { "Row to be deleted must be a struct" } - val primaryKey = row.extractPrimaryKey() - rows.remove(primaryKey) - } - - override fun iterator(): Iterator = - // the call to .toList below is important to allow the table contents to be modified during query - // execution. (Otherwise we will hit a ConcurrentModificationException in the case a DELETE FROM statement - // is executed) - rows.values.toList().iterator() -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/OperatorNames.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/OperatorNames.kt deleted file mode 100644 index 5ea30cde69..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/OperatorNames.kt +++ /dev/null @@ -1,3 +0,0 @@ -package org.partiql.lang.compiler.memorydb - -const val GET_BY_KEY_PROJECT_IMPL_NAME = "custom_get_by_key" diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt deleted file mode 100644 index 5b69067ab8..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/QueryEngine.kt +++ /dev/null @@ -1,251 +0,0 @@ -package org.partiql.lang.compiler.memorydb - -import com.amazon.ionelement.api.toIonValue -import org.partiql.annotations.ExperimentalPartiQLCompilerPipeline -import org.partiql.lang.ION -import org.partiql.lang.compiler.PartiQLCompilerPipeline -import org.partiql.lang.compiler.PartiQLCompilerPipelineAsync -import org.partiql.lang.compiler.memorydb.operators.GetByKeyProjectRelationalOperatorFactory -import org.partiql.lang.compiler.memorydb.operators.GetByKeyProjectRelationalOperatorFactoryAsync -import org.partiql.lang.domains.PartiqlPhysical -import org.partiql.lang.eval.BindingCase -import org.partiql.lang.eval.BindingName -import org.partiql.lang.eval.Bindings -import org.partiql.lang.eval.EvaluationSession -import org.partiql.lang.eval.ExprValue -import org.partiql.lang.eval.PartiQLResult -import org.partiql.lang.eval.StructOrdering -import org.partiql.lang.eval.namedValue -import org.partiql.lang.planner.GlobalResolutionResult -import org.partiql.lang.planner.GlobalVariableResolver -import org.partiql.lang.planner.PartiQLPhysicalPass -import org.partiql.lang.planner.PartiQLPlannerBuilder -import org.partiql.lang.planner.StaticTypeResolver -import org.partiql.lang.planner.transforms.optimizations.createConcatWindowFunctionPass -import org.partiql.lang.planner.transforms.optimizations.createFilterScanToKeyLookupPass -import org.partiql.lang.planner.transforms.optimizations.createRemoveUselessAndsPass -import org.partiql.lang.planner.transforms.optimizations.createRemoveUselessFiltersPass -import org.partiql.lang.util.SexpAstPrettyPrinter -import org.partiql.pig.runtime.DomainNode -import org.partiql.types.BagType -import org.partiql.types.StructType -import java.util.UUID - -// The name of the database in the context variable. -internal const val DB_CONTEXT_VAR = "in-memory-database" - -/** - * This class is a demonstration of how to integrate a storage layer with the experimental PartiQLCompilerPipeline. - */ -@OptIn(ExperimentalPartiQLCompilerPipeline::class) -class QueryEngine(val db: MemoryDatabase) { - private var enableDebugOutput = false - - /** Given a [BindingName], inform the planner the unique identifier of the global variable (usually a table). */ - private val globalVariableResolver = GlobalVariableResolver { bindingName -> - // The planner has asked us to resolve a global variable named [bindingName]. let's do so and return the - // UUID of the table. This will get packaged into a (global_id ) node (a reference to an - // unambiguously global variable). - db.findTableMetadata(bindingName)?.let { tableMetadata -> - GlobalResolutionResult.GlobalVariable(tableMetadata.tableId.toString()) - } ?: GlobalResolutionResult.Undefined - } - - /** Given a global variable's unique id, informs the planner about the static type (schema) of the global variable. */ - private val staticTypeResolver = StaticTypeResolver { uniqueId -> - val tableMetadata = db.getTableMetadata(UUID.fromString(uniqueId)) - // Tables are a bag of structs. - // TODO: at some point we'll populate this with complete schema information. - BagType( - StructType( - // TODO: nothing in the planner uses the fields property yet - fields = emptyMap(), - // TODO: nothing in the planner uses the contentClosed property yet, but "technically" do have open - // content since nothing is constraining the fields in the table. - contentClosed = false, - // The FilterScanToKeyLookup pass does use this. - primaryKeyFields = tableMetadata.primaryKeyFields - ) - ) - } - - val bindings = object : Bindings { - /** - * This function is called by the `(global_id )` expression to fetch an [ExprValue] for a resolved - * global variable, which is almost always a database table. - */ - override fun get(bindingName: BindingName): ExprValue { - // TODO: PlannerPipeline may need some additional cleanup here because, perhaps very confusingly, the - // bindingName passed here contains the UUID of the table and *not* its name. It should also always - // specify a case-sensitive binding name. Really, we should reconsider if [PlannerPipeline] should use - // [Bindings] at all, perhaps another interface that's more narrow should be used instead. - // Another difference in how PlannerPipeline uses Bindings (and argument for using a new - // interface) is that the bindingName here is guaranteed to be valid because otherwise planning would have - // been aborted. If the lookup fails for some reason it would mean the plan is invalid. But with - // the [CompilerPipeline] these checks don't exist and therefore the [Bindings] implementation is - // expected to throw if the variable does not exist. - require(bindingName.bindingCase == BindingCase.SENSITIVE) { - "It is assumed that the plan evaluator will always set bindingName.bindingCase to SENSITIVE" - } - - val tableId = UUID.fromString(bindingName.name) - return ExprValue.newBag( - db.getFullScanSequence(tableId) - ) - } - } - - // session data - val session = EvaluationSession.build { - globals(bindings) - // Please note that the context here is immutable once the call to .build above - // returns, (Hopefully that will reduce the chances of it being abused.) - withContextVariable("in-memory-database", db) - } - - private fun PartiQLPlannerBuilder.plannerBlock() = this - .callback { - fun prettyPrint(label: String, data: Any) { - val padding = 10 - when (data) { - is DomainNode -> { - println("$label:") - val sexpElement = data.toIonElement() - println(SexpAstPrettyPrinter.format(sexpElement.asAnyElement().toIonValue(ION))) - } - else -> - println("$label:".padEnd(padding) + data.toString()) - } - } - if (this@QueryEngine.enableDebugOutput) { - prettyPrint("event", it.eventName) - prettyPrint("duration", it.duration) - if (it.eventName == "parse_sql") prettyPrint("input", it.input) - prettyPrint("output", it.output) - } - } - .globalVariableResolver(globalVariableResolver) - .physicalPlannerPasses( - listOf( - // TODO: push-down filters on top of scans before this pass. - PartiQLPhysicalPass { plan, problemHandler -> - createFilterScanToKeyLookupPass( - customProjectOperatorName = GET_BY_KEY_PROJECT_IMPL_NAME, - staticTypeResolver = staticTypeResolver, - createKeyValueConstructor = { recordType, keyFieldEqualityPredicates -> - require(recordType.primaryKeyFields.size == keyFieldEqualityPredicates.size) - PartiqlPhysical.build { - list( - // Key values are expressed to the in-memory storage engine as ordered list. Therefore, we need - // to ensure that the list we pass in as an argument to the custom_get_by_key project operator - // impl is in the right order. - recordType.primaryKeyFields.map { keyFieldName -> - keyFieldEqualityPredicates.single { it.keyFieldName == keyFieldName }.equivalentValue - } - ) - } - } - ).apply(plan, problemHandler) - }, - // Note that the order of the following plans is relevant--the "remove useless filters" pass - // will not work correctly if "remove useless ands" pass is not executed first. - - // After the filter-scan-to-key-lookup pass above, we may be left with some `(and ...)` expressions - // whose operands were replaced with `(lit true)`. This pass removes `(lit true)` operands from `and` - // expressions, and replaces any `and` expressions with only `(lit true)` operands with `(lit true)`. - // This happens recursively, so an entire tree of useless `(and ...)` expressions will be replaced - // with a single `(lit true)`. - // A constant folding pass might one day eliminate the need for this, but that is not within the current scope. - PartiQLPhysicalPass { plan, problemHandler -> - createRemoveUselessAndsPass().apply(plan, problemHandler) - }, - - // After the previous pass, we may have some `(filter ... )` nodes with `(lit true)` as a predicate. - // This pass removes these useless filter nodes. - PartiQLPhysicalPass { plan, problemHandler -> - createRemoveUselessFiltersPass().apply(plan, problemHandler) - }, - - PartiQLPhysicalPass { plan, problemHandler -> - createConcatWindowFunctionPass().apply(plan, problemHandler) - }, - ) - ) - - private val compilerPipeline = PartiQLCompilerPipeline.build { - planner.plannerBlock() - compiler - .customOperatorFactories( - listOf( - GetByKeyProjectRelationalOperatorFactory() - ) - ) - } - - private val compilerPipelineAsync = PartiQLCompilerPipelineAsync.build { - planner.plannerBlock() - compiler - .customOperatorFactories( - listOf( - GetByKeyProjectRelationalOperatorFactoryAsync() // using async version here - ) - ) - } - - fun executeQuery(sql: String): ExprValue { - // compile query to statement - val statement = compilerPipeline.compile(sql) - - // First step is to plan the query. - // This parses the query and runs it through all the planner passes: - // AST -> logical plan -> resolved logical plan -> default physical plan -> custom physical plan - return convertResultToExprValue(statement.eval(session)) - } - - suspend fun executeQueryAsync(sql: String): ExprValue { - // compile query to statement - val statement = compilerPipelineAsync.compile(sql) - - // First step is to plan the query. - // This parses the query and runs it through all the planner passes: - // AST -> logical plan -> resolved logical plan -> default physical plan -> custom physical plan - return convertResultToExprValue(statement.eval(session)) - } - - private fun convertResultToExprValue(result: PartiQLResult): ExprValue = - when (result) { - is PartiQLResult.Value -> result.value - is PartiQLResult.Delete -> { - val targetTableId = UUID.fromString(result.target) - var rowsEffected = 0L - result.rows.forEach { - db.delete(targetTableId, it) - rowsEffected ++ - } - ExprValue.newStruct( - listOf( - ExprValue.newInt(rowsEffected) - .namedValue(ExprValue.newString("rows_effected")) - ), - StructOrdering.UNORDERED - ) - } - is PartiQLResult.Insert -> { - val targetTableId = UUID.fromString(result.target) - var rowsEffected = 0L - result.rows.forEach { - db.insert(targetTableId, it) - rowsEffected ++ - } - ExprValue.newStruct( - listOf( - ExprValue.newInt(rowsEffected) - .namedValue(ExprValue.newString("rows_effected")) - ), - StructOrdering.UNORDERED - ) - } - is PartiQLResult.Replace -> TODO("Not implemented yet") - is PartiQLResult.Explain.Domain -> TODO("Not implemented yet") - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactory.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactory.kt deleted file mode 100644 index 2a2778fc60..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactory.kt +++ /dev/null @@ -1,108 +0,0 @@ -package org.partiql.lang.compiler.memorydb.operators - -import org.partiql.lang.compiler.memorydb.DB_CONTEXT_VAR -import org.partiql.lang.compiler.memorydb.GET_BY_KEY_PROJECT_IMPL_NAME -import org.partiql.lang.compiler.memorydb.MemoryDatabase -import org.partiql.lang.domains.PartiqlPhysical -import org.partiql.lang.eval.physical.SetVariableFunc -import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactory -import org.partiql.lang.eval.physical.operators.RelationExpression -import org.partiql.lang.eval.physical.operators.ValueExpression -import org.partiql.lang.eval.relation.RelationIterator -import org.partiql.lang.eval.relation.RelationScope -import org.partiql.lang.eval.relation.RelationType -import org.partiql.lang.eval.relation.relation -import java.util.UUID - -/** - * A `project` operator implementation that performs a lookup of a single record stored in a [MemoryDatabase] given its - * primary key. - * - * Operator implementations comprise two phases: - * - * - A compile phase, where one-time computation can be performed and stored in a [RelationExpression], which - * is essentially a closure. - *- An evaluation phase, where the closure is invoked. The closure returns a [RelationIterator], which is a - * coroutine created by the [relation] function. - * - * In general, the `project` operator implementations must fetch the next row from the data store, call the provided - * [SetVariableFunc] to set the variable, and then call [RelationScope.yield]. - */ - -class GetByKeyProjectRelationalOperatorFactory : ProjectRelationalOperatorFactory(GET_BY_KEY_PROJECT_IMPL_NAME) { - /** - * This function is called at compile-time to create an instance of the operator [RelationExpression] - * that will be invoked at evaluation-time. - */ - override fun create( - impl: PartiqlPhysical.Impl, - setVar: SetVariableFunc, - args: List - ): RelationExpression { - // Compile phase starts here. We should do as much pre-computation as possible to avoid repeating during the - // evaluation phase. - - // Sanity check the static and dynamic arguments of this operator. If either of these checks fail, it would - // indicate a bug in the rewrite which created this (project ...) operator. - require(impl.staticArgs.size == 1) { - "Expected one static argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" - } - require(args.size == 1) { - "Expected one argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" - } - - // Extract the key value constructor - val keyValueExpression = args.single() - - // Parse the tableId so we don't have to at evaluation-time - val tableId = UUID.fromString(impl.staticArgs.single().textValue) - - var exhausted = false - - // Finally, return a RelationExpression which evaluates the key value expression and returns a - // RelationIterator containing a single row corresponding to the key (or no rows if nothing matches) - return RelationExpression { state -> - // this code runs at evaluation-time. - - if (exhausted) { - throw IllegalStateException("Exhausted result set") - } - - // Get the current database from the EvaluationSession context. - // Please note that the state.session.context map is immutable, therefore it is not possible - // for custom operators or functions to put stuff in there. (Hopefully that will reduce the - // chances of it being abused.) - val db = state.session.context[DB_CONTEXT_VAR] as MemoryDatabase - - // Compute the value of the key using the keyValueExpression - val keyValue = keyValueExpression.invoke(state) - - // get the record requested. - val record = db.getRecordByKey(tableId, keyValue) - - exhausted = true - - // if the record was not found, return an empty relation: - if (record == null) - relation(RelationType.BAG) { - // this relation is empty because there is no call to yield() - } - else { - // Return the relation which is Kotlin-coroutine that simply projects the single record we - // found above into the one variable allowed by the project operator, yields, and then returns. - relation(RelationType.BAG) { - // `state` is sacrosanct and should not be modified outside PartiQL. PartiQL - // provides the setVar function so that embedders can safely set the value of the - // variable from within the relation without clobbering anything else. - // It is important to call setVar *before* the yield since otherwise the value - // of the variable will not be assigned before it is accessed. - setVar(state, record) - yield() - - // also note that in this case there is only one record--to return multiple records we would - // iterate over each record normally, calling `setVar` and `yield` once for each record. - } - } - } - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt deleted file mode 100644 index 323ec13b6c..0000000000 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/compiler/memorydb/operators/GetByKeyProjectRelationalOperatorFactoryAsync.kt +++ /dev/null @@ -1,108 +0,0 @@ -package org.partiql.lang.compiler.memorydb.operators - -import org.partiql.lang.compiler.memorydb.DB_CONTEXT_VAR -import org.partiql.lang.compiler.memorydb.GET_BY_KEY_PROJECT_IMPL_NAME -import org.partiql.lang.compiler.memorydb.MemoryDatabase -import org.partiql.lang.domains.PartiqlPhysical -import org.partiql.lang.eval.physical.SetVariableFunc -import org.partiql.lang.eval.physical.operators.ProjectRelationalOperatorFactoryAsync -import org.partiql.lang.eval.physical.operators.RelationExpressionAsync -import org.partiql.lang.eval.physical.operators.ValueExpressionAsync -import org.partiql.lang.eval.relation.RelationIterator -import org.partiql.lang.eval.relation.RelationScope -import org.partiql.lang.eval.relation.RelationType -import org.partiql.lang.eval.relation.relation -import java.util.UUID - -/** - * A `project` operator implementation that performs a lookup of a single record stored in a [MemoryDatabase] given its - * primary key. - * - * Operator implementations comprise two phases: - * - * - A compile phase, where one-time computation can be performed and stored in a [RelationExpressionAsync], which - * is essentially a closure. - *- An evaluation phase, where the closure is invoked. The closure returns a [RelationIterator], which is a - * coroutine created by the [relation] function. - * - * In general, the `project` operator implementations must fetch the next row from the data store, call the provided - * [SetVariableFunc] to set the variable, and then call [RelationScope.yield]. - */ - -class GetByKeyProjectRelationalOperatorFactoryAsync : ProjectRelationalOperatorFactoryAsync(GET_BY_KEY_PROJECT_IMPL_NAME) { - /** - * This function is called at compile-time to create an instance of the operator [RelationExpressionAsync] - * that will be invoked at evaluation-time. - */ - override fun create( - impl: PartiqlPhysical.Impl, - setVar: SetVariableFunc, - args: List - ): RelationExpressionAsync { - // Compile phase starts here. We should do as much pre-computation as possible to avoid repeating during the - // evaluation phase. - - // Sanity check the static and dynamic arguments of this operator. If either of these checks fail, it would - // indicate a bug in the rewrite which created this (project ...) operator. - require(impl.staticArgs.size == 1) { - "Expected one static argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" - } - require(args.size == 1) { - "Expected one argument to $GET_BY_KEY_PROJECT_IMPL_NAME but found ${args.size}" - } - - // Extract the key value constructor - val keyValueExpressionAsync = args.single() - - // Parse the tableId, so we don't have to at evaluation-time - val tableId = UUID.fromString(impl.staticArgs.single().textValue) - - var exhausted = false - - // Finally, return a RelationExpressionAsync which evaluates the key value expression and returns a - // RelationIterator containing a single row corresponding to the key (or no rows if nothing matches) - return RelationExpressionAsync { state -> - // this code runs at evaluation-time. - - if (exhausted) { - throw IllegalStateException("Exhausted result set") - } - - // Get the current database from the EvaluationSession context. - // Please note that the state.session.context map is immutable, therefore it is not possible - // for custom operators or functions to put stuff in there. (Hopefully that will reduce the - // chances of it being abused.) - val db = state.session.context[DB_CONTEXT_VAR] as MemoryDatabase - - // Compute the value of the key using the keyValueExpressionAsync - val keyValue = keyValueExpressionAsync.invoke(state) - - // get the record requested. - val record = db.getRecordByKey(tableId, keyValue) - - exhausted = true - - // if the record was not found, return an empty relation: - if (record == null) - relation(RelationType.BAG) { - // this relation is empty because there is no call to yield() - } - else { - // Return the relation which is Kotlin-coroutine that simply projects the single record we - // found above into the one variable allowed by the project operator, yields, and then returns. - relation(RelationType.BAG) { - // `state` is sacrosanct and should not be modified outside PartiQL. PartiQL - // provides the setVar function so that embedders can safely set the value of the - // variable from within the relation without clobbering anything else. - // It is important to call setVar *before* the yield since otherwise the value - // of the variable will not be assigned before it is accessed. - setVar(state, record) - yield() - - // also note that in this case there is only one record--to return multiple records we would - // iterate over each record normally, calling `setVar` and `yield` once for each record. - } - } - } - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt index 1ca5a7e786..296f60f77c 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/AstToLogicalVisitorTransformTests.kt @@ -74,7 +74,7 @@ class AstToLogicalVisitorTransformTests { @ParameterizedTest @ArgumentsSource(ArgumentsForToLogicalWindowTests::class) - fun `to logical (Window)`(tc: TestCase) = runTestCase(tc) + fun `to logical-Window`(tc: TestCase) = runTestCase(tc) class ArgumentsForToLogicalWindowTests : ArgumentsProviderBase() { override fun getParameters() = listOf( @@ -130,7 +130,7 @@ class AstToLogicalVisitorTransformTests { @ParameterizedTest @ArgumentsSource(ArgumentsForToLogicalSfwTests::class) - fun `to logical (SFW)`(tc: TestCase) = runTestCase(tc) + fun `to logical-SFW`(tc: TestCase) = runTestCase(tc) class ArgumentsForToLogicalSfwTests : ArgumentsProviderBase() { @@ -725,27 +725,26 @@ class AstToLogicalVisitorTransformTests { @ParameterizedTest @ArgumentsSource(ArgumentsForToLogicalDmlTests::class) - fun `to logical (DML)`(tc: TestCase) = runTestCase(tc) + fun `to logical-DML`(tc: TestCase) = runTestCase(tc) class ArgumentsForToLogicalDmlTests : ArgumentsProviderBase() { override fun getParameters() = listOf( TestCase( "INSERT INTO foo << 1 >>", PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlInsert(varDecl("foo")), - bag(lit(ionInt(1))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + rowsToInsert = bag(lit(ionInt(1))) ) } ), - TestCase( "INSERT INTO foo SELECT x.* FROM 1 AS x", PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlInsert(varDecl("foo")), - bindingsToValues( + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + rowsToInsert = bindingsToValues( struct(structFields(id("x", caseInsensitive(), unqualified()))), scan(lit(ionInt(1)), varDecl("x")) ) @@ -755,10 +754,14 @@ class AstToLogicalVisitorTransformTests { TestCase( "INSERT INTO foo SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED", PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace(varDecl("foo")), - bindingsToValues( + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace() + ), + rowsToInsert = bindingsToValues( struct(structFields(id("x", caseInsensitive(), unqualified()))), scan(lit(ionInt(1)), varDecl("x")) ) @@ -768,10 +771,12 @@ class AstToLogicalVisitorTransformTests { TestCase( "INSERT INTO foo SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED WHERE foo.id > 2", PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace( - targetAlias = varDecl("foo"), + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace(), condition = gt( listOf( path( @@ -780,10 +785,9 @@ class AstToLogicalVisitorTransformTests { ), lit(ionInt(2)) ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) + ) ), - bindingsToValues( + rowsToInsert = bindingsToValues( struct(structFields(id("x", caseInsensitive(), unqualified()))), scan(lit(ionInt(1)), varDecl("x")) ) @@ -793,47 +797,48 @@ class AstToLogicalVisitorTransformTests { TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO REPLACE EXCLUDED", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace(varDecl("f")), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) - ) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace() + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) ) ) - } + ) } ), TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO REPLACE EXCLUDED WHERE f.id > 2", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace( - varDecl("f"), - condition = gt( - listOf( - path( - id("f", caseInsensitive(), unqualified()), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(2)) - ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) - ), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace(), + condition = gt( + listOf( + path( + id("f", caseInsensitive(), unqualified()), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(2)) ) ) + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) + ) ) - } + ) } ), @@ -841,73 +846,75 @@ class AstToLogicalVisitorTransformTests { TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}, {'id': 2, 'name':'alice'}>> ON CONFLICT DO REPLACE EXCLUDED WHERE f.id > 2", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace( - varDecl("f"), - condition = gt( - listOf( - path( - id("f", caseInsensitive(), unqualified()), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(2)) - ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) - ), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) - ), - struct( - structField(lit(ionString("id")), lit(ionInt(2))), - structField(lit(ionString("name")), lit(ionString("alice"))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace(), + condition = gt( + listOf( + path( + id("f", caseInsensitive(), unqualified()), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(2)) ) ) + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) + ), + struct( + structField(lit(ionString("id")), lit(ionInt(2))), + structField(lit(ionString("name")), lit(ionString("alice"))) + ) ) - } + ) } ), // Testing using excluded non-reserved keyword in condition TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO REPLACE EXCLUDED WHERE excluded.id > 2", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace( - varDecl("f"), - condition = gt( - listOf( - path( - id("excluded", caseInsensitive(), unqualified()), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(2)) - ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) - ), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace(), + condition = gt( + listOf( + path( + id("excluded", caseInsensitive(), unqualified()), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(2)) ) ) + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) + ) ) - } + ) } ), TestCase( "INSERT INTO foo SELECT x.* FROM 1 AS x ON CONFLICT DO UPDATE EXCLUDED", PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlUpdate(varDecl("foo")), - bindingsToValues( + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doUpdate() + ), + rowsToInsert = bindingsToValues( struct(structFields(id("x", caseInsensitive(), unqualified()))), scan(lit(ionInt(1)), varDecl("x")) ) @@ -917,137 +924,129 @@ class AstToLogicalVisitorTransformTests { TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO UPDATE EXCLUDED", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlUpdate(varDecl("f")), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) - ) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doUpdate() + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) ) ) - } + ) } ), TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO UPDATE EXCLUDED WHERE f.id > 2", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlUpdate( - varDecl("f"), - condition = gt( - listOf( - path( - id("f", caseInsensitive(), unqualified()), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(2)) - ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) - ), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doUpdate(), + condition = gt( + listOf( + path( + id("f", caseInsensitive(), unqualified()), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(2)) ) ) + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) + ) ) - } + ) } ), // Testing using excluded non-reserved keyword in condition TestCase( "INSERT INTO foo AS f <<{'id': 1, 'name':'bob'}>> ON CONFLICT DO UPDATE EXCLUDED WHERE excluded.id > 2", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlUpdate( - varDecl("f"), - condition = gt( - listOf( - path( - id("excluded", caseInsensitive(), unqualified()), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(2)) - ) - ), - rowAlias = varDecl(AstToLogicalVisitorTransform.EXCLUDED) - ), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doUpdate(), + condition = gt( + listOf( + path( + id("excluded", caseInsensitive(), unqualified()), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(2)) ) ) + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) + ) ) - } + ) } ), TestCase( "REPLACE INTO foo AS f <<{'id': 1, 'name':'bob'}>>", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlReplace(varDecl("f")), - bag( - struct( - structField(lit(ionString("id")), lit(ionInt(1))), - structField(lit(ionString("name")), lit(ionString("bob"))) - ) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doReplace() + ), + rowsToInsert = bag( + struct( + structField(lit(ionString("id")), lit(ionInt(1))), + structField(lit(ionString("name")), lit(ionString("bob"))) ) ) - } + ) } ), TestCase( "UPSERT INTO foo AS f SELECT x.* FROM 1 AS x", PartiqlLogical.build { - PartiqlLogical.build { - dml( - identifier("foo", caseInsensitive()), - dmlUpdate(varDecl("f")), - bindingsToValues( - struct(structFields(id("x", caseInsensitive(), unqualified()))), - scan(lit(ionInt(1)), varDecl("x")) - ) + dmlInsert( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + onConflict = onConflict( + excludedAlias = varDecl("EXCLUDED"), + action = doUpdate() + ), + rowsToInsert = bindingsToValues( + struct(structFields(id("x", caseInsensitive(), unqualified()))), + scan(lit(ionInt(1)), varDecl("x")) ) - } + ) } ), TestCase( "DELETE FROM y AS y", PartiqlLogical.build { - dml( - identifier("y", caseInsensitive()), - dmlDelete(), - bindingsToValues( - id("y", caseSensitive(), unqualified()), - scan(id("y", caseInsensitive(), unqualified()), varDecl("y")) - ) - ) + dmlDelete(scan(id("y", caseInsensitive(), unqualified()), varDecl("y"))) } ), TestCase( "DELETE FROM y AS y WHERE 1=1", PartiqlLogical.build { - dml( - identifier("y", caseInsensitive()), - dmlDelete(), - bindingsToValues( - id("y", caseSensitive(), unqualified()), - // this logical plan is same as previous but includes this filter - filter( - eq(lit(ionInt(1)), lit(ionInt(1))), - scan(id("y", caseInsensitive(), unqualified()), varDecl("y")) - ) + dmlDelete( + // this logical plan is same as previous but includes this filter + filter( + eq(lit(ionInt(1)), lit(ionInt(1))), + scan(id("y", caseInsensitive(), unqualified()), varDecl("y")) ) ) } @@ -1083,6 +1082,220 @@ class AstToLogicalVisitorTransformTests { ) } + @ParameterizedTest + @ArgumentsSource(ArgumentsForToLogicalDmlUpdateTests::class) + fun `to logical-DML-update`(tc: TestCase) = runTestCase(tc) + class ArgumentsForToLogicalDmlUpdateTests : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + // simple, non-nested target field + "FROM foo AS f WHERE TRUE SET a = 1", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = emptyList() + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + TestCase( + // simple, non-nested target field + "UPDATE foo SET a = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = emptyList() + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // simple, nested target field (depth 2) + TestCase( + "UPDATE foo SET a.b = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = listOf(spsIdentifier(identifier("b", caseInsensitive()))) + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // simple, nested target field (depth 3) + TestCase( + "UPDATE foo SET a.b.c = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = listOf( + spsIdentifier(identifier("b", caseInsensitive())), + spsIdentifier(identifier("c", caseInsensitive())) + ) + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // simple path with various types of steps + TestCase( + "UPDATE foo SET a[42].c[84] = 42 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = listOf( + spsIndex(42), + spsIdentifier(identifier("c", caseInsensitive())), + spsIndex(84), + ) + ), + value = lit(ionInt(42)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // case sensitivity of root in simple path + TestCase( + "UPDATE foo SET \"a\".b = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseSensitive()), + steps = listOf(spsIdentifier(identifier("b", caseInsensitive()))) + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // case sensitivity of path step in simple path + TestCase( + "UPDATE foo SET a.\"b\" = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = listOf(spsIdentifier(identifier("b", caseSensitive()))) + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + TestCase( + // with explicit target alias + "UPDATE foo AS f SET a = 1 WHERE TRUE", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("f"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = emptyList() + ), + value = lit(ionInt(1)), + ), + ), + where = lit(ionBool(true)) + ) + } + ), + // with multiple set assignments. + TestCase( + "UPDATE foo SET a = 1, b.c = 2, d.e.f = 3 WHERE 42", + PartiqlLogical.build { + dmlUpdate( + target = dmlTarget(identifier("foo", caseInsensitive())), + targetAlias = varDecl("foo"), + assignments = listOf( + setAssignment( + setTarget = simplePath( + root = identifier("a", caseInsensitive()), + steps = emptyList() + ), + value = lit(ionInt(1)), + ), + setAssignment( + setTarget = simplePath( + root = identifier("b", caseInsensitive()), + steps = listOf( + spsIdentifier(identifier("c", caseInsensitive())) + ) + ), + value = lit(ionInt(2)), + ), + setAssignment( + setTarget = simplePath( + root = identifier("d", caseInsensitive()), + steps = listOf( + spsIdentifier(identifier("e", caseInsensitive())), + spsIdentifier(identifier("f", caseInsensitive())) + ) + ), + value = lit(ionInt(3)), + ), + ), + where = lit(ionInt(42)) + ) + } + ), + ) + } + data class ProblemTestCase(val id: Int, val sql: String, val expectedProblem: Problem) @ParameterizedTest @@ -1117,8 +1330,6 @@ class AstToLogicalVisitorTransformTests { // Unimplemented parts of DML ProblemTestCase(200, "FROM x AS xx INSERT INTO foo VALUES (1, 2)", unimplementedProblem("UPDATE / INSERT", 1, 14, 6)), - ProblemTestCase(201, "FROM x AS xx SET k = 5", unimplementedProblem("SET", 1, 14, 3)), - ProblemTestCase(202, "UPDATE x SET k = 5", unimplementedProblem("SET", 1, 10, 3)), ProblemTestCase(203, "UPDATE x REMOVE k", unimplementedProblem("REMOVE", 1, 10, 6)), ProblemTestCase(204, "UPDATE x INSERT INTO k << 1 >>", unimplementedProblem("UPDATE / INSERT", 1, 10, 6)), diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass.kt index 8c27ed2e64..558da54dc1 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass.kt @@ -12,6 +12,7 @@ import org.partiql.lang.domains.PartiqlLogicalResolved import org.partiql.lang.domains.PartiqlPhysical import org.partiql.lang.errors.ProblemCollector import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.pig.runtime.SymbolPrimitive import kotlin.test.fail class LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass { @@ -224,172 +225,177 @@ class LogicalResolvedToDefaultPartiQLPhysicalVisitorTransformTestsPass { } class ArgumentsForToDMLTests : ArgumentsProviderBase() { - override fun getParameters() = listOf( - DmlTestCase( - // INSERT INTO foo VALUE 1 - PartiqlLogicalResolved.build { - dml( - uniqueId = "foo", - operation = dmlInsert(varDecl(0)), - rows = bag(lit(ionInt(1))) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "foo", - operation = dmlInsert(varDecl(0)), - rows = bag(lit(ionInt(1))) - ) - } - ), - DmlTestCase( - // INSERT INTO foo SELECT x.* FROM 1 AS x - PartiqlLogicalResolved.build { - dml( - uniqueId = "foo", - operation = dmlInsert(varDecl(0)), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(lit(ionInt(1)), varDecl(0)) + override fun getParameters(): List { + val listOf = listOf( + DmlTestCase( + // INSERT INTO foo VALUE 1 + PartiqlLogicalResolved.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bag(lit(ionInt(1))), ) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "foo", - operation = dmlInsert(varDecl(0)), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) + }, + PartiqlPhysical.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bag(lit(ionInt(1))), ) - ) - } - ), - DmlTestCase( - // INSERT INTO foo [AS f] SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED - PartiqlLogicalResolved.build { - dml( - uniqueId = "foo", - operation = dmlReplace(varDecl(0), rowAlias = varDecl(1)), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(lit(ionInt(1)), varDecl(0)) + } + ), + DmlTestCase( + // INSERT INTO foo SELECT x.* FROM 1 AS x + PartiqlLogicalResolved.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(lit(ionInt(1)), varDecl(0)) + ) ) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "foo", - operation = dmlReplace(varDecl(0), rowAlias = varDecl(1)), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) + }, + PartiqlPhysical.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) + ) ) - ) - } - ), - DmlTestCase( - // INSERT INTO foo [AS f] SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED WHERE foo.id > 1 - PartiqlLogicalResolved.build { - dml( - uniqueId = "foo", - operation = dmlReplace( + } + ), + DmlTestCase( + // INSERT INTO foo [AS f] SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED + PartiqlLogicalResolved.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), targetAlias = varDecl(0), - condition = gt( - listOf( - path( - localId(0), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(1)) - ) + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(lit(ionInt(1)), varDecl(0)) ), - rowAlias = varDecl(1) - ), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(lit(ionInt(1)), varDecl(0)) + onConflict = onConflict( + excludedAlias = varDecl(1), + action = doReplace() + ) ) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "foo", - operation = dmlReplace( + }, + PartiqlPhysical.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), targetAlias = varDecl(0), - condition = gt( - listOf( - path( - localId(0), - listOf(pathExpr(lit(ionString("id")), caseInsensitive())) - ), - lit(ionInt(1)) - ) + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) ), - rowAlias = varDecl(1) - ), - rows = bindingsToValues( - struct(structFields(localId(0))), - scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) + onConflict = onConflict( + excludedAlias = varDecl(1), + action = doReplace() + ) ) - ) - } - ), - DmlTestCase( - // DELETE FROM y AS y - PartiqlLogicalResolved.build { - dml( - uniqueId = "foo", - operation = dmlDelete(), - rows = bindingsToValues( - localId(0), - scan(globalId("y"), varDecl(0)) + } + ), + DmlTestCase( + // INSERT INTO foo [AS f] SELECT x.* FROM 1 AS x ON CONFLICT DO REPLACE EXCLUDED WHERE foo.id > 1 + PartiqlLogicalResolved.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(lit(ionInt(1)), varDecl(0)) + ), + onConflict = onConflict( + excludedAlias = varDecl(1), + action = doReplace(), + condition = gt( + listOf( + path( + localId(0), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(1)) + ) + ), + ) ) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "foo", - operation = dmlDelete(), - rows = bindingsToValues( - localId(0), - scan(DEFAULT_IMPL, globalId("y"), varDecl(0)) + }, + PartiqlPhysical.build { + dmlInsert( + target = dmlTarget_(SymbolPrimitive("foo", emptyMap())), + targetAlias = varDecl(0), + rowsToInsert = bindingsToValues( + struct(structFields(localId(0))), + scan(DEFAULT_IMPL, lit(ionInt(1)), varDecl(0)) + ), + onConflict = onConflict( + excludedAlias = varDecl(1), + action = doReplace(), + condition = gt( + listOf( + path( + localId(0), + listOf(pathExpr(lit(ionString("id")), caseInsensitive())) + ), + lit(ionInt(1)) + ) + ), + ) ) - ) - } - ), - DmlTestCase( - // DELETE FROM y AS y WHERE 1=1 - PartiqlLogicalResolved.build { - dml( - uniqueId = "y", - operation = dmlDelete(), - rows = bindingsToValues( - localId(0), - // this logical plan is same as previous but includes this filter - filter( + } + ), + DmlTestCase( + // DELETE FROM y AS y + PartiqlLogicalResolved.build { + dmlDelete( + from = scan( + globalId("y"), + varDecl(0) + ), + ) + }, + PartiqlPhysical.build { + dmlDelete( + from = scan( + DEFAULT_IMPL, + globalId("y"), + varDecl(0) + ), + ) + } + ), + DmlTestCase( + // DELETE FROM y AS y WHERE 1=1 + PartiqlLogicalResolved.build { + dmlDelete( + from = filter( eq(lit(ionInt(1)), lit(ionInt(1))), - scan(globalId("y"), varDecl(0)) - ) + scan( + globalId("y"), + varDecl(0) + ) + ), ) - ) - }, - PartiqlPhysical.build { - dml( - uniqueId = "y", - operation = dmlDelete(), - rows = bindingsToValues( - localId(0), - // this logical plan is same as previous but includes this filter - filter( + }, + PartiqlPhysical.build { + dmlDelete( + from = filter( DEFAULT_IMPL, eq(lit(ionInt(1)), lit(ionInt(1))), - scan(DEFAULT_IMPL, globalId("y"), varDecl(0)) - ) + scan( + DEFAULT_IMPL, + globalId("y"), + varDecl(0) + ) + ), ) - ) - } - ), - ) + } + ), + ) + return listOf + } } } diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt index 1bd57da69c..73ac974fcf 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/LogicalToLogicalResolvedVisitorTransformTests.kt @@ -873,9 +873,9 @@ class LogicalToLogicalResolvedVisitorTransformTests { } @ParameterizedTest - @ArgumentsSource(DmlStatements::class) - fun `dml statements`(tc: TestCase) = runTestCase(tc) - class DmlStatements : ArgumentsProviderBase() { + @ArgumentsSource(InsertStatements::class) + fun `insert statements`(tc: TestCase) = runTestCase(tc) + class InsertStatements : ArgumentsProviderBase() { val EXCLUDED = AstToLogicalVisitorTransform.EXCLUDED override fun getParameters() = listOf( TestCase( @@ -958,4 +958,26 @@ class LogicalToLogicalResolvedVisitorTransformTests { ), ) } + + @ParameterizedTest + @ArgumentsSource(UpdateStatements::class) + fun `update statements`(tc: TestCase) = runTestCase(tc) + class UpdateStatements : ArgumentsProviderBase() { + override fun getParameters() = listOf( + TestCase( + """ + UPDATE foo AS f SET + bar = f.bat || f.baz, + bor['bop'] = 'biz' + WHERE + f.bork = 42 + """, + Expectation.Success( + ResolvedId(3, 31) { localId(0) }, + ResolvedId(3, 40) { localId(0) }, + ResolvedId(6, 25) { localId(0) }, + ).withLocals(localVariable("f", 0)) + ), + ) + } }