Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Support delete rows in a table #8

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "substrait/plan.pb.h"

#include "duckdb/main/relation/create_table_relation.hpp"
#include <duckdb/main/relation/delete_relation.hpp>
#include "duckdb/main/relation/table_relation.hpp"

namespace duckdb {
Expand Down Expand Up @@ -623,22 +624,25 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS:
return input->CreateRel(schema_name, table_name);
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE: {
auto filter = std::move(input.get()->Cast<FilterRelation>());
auto context = filter.child->Cast<TableRelation>().context;
return make_shared_ptr<DeleteRelation>(filter.context, std::move(filter.condition), schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported write operation" + to_string(swrite.op()));
Expand Down
3 changes: 3 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class DuckDBToSubstrait {
static vector<string> DepthFirstNames(const LogicalType &type);
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr);
static void SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema);
static void SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel);

//! Transforms Relation Root
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
Expand All @@ -59,6 +61,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
62 changes: 62 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,51 @@ substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
return rel;
}

void DuckDBToSubstrait::SetTableSchema(const TableCatalogEntry &table, substrait::NamedStruct *schema) {
for (auto &name : table.GetColumns().GetColumnNames()) {
schema->add_names(name);
}
auto type_info = new substrait::Type_Struct();
type_info->set_nullability(substrait::Type_Nullability_NULLABILITY_REQUIRED);
for (auto &col_type : table.GetColumns().GetColumnTypes()) {
auto s_type = DuckToSubstraitType(col_type, nullptr, false);
*type_info->add_types() = s_type;
}
schema->set_allocated_struct_(type_info);
}

void DuckDBToSubstrait::SetNamedTable(const TableCatalogEntry &table, substrait::WriteRel *writeRel) {
auto named_table = writeRel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);
}

substrait::Rel *DuckDBToSubstrait::TransformDeleteTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &logical_delete = dop.Cast<LogicalDelete>();
auto &table = logical_delete.table;
if (logical_delete.children.size() != 1) {
throw InternalException("Delete table expected one child, found " + to_string(logical_delete.children.size()));
}

auto writeRel = rel->mutable_write();
writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_DELETE);
writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT);

auto named_table = writeRel->mutable_named_table();
named_table->add_names(table.schema.name);
named_table->add_names(table.name);

SetNamedTable(logical_delete.table, writeRel);
auto schema = new substrait::NamedStruct();
SetTableSchema(logical_delete.table, schema);
writeRel->set_allocated_table_schema(schema);

substrait::Rel *input = TransformOp(*logical_delete.children[0]);
writeRel->set_allocated_input(input);
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1485,6 +1530,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
case LogicalOperatorType::LOGICAL_DELETE:
return TransformDeleteTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand All @@ -1495,8 +1542,23 @@ static bool IsSetOperation(const LogicalOperator &op) {
op.type == LogicalOperatorType::LOGICAL_INTERSECT;
}

static bool IsRowModificationOperator(const LogicalOperator &op) {
switch (op.type) {
case LogicalOperatorType::LOGICAL_INSERT:
case LogicalOperatorType::LOGICAL_DELETE:
case LogicalOperatorType::LOGICAL_UPDATE:
return true;
default:
return false;
}
}

substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) {
auto root_rel = new substrait::RelRoot();
if (IsRowModificationOperator(dop)) {
root_rel->set_allocated_input(TransformOp(dop));
return root_rel;
}
LogicalOperator *current_op = &dop;
bool weird_scenario = current_op->type == LogicalOperatorType::LOGICAL_PROJECTION &&
current_op->children[0]->type == LogicalOperatorType::LOGICAL_TOP_N;
Expand Down
16 changes: 15 additions & 1 deletion test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,18 @@ TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5, 6, 7}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", "David White", "Eve Green"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2}));
}
}

TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "DELETE FROM employees WHERE salary < 80000");
auto result = ExecuteViaSubstrait(con, "SELECT * from employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 4}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Bob Brown"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 95000}));
}
12 changes: 12 additions & 0 deletions test/python/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def test_ctas_with_union(require):
pd.testing.assert_frame_equal(query_result.df(), expected)


def test_delete_rows_in_table(require):
connection = require('substrait')
create_employee_table(connection)
connection.execute("DELETE FROM employees WHERE salary < 80000")
query_result = execute_via_substrait(connection, "SELECT * FROM employees")
expected = pd.DataFrame({"employee_id": pd.Series([1, 2, 4], dtype="int32"),
"name": ["John Doe", "Jane Smith", "Bob Brown"],
"department_id": pd.Series([1, 2, 3], dtype="int32"),
"salary": pd.Series([120000, 80000, 95000], dtype="float64")})
pd.testing.assert_frame_equal(query_result.df(), expected)


def execute_via_substrait(connection, query):
res = connection.get_substrait(query)
proto_bytes = res.fetchone()[0]
Expand Down
Loading