Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Support insert rows in a table #9

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
Expand Down
23 changes: 23 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,27 @@ void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait:
named_table->add_names(table.name);
}

substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &insert_table = dop.Cast<LogicalInsert>();
if (insert_table.children.size() != 1) {
throw InternalException("insert table expected one child, found " + to_string(insert_table.children.size()));
}

auto writeRel = rel->mutable_write();
writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT);
writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT);

SetNamedTable(insert_table.table, writeRel);
auto schema = new substrait::NamedStruct();
SetTableSchema(insert_table.table, schema);
writeRel->set_allocated_table_schema(schema);

substrait::Rel *input = TransformOp(*insert_table.children[0]);
writeRel->set_allocated_input(input);
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_delete = dop.Cast<LogicalDelete>();
Expand Down Expand Up @@ -1530,6 +1551,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
case LogicalOperatorType::LOGICAL_INSERT:
return TransformInsertTable(dop);
case LogicalOperatorType::LOGICAL_DELETE:
return TransformDeleteTable(dop);
default:
Expand Down
21 changes: 21 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2}));
}

TEST_CASE("Test C InsertRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
REQUIRE_NO_FAIL(con.Query("CREATE TABLE senior_employees ("
"employee_id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"salary DECIMAL(10, 2))"));

ExecuteViaSubstrait(con, "INSERT INTO senior_employees "
"SELECT * FROM employees WHERE salary > 80000");

auto result = ExecuteViaSubstrait(con, "SELECT * from senior_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 4}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000}));
}

TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);
Expand Down
23 changes: 23 additions & 0 deletions test/python/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def test_ctas_with_union(require):
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_insert_rows_into_table(require):
connection = require('substrait')
create_employee_table(connection)
create_senior_employees_table(connection)
_ = execute_via_substrait(connection, "INSERT INTO senior_employees SELECT * FROM employees WHERE salary > 80000")

expected = pd.DataFrame({"employee_id": pd.Series([1, 4], dtype="int32"),
"name": ["John Doe", "Bob Brown"],
"department_id": pd.Series([1, 3], dtype="int32"),
"salary": pd.Series([120000, 95000], dtype="float64")})
query_result = execute_via_substrait(connection, "SELECT * FROM senior_employees")
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_delete_rows_in_table(require):
connection = require('substrait')
create_employee_table(connection)
Expand Down Expand Up @@ -208,6 +222,15 @@ def create_part_time_employee_table(connection):
(7, 'Eve Green', 2, 20)
""")

def create_senior_employees_table(connection):
connection.execute("""
CREATE TABLE senior_employees (
employee_id INTEGER PRIMARY KEY,
name VARCHAR(100),
department_id INTEGER,
salary DECIMAL(10, 2)
)
""")

def create_departments_table(connection):
connection.execute("""
Expand Down
Loading