From 87195e997f6b2e8d94dedbff4ddcf1d1cdd0a38e Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Fri, 27 Sep 2024 14:27:40 +0530 Subject: [PATCH] Support insert rows in a table Co-authored-by: Anshul Data --- src/from_substrait.cpp | 2 ++ src/include/to_substrait.hpp | 1 + src/to_substrait.cpp | 23 +++++++++++++++++++++++ test/c/test_substrait_c_api.cpp | 21 +++++++++++++++++++++ test/python/test_substrait.py | 23 +++++++++++++++++++++++ 5 files changed, 70 insertions(+) diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 4b0e141..5cee1c3 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -639,6 +639,8 @@ shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s switch (swrite.op()) { case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: return input->CreateRel(schema_name, table_name); + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT: + return input->InsertRel(schema_name, table_name); case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: { auto filter = std::move(input.get()->Cast()); auto context = filter.child->Cast().context; diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index e3cde98..9d3ffc9 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -61,6 +61,7 @@ class DuckDBToSubstrait { substrait::Rel *TransformExcept(LogicalOperator &dop); substrait::Rel *TransformIntersect(LogicalOperator &dop); substrait::Rel *TransformCreateTable(LogicalOperator &dop); + substrait::Rel *TransformInsertTable(LogicalOperator &dop); substrait::Rel *TransformDeleteTable(LogicalOperator &dop); static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index e0e465f..cf3f2c8 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1472,6 +1472,27 @@ void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait: named_table->add_names(table.name); } +substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &insert_table = dop.Cast(); + if (insert_table.children.size() != 1) { + throw InternalException("insert table expected one child, found " + to_string(insert_table.children.size())); + } + + auto writeRel = rel->mutable_write(); + writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT); + writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT); + + SetNamedTable(insert_table.table, writeRel); + auto schema = new substrait::NamedStruct(); + SetTableSchema(insert_table.table, schema); + writeRel->set_allocated_table_schema(schema); + + substrait::Rel *input = TransformOp(*insert_table.children[0]); + writeRel->set_allocated_input(input); + return rel; +} + substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) { auto rel = new substrait::Rel(); auto &logical_delete = dop.Cast(); @@ -1530,6 +1551,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformDummyScan(); case LogicalOperatorType::LOGICAL_CREATE_TABLE: return TransformCreateTable(dop); + case LogicalOperatorType::LOGICAL_INSERT: + return TransformInsertTable(dop); case LogicalOperatorType::LOGICAL_DELETE: return TransformDeleteTable(dop); default: diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 6e1a723..e7a7c36 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -259,6 +259,27 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") { REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2})); } +TEST_CASE("Test C InsertRows with Substrait API", "[substrait-api]") { + DuckDB db(nullptr); + Connection con(db); + + CreateEmployeeTable(con); + REQUIRE_NO_FAIL(con.Query("CREATE TABLE senior_employees (" + "employee_id INTEGER PRIMARY KEY, " + "name VARCHAR(100), " + "department_id INTEGER, " + "salary DECIMAL(10, 2))")); + + ExecuteViaSubstrait(con, "INSERT INTO senior_employees " + "SELECT * FROM employees WHERE salary > 80000"); + + auto result = ExecuteViaSubstrait(con, "SELECT * from senior_employees"); + REQUIRE(CHECK_COLUMN(result, 0, {1, 4})); + REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown"})); + REQUIRE(CHECK_COLUMN(result, 2, {1, 3})); + REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000})); +} + TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") { DuckDB db(nullptr); Connection con(db); diff --git a/test/python/test_substrait.py b/test/python/test_substrait.py index 91dad21..1b446e6 100644 --- a/test/python/test_substrait.py +++ b/test/python/test_substrait.py @@ -152,6 +152,20 @@ def test_ctas_with_union(require): pd.testing.assert_frame_equal(query_result.df(), expected) +def test_insert_rows_into_table(require): + connection = require('substrait') + create_employee_table(connection) + create_senior_employees_table(connection) + _ = execute_via_substrait(connection, "INSERT INTO senior_employees SELECT * FROM employees WHERE salary > 80000") + + expected = pd.DataFrame({"employee_id": pd.Series([1, 4], dtype="int32"), + "name": ["John Doe", "Bob Brown"], + "department_id": pd.Series([1, 3], dtype="int32"), + "salary": pd.Series([120000, 95000], dtype="float64")}) + query_result = execute_via_substrait(connection, "SELECT * FROM senior_employees") + pd.testing.assert_frame_equal(query_result.df(), expected) + + def test_delete_rows_in_table(require): connection = require('substrait') create_employee_table(connection) @@ -208,6 +222,15 @@ def create_part_time_employee_table(connection): (7, 'Eve Green', 2, 20) """) +def create_senior_employees_table(connection): + connection.execute(""" + CREATE TABLE senior_employees ( + employee_id INTEGER PRIMARY KEY, + name VARCHAR(100), + department_id INTEGER, + salary DECIMAL(10, 2) + ) + """) def create_departments_table(connection): connection.execute("""