Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
Support insert rows in a table
Browse files Browse the repository at this point in the history
Co-authored-by: Anshul Data <[email protected]>
  • Loading branch information
scgkiran committed Oct 2, 2024
1 parent aa4c846 commit 87195e9
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &s
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT:
return input->InsertRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
Expand Down
23 changes: 23 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,27 @@ void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait:
named_table->add_names(table.name);
}

substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &insert_table = dop.Cast<LogicalInsert>();
if (insert_table.children.size() != 1) {
throw InternalException("insert table expected one child, found " + to_string(insert_table.children.size()));
}

auto writeRel = rel->mutable_write();
writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT);
writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT);

SetNamedTable(insert_table.table, writeRel);
auto schema = new substrait::NamedStruct();
SetTableSchema(insert_table.table, schema);
writeRel->set_allocated_table_schema(schema);

substrait::Rel *input = TransformOp(*insert_table.children[0]);
writeRel->set_allocated_input(input);
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_delete = dop.Cast<LogicalDelete>();
Expand Down Expand Up @@ -1530,6 +1551,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
case LogicalOperatorType::LOGICAL_INSERT:
return TransformInsertTable(dop);
case LogicalOperatorType::LOGICAL_DELETE:
return TransformDeleteTable(dop);
default:
Expand Down
21 changes: 21 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2}));
}

TEST_CASE("Test C InsertRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
REQUIRE_NO_FAIL(con.Query("CREATE TABLE senior_employees ("
"employee_id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"salary DECIMAL(10, 2))"));

ExecuteViaSubstrait(con, "INSERT INTO senior_employees "
"SELECT * FROM employees WHERE salary > 80000");

auto result = ExecuteViaSubstrait(con, "SELECT * from senior_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 4}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000}));
}

TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);
Expand Down
23 changes: 23 additions & 0 deletions test/python/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def test_ctas_with_union(require):
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_insert_rows_into_table(require):
connection = require('substrait')
create_employee_table(connection)
create_senior_employees_table(connection)
_ = execute_via_substrait(connection, "INSERT INTO senior_employees SELECT * FROM employees WHERE salary > 80000")

expected = pd.DataFrame({"employee_id": pd.Series([1, 4], dtype="int32"),
"name": ["John Doe", "Bob Brown"],
"department_id": pd.Series([1, 3], dtype="int32"),
"salary": pd.Series([120000, 95000], dtype="float64")})
query_result = execute_via_substrait(connection, "SELECT * FROM senior_employees")
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_delete_rows_in_table(require):
connection = require('substrait')
create_employee_table(connection)
Expand Down Expand Up @@ -208,6 +222,15 @@ def create_part_time_employee_table(connection):
(7, 'Eve Green', 2, 20)
""")

def create_senior_employees_table(connection):
connection.execute("""
CREATE TABLE senior_employees (
employee_id INTEGER PRIMARY KEY,
name VARCHAR(100),
department_id INTEGER,
salary DECIMAL(10, 2)
)
""")

def create_departments_table(connection):
connection.execute("""
Expand Down

0 comments on commit 87195e9

Please sign in to comment.