diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 761724e..4b0e141 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -26,6 +26,7 @@ #include "substrait/plan.pb.h" #include "duckdb/main/relation/create_table_relation.hpp" +#include #include "duckdb/main/relation/table_relation.hpp" namespace duckdb { @@ -623,22 +624,25 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) { auto &swrite = sop.write(); + auto &nobj = swrite.named_table(); + if (nobj.names_size() == 0) { + throw InvalidInputException("Named object must have at least one name"); + } + auto table_idx = nobj.names_size() - 1; + auto table_name = nobj.names(table_idx); + string schema_name; + if (table_idx > 0) { + schema_name = nobj.names(0); + } + auto input = TransformOp(swrite.input()); switch (swrite.op()) { - case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { - auto &nobj = swrite.named_table(); - if (nobj.names_size() == 0) { - throw InvalidInputException("Named object must have at least one name"); - } - auto table_idx = nobj.names_size() - 1; - auto table_name = nobj.names(table_idx); - string schema_name; - if (table_idx > 0) { - schema_name = nobj.names(0); - } - - auto input = TransformOp(swrite.input()); + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: return input->CreateRel(schema_name, table_name); + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: { + auto filter = std::move(input.get()->Cast()); + auto context = filter.child->Cast().context; + return make_shared_ptr(filter.context, std::move(filter.condition), schema_name, table_name); } default: throw NotImplementedException("Unsupported write operation" + to_string(swrite.op())); diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 08b51f3..e3cde98 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -39,6 +39,8 @@ class DuckDBToSubstrait { static vector DepthFirstNames(const LogicalType &type); static void DepthFirstNamesRecurse(vector &names, const LogicalType &type); static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr); + static void SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema); + static void SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel); //! Transforms Relation Root substrait::RelRoot *TransformRootOp(LogicalOperator &dop); @@ -59,6 +61,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformExcept(LogicalOperator &dop); substrait::Rel *TransformIntersect(LogicalOperator &dop); substrait::Rel *TransformCreateTable(LogicalOperator &dop); + substrait::Rel *TransformDeleteTable(LogicalOperator &dop); static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 125a540..e0e465f 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1453,6 +1453,51 @@ substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) { return rel; } +void DuckDBToSubstrait::SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema) { + for (auto &name : table.GetColumns().GetColumnNames()) { + schema->add_names(name); + } + auto type_info = new substrait::Type_Struct(); + type_info->set_nullability(substrait::Type_Nullability_NULLABILITY_REQUIRED); + for (auto &col_type : table.GetColumns().GetColumnTypes()) { + auto s_type = DuckToSubstraitType(col_type, nullptr, false); + *type_info->add_types() = s_type; + } + schema->set_allocated_struct_(type_info); +} + +void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel) { + auto named_table = writeRel->mutable_named_table(); + named_table->add_names(table.schema.name); + named_table->add_names(table.name); +} + +substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &logical_delete = dop.Cast(); + auto &table = logical_delete.table; + if (logical_delete.children.size() != 1) { + throw InternalException("Delete table expected one child, found " + to_string(logical_delete.children.size())); + } + + auto writeRel = rel->mutable_write(); + writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE); + writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT); + + auto named_table = writeRel->mutable_named_table(); + named_table->add_names(table.schema.name); + named_table->add_names(table.name); + + SetNamedTable(logical_delete.table, writeRel); + auto schema = new substrait::NamedStruct(); + SetTableSchema(logical_delete.table, schema); + writeRel->set_allocated_table_schema(schema); + + substrait::Rel *input = TransformOp(*logical_delete.children[0]); + writeRel->set_allocated_input(input); + return rel; +} + substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { switch (dop.type) { case LogicalOperatorType::LOGICAL_FILTER: @@ -1485,6 +1530,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformDummyScan(); case LogicalOperatorType::LOGICAL_CREATE_TABLE: return TransformCreateTable(dop); + case LogicalOperatorType::LOGICAL_DELETE: + return TransformDeleteTable(dop); default: throw NotImplementedException(LogicalOperatorToString(dop.type)); } @@ -1495,8 +1542,23 @@ static bool IsSetOperation(const LogicalOperator &op) { op.type == LogicalOperatorType::LOGICAL_INTERSECT; } +static bool IsRowModificationOperator(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_INSERT: + case LogicalOperatorType::LOGICAL_DELETE: + case LogicalOperatorType::LOGICAL_UPDATE: + return true; + default: + return false; + } +} + substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { auto root_rel = new substrait::RelRoot(); + if (IsRowModificationOperator(dop)) { + root_rel->set_allocated_input(TransformOp(dop)); + return root_rel; + } LogicalOperator *current_op = &dop; bool weird_scenario = current_op->type == LogicalOperatorType::LOGICAL_PROJECTION && current_op->children[0]->type == LogicalOperatorType::LOGICAL_TOP_N; diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 6999c03..6e1a723 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -257,4 +257,18 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") { REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5, 6, 7})); REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", "David White", "Eve Green"})); REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2})); -} \ No newline at end of file +} + +TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + + ExecuteViaSubstraitJSON(con, "DELETE FROM employees WHERE salary < 80000"); + auto result = ExecuteViaSubstrait(con, "SELECT * from employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 4})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Bob Brown"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 3})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 95000})); +} diff --git a/test/python/test_substrait.py b/test/python/test_substrait.py index 38590fa..91dad21 100644 --- a/test/python/test_substrait.py +++ b/test/python/test_substrait.py @@ -152,6 +152,18 @@ def test_ctas_with_union(require): pd.testing.assert_frame_equal(query_result.df(), expected) +def test_delete_rows_in_table(require): + connection = require('substrait') + create_employee_table(connection) + connection.execute("DELETE FROM employees WHERE salary < 80000") + query_result = execute_via_substrait(connection, "SELECT * FROM employees") + expected = pd.DataFrame({"employee_id": pd.Series([1, 2, 4], dtype="int32"), + "name": ["John Doe", "Jane Smith", "Bob Brown"], + "department_id": pd.Series([1, 2, 3], dtype="int32"), + "salary": pd.Series([120000, 80000, 95000], dtype="float64")}) + pd.testing.assert_frame_equal(query_result.df(), expected) + + def execute_via_substrait(connection, query): res = connection.get_substrait(query) proto_bytes = res.fetchone()[0]