Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
Support CTAS (#7)
Browse files Browse the repository at this point in the history
* Support CTAS to/from substrait
* Add python & C test cases
  • Loading branch information
scgkiran authored Sep 27, 2024
1 parent 9a3deee commit a0b1fa4
Show file tree
Hide file tree
Showing 6 changed files with 495 additions and 2 deletions.
40 changes: 40 additions & 0 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "google/protobuf/util/json_util.h"
#include "substrait/plan.pb.h"

#include "duckdb/main/relation/create_table_relation.hpp"
#include "duckdb/main/relation/table_relation.hpp"

namespace duckdb {
Expand Down Expand Up @@ -620,6 +621,30 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop
return make_shared_ptr<SetOpRelation>(std::move(lhs), std::move(rhs), type);
}

shared_ptr<Relation> SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) {
auto &swrite = sop.write();

switch (swrite.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
auto &nobj = swrite.named_table();
if (nobj.names_size() == 0) {
throw InvalidInputException("Named object must have at least one name");
}
auto table_idx = nobj.names_size() - 1;
auto table_name = nobj.names(table_idx);
string schema_name;
if (table_idx > 0) {
schema_name = nobj.names(0);
}

auto input = TransformOp(swrite.input());
return input->CreateRel(schema_name, table_name);
}
default:
throw NotImplementedException("Unsupported write operation" + to_string(swrite.op()));
}
}

shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
Expand All @@ -640,6 +665,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) {
return TransformSortOp(sop);
case substrait::Rel::RelTypeCase::kSet:
return TransformSetOp(sop);
case substrait::Rel::RelTypeCase::kWrite:
return TransformWriteOp(sop);
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
Expand Down Expand Up @@ -699,6 +726,19 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot
}
}

if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) {
auto write = sop.input().write();
switch (write.op()) {
case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: {
const auto create_table = static_cast<CreateTableRelation *>(child.get());
auto proj = make_shared_ptr<ProjectionRelation>(create_table->child, std::move(expressions), aliases);
return proj->CreateRel(create_table->schema_name, create_table->table_name);
}
default:
return child;
}
}

return make_shared_ptr<ProjectionRelation>(child, std::move(expressions), aliases);
}

Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SubstraitToDuckDB {
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformWriteOp(const substrait::Rel &sop);

//! Transform Substrait Expressions to DuckDB Expressions
unique_ptr<ParsedExpression> TransformExpr(const substrait::Expression &sexpr);
Expand Down
2 changes: 2 additions & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class DuckDBToSubstrait {
//! In case of struct types we might we do DFS to get all names
static vector<string> DepthFirstNames(const LogicalType &type);
static void DepthFirstNamesRecurse(vector<string> &names, const LogicalType &type);
static substrait::Expression_Literal ToExpressionLiteral(const substrait::Expression &expr);

//! Transforms Relation Root
substrait::RelRoot *TransformRootOp(LogicalOperator &dop);
Expand All @@ -57,6 +58,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformDistinct(LogicalOperator &dop);
substrait::Rel *TransformExcept(LogicalOperator &dop);
substrait::Rel *TransformIntersect(LogicalOperator &dop);
substrait::Rel *TransformCreateTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
//! To Substrait;
Expand Down
45 changes: 43 additions & 2 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void DuckDBToSubstrait::TransformBetweenExpression(Expression &dexpr, substrait:
args_types.emplace_back(DuckToSubstraitType(dcomp.lower->return_type));
args_types.emplace_back(DuckToSubstraitType(dcomp.upper->return_type));
scalar_fun->set_function_reference(RegisterFunction("between", args_types));

auto sarg = scalar_fun->add_arguments();
TransformExpr(*dcomp.input, *sarg->mutable_value(), 0);
sarg = scalar_fun->add_arguments();
Expand Down Expand Up @@ -1381,7 +1381,8 @@ substrait::Rel *DuckDBToSubstrait::TransformDistinct(LogicalOperator &dop) {
set_op->set_op(substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_INTERSECTION_PRIMARY);
break;
default:
throw NotImplementedException("Found unexpected child type in Distinct operator");
throw NotImplementedException("Found unexpected child type in Distinct operator " +
LogicalOperatorToString(set_operation_p->type));
}
auto &set_operation = set_operation_p->Cast<LogicalSetOperation>();

Expand Down Expand Up @@ -1417,6 +1418,41 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) {
return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto &create_table = dop.Cast<LogicalCreateTable>();
auto &create_info = create_table.info.get()->Base();
if (create_table.children.size() != 1) {
if (create_table.children.size() == 0) {
throw NotImplementedException("Create table without children not implemented");
}
throw InternalException("Create table with more than one child is not supported");
}

auto schema = new substrait::NamedStruct();
auto type_info = new substrait::Type_Struct();
for (auto &name : create_info.columns.GetColumnNames()) {
schema->add_names(name);
}
for (auto &col_type : create_info.columns.GetColumnTypes()) {
auto s_type = DuckToSubstraitType(col_type, nullptr, false);
*type_info->add_types() = s_type;
}
schema->set_allocated_struct_(type_info);

// This is CreateTableAsSelect
substrait::Rel *input = TransformOp(*create_table.children[0]);
auto write = rel->mutable_write();
write->set_allocated_table_schema(schema);
write->set_allocated_input(input);
write->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS);
auto named_table = write->mutable_named_table();
named_table->add_names(create_info.schema);
named_table->add_names(create_info.table);

return rel;
}

substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
switch (dop.type) {
case LogicalOperatorType::LOGICAL_FILTER:
Expand Down Expand Up @@ -1447,6 +1483,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformIntersect(dop);
case LogicalOperatorType::LOGICAL_DUMMY_SCAN:
return TransformDummyScan();
case LogicalOperatorType::LOGICAL_CREATE_TABLE:
return TransformCreateTable(dop);
default:
throw NotImplementedException(LogicalOperatorToString(dop.type));
}
Expand Down Expand Up @@ -1477,6 +1515,9 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) {
continue;
}
if (current_op->children.size() != 1) {
if (current_op->type == LogicalOperatorType::LOGICAL_CREATE_TABLE) {
break;
}
throw InternalException("Root node has more than 1, or 0 children (%d) up to "
"reaching a projection node. Type %d",
current_op->children.size(), current_op->type);
Expand Down
213 changes: 213 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,216 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") {

REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid"));
}

duckdb::unique_ptr<QueryResult> ExecuteViaSubstrait(Connection &con, const string &sql) {
auto proto = con.GetSubstrait(sql);
return con.FromSubstrait(proto);
}

duckdb::unique_ptr<QueryResult> ExecuteViaSubstraitJSON(Connection &con, const string &sql) {
auto json_str = con.GetSubstraitJSON(sql);
return con.FromSubstraitJSON(json_str);
}

void CreateEmployeeTable(Connection& con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE employees ("
"employee_id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"salary DECIMAL(10, 2))"));

REQUIRE_NO_FAIL(con.Query("INSERT INTO employees VALUES "
"(1, 'John Doe', 1, 120000), "
"(2, 'Jane Smith', 2, 80000), "
"(3, 'Alice Johnson', 1, 50000), "
"(4, 'Bob Brown', 3, 95000), "
"(5, 'Charlie Black', 2, 60000)"));
}

void CreatePartTimeEmployeeTable(Connection& con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE part_time_employees ("
"id INTEGER PRIMARY KEY, "
"name VARCHAR(100), "
"department_id INTEGER, "
"hourly_rate DECIMAL(10, 2))"));

REQUIRE_NO_FAIL(con.Query("INSERT INTO part_time_employees VALUES "
"(6, 'David White', 1, 30000), "
"(7, 'Eve Green', 2, 40000)"));
}

void CreateDepartmentsTable(Connection& con) {
REQUIRE_NO_FAIL(con.Query("CREATE TABLE departments (department_id INTEGER PRIMARY KEY, department_name VARCHAR(100))"));

REQUIRE_NO_FAIL(con.Query("INSERT INTO departments VALUES "
"(1, 'HR'), "
"(2, 'Engineering'), "
"(3, 'Finance')"));
}

TEST_CASE("Test C CTAS Select columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_salaries AS "
"SELECT name, salary FROM employees"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from employee_salaries");
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C CTAS Filter with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE filtered_employees AS "
"SELECT * FROM employees "
"WHERE salary > 80000;"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from filtered_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 4}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000}));
}

TEST_CASE("Test C CTAS Case_When with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE categorized_employees AS "
"SELECT name, "
"CASE "
"WHEN salary > 100000 THEN 'High' "
"WHEN salary BETWEEN 60000 AND 100000 THEN 'Medium' "
"ELSE 'Low' "
"END AS salary_category "
"FROM employees"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from categorized_employees");
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {"High", "Medium", "Low", "Medium", "Medium"}));
}

TEST_CASE("Test C CTAS OrderBy with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE ordered_employees AS "
"SELECT * FROM employees "
"ORDER BY salary DESC"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from ordered_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 4, 2, 5, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Bob Brown", "Jane Smith", "Charlie Black", "Alice Johnson"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 3, 2, 2, 1}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 95000, 80000, 60000, 50000}));
}

TEST_CASE("Test C CTAS SubQuery with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE high_salary_employees AS "
"SELECT * "
"FROM ( "
"SELECT employee_id, name, salary "
"FROM employees "
"WHERE salary > 100000)"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from high_salary_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe"}));
REQUIRE(CHECK_COLUMN(result, 2, {120000}));
}

TEST_CASE("Test C CTAS Distinct with Substrait API", "[substrait-api]") {
SKIP_TEST("SKIP: Distinct operator has unsupported child type"); // TODO fix TransformDistinct
return;

DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
ExecuteViaSubstraitJSON(con, "CREATE TABLE unique_departments AS "
"SELECT DISTINCT department_id FROM employees"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from unique_departments");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
}

TEST_CASE("Test C CTAS Aggregation with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE department_summary AS "
"SELECT department_id, COUNT(*) AS employee_count "
"FROM employees "
"GROUP BY department_id"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from department_summary");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {2, 2, 1}));
}

TEST_CASE("Test C CTAS Join with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
CreateDepartmentsTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE employee_departments AS "
"SELECT e.employee_id, e.name, d.department_name "
"FROM employees e "
"JOIN departments d "
"ON e.department_id = d.department_id"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from employee_departments");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {"HR", "Engineering", "HR", "Finance", "Engineering"}));
}

TEST_CASE("Test C CTAS Union with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
CreatePartTimeEmployeeTable(con);

ExecuteViaSubstraitJSON(con, "CREATE TABLE all_employees AS "
"SELECT employee_id, name, department_id, salary "
"FROM employees "
"UNION "
"SELECT id, name, department_id, hourly_rate * 2000 AS salary "
"FROM part_time_employees "
"ORDER BY employee_id"
);

auto result = ExecuteViaSubstrait(con, "SELECT * from all_employees");
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5, 6, 7}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black", "David White", "Eve Green"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2, 1, 2}));
}
Loading

0 comments on commit a0b1fa4

Please sign in to comment.