diff --git a/CMakeLists.txt b/CMakeLists.txt index f180c46..6d279fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,96 +10,100 @@ include_directories(src/include) include_directories(third_party/substrait) include_directories(third_party/) +# refer source by absolute path. So that we can use the same source in the child profile too (i.e. test/c) +set(THIRD_PARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party") set(PROTOBUF_SOURCES - third_party/google/protobuf/any.cc - third_party/google/protobuf/any.pb.cc - third_party/google/protobuf/any_lite.cc - third_party/google/protobuf/arena.cc - third_party/google/protobuf/arenastring.cc - third_party/google/protobuf/descriptor.cc - third_party/google/protobuf/descriptor.pb.cc - third_party/google/protobuf/descriptor_database.cc - third_party/google/protobuf/dynamic_message.cc - third_party/google/protobuf/empty.pb.cc - third_party/google/protobuf/extension_set.cc - third_party/google/protobuf/extension_set_heavy.cc - third_party/google/protobuf/generated_enum_util.cc - third_party/google/protobuf/generated_message_bases.cc - third_party/google/protobuf/generated_message_reflection.cc - third_party/google/protobuf/generated_message_table_driven.cc - third_party/google/protobuf/generated_message_table_driven_lite.cc - third_party/google/protobuf/generated_message_util.cc - third_party/google/protobuf/implicit_weak_message.cc - third_party/google/protobuf/inlined_string_field.cc - third_party/google/protobuf/map.cc - third_party/google/protobuf/map_field.cc - third_party/google/protobuf/message.cc - third_party/google/protobuf/message_lite.cc - third_party/google/protobuf/parse_context.cc - third_party/google/protobuf/port_def.inc - third_party/google/protobuf/port_undef.inc - third_party/google/protobuf/reflection_ops.cc - third_party/google/protobuf/repeated_field.cc - third_party/google/protobuf/repeated_ptr_field.cc - third_party/google/protobuf/text_format.cc - third_party/google/protobuf/unknown_field_set.cc - third_party/google/protobuf/wire_format.cc - third_party/google/protobuf/wire_format_lite.cc - third_party/google/protobuf/io/coded_stream.cc - third_party/google/protobuf/io/io_win32.cc - third_party/google/protobuf/io/strtod.cc - third_party/google/protobuf/io/tokenizer.cc - third_party/google/protobuf/io/zero_copy_stream.cc - third_party/google/protobuf/io/zero_copy_stream_impl.cc - third_party/google/protobuf/io/zero_copy_stream_impl_lite.cc - third_party/google/protobuf/stubs/common.cc - third_party/google/protobuf/stubs/int128.cc - third_party/google/protobuf/stubs/status.cc - third_party/google/protobuf/stubs/stringpiece.cc - third_party/google/protobuf/stubs/stringprintf.cc - third_party/google/protobuf/stubs/structurally_valid.cc - third_party/google/protobuf/stubs/strutil.cc - third_party/google/protobuf/stubs/substitute.cc - third_party/google/protobuf/stubs/bytestream.cc - third_party/google/protobuf/util/json_util.cc - third_party/google/protobuf/util/internal/datapiece.cc - third_party/google/protobuf/util/internal/default_value_objectwriter.cc - third_party/google/protobuf/util/internal/error_listener.cc - third_party/google/protobuf/util/internal/json_escaping.cc - third_party/google/protobuf/util/internal/json_objectwriter.cc - third_party/google/protobuf/util/internal/json_stream_parser.cc - third_party/google/protobuf/util/json_util.cc - third_party/google/protobuf/util/internal/object_writer.cc - third_party/google/protobuf/util/internal/proto_writer.cc - third_party/google/protobuf/util/internal/protostream_objectsource.cc - third_party/google/protobuf/util/internal/protostream_objectwriter.cc - third_party/google/protobuf/source_context.pb.cc - third_party/google/protobuf/stubs/statusor.cc - third_party/google/protobuf/stubs/time.cc - third_party/google/protobuf/type.pb.cc - third_party/google/protobuf/wrappers.pb.cc - third_party/google/protobuf/struct.pb.cc - third_party/google/protobuf/util/internal/type_info.cc - third_party/google/protobuf/util/internal/field_mask_utility.cc - third_party/google/protobuf/util/type_resolver_util.cc - third_party/google/protobuf/util/internal/utility.cc) + ${THIRD_PARTY_DIR}/google/protobuf/any.cc + ${THIRD_PARTY_DIR}/google/protobuf/any.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/any_lite.cc + ${THIRD_PARTY_DIR}/google/protobuf/arena.cc + ${THIRD_PARTY_DIR}/google/protobuf/arenastring.cc + ${THIRD_PARTY_DIR}/google/protobuf/descriptor.cc + ${THIRD_PARTY_DIR}/google/protobuf/descriptor.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/descriptor_database.cc + ${THIRD_PARTY_DIR}/google/protobuf/dynamic_message.cc + ${THIRD_PARTY_DIR}/google/protobuf/empty.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/extension_set.cc + ${THIRD_PARTY_DIR}/google/protobuf/extension_set_heavy.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_enum_util.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_message_bases.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_message_reflection.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_message_table_driven.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_message_table_driven_lite.cc + ${THIRD_PARTY_DIR}/google/protobuf/generated_message_util.cc + ${THIRD_PARTY_DIR}/google/protobuf/implicit_weak_message.cc + ${THIRD_PARTY_DIR}/google/protobuf/inlined_string_field.cc + ${THIRD_PARTY_DIR}/google/protobuf/map.cc + ${THIRD_PARTY_DIR}/google/protobuf/map_field.cc + ${THIRD_PARTY_DIR}/google/protobuf/message.cc + ${THIRD_PARTY_DIR}/google/protobuf/message_lite.cc + ${THIRD_PARTY_DIR}/google/protobuf/parse_context.cc + ${THIRD_PARTY_DIR}/google/protobuf/port_def.inc + ${THIRD_PARTY_DIR}/google/protobuf/port_undef.inc + ${THIRD_PARTY_DIR}/google/protobuf/reflection_ops.cc + ${THIRD_PARTY_DIR}/google/protobuf/repeated_field.cc + ${THIRD_PARTY_DIR}/google/protobuf/repeated_ptr_field.cc + ${THIRD_PARTY_DIR}/google/protobuf/text_format.cc + ${THIRD_PARTY_DIR}/google/protobuf/unknown_field_set.cc + ${THIRD_PARTY_DIR}/google/protobuf/wire_format.cc + ${THIRD_PARTY_DIR}/google/protobuf/wire_format_lite.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/coded_stream.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/io_win32.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/strtod.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/tokenizer.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/zero_copy_stream.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/zero_copy_stream_impl.cc + ${THIRD_PARTY_DIR}/google/protobuf/io/zero_copy_stream_impl_lite.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/common.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/int128.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/status.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/stringpiece.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/stringprintf.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/structurally_valid.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/strutil.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/substitute.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/bytestream.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/json_util.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/datapiece.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/default_value_objectwriter.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/error_listener.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/json_escaping.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/json_objectwriter.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/json_stream_parser.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/json_util.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/object_writer.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/proto_writer.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/protostream_objectsource.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/protostream_objectwriter.cc + ${THIRD_PARTY_DIR}/google/protobuf/source_context.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/statusor.cc + ${THIRD_PARTY_DIR}/google/protobuf/stubs/time.cc + ${THIRD_PARTY_DIR}/google/protobuf/type.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/wrappers.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/struct.pb.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/type_info.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/field_mask_utility.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/type_resolver_util.cc + ${THIRD_PARTY_DIR}/google/protobuf/util/internal/utility.cc) set(SUBSTRAIT_SOURCES - third_party/substrait/substrait/algebra.pb.cc - third_party/substrait/substrait/capabilities.pb.cc - third_party/substrait/substrait/function.pb.cc - third_party/substrait/substrait/parameterized_types.pb.cc - third_party/substrait/substrait/plan.pb.cc - third_party/substrait/substrait/type.pb.cc - third_party/substrait/substrait/type_expressions.pb.cc - third_party/substrait/substrait/extensions/extensions.pb.cc) + ${THIRD_PARTY_DIR}/substrait/substrait/algebra.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/capabilities.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/function.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/parameterized_types.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/plan.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/type.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/type_expressions.pb.cc + ${THIRD_PARTY_DIR}/substrait/substrait/extensions/extensions.pb.cc) +# refer source by absolute path. So that we can use the same source in the child profile too (i.e. test/c) +set(EXTENSION_SOURCES_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src") set(EXTENSION_SOURCES - src/to_substrait.cpp - src/from_substrait.cpp - src/substrait_extension.cpp - src/custom_extensions.cpp - src/custom_extensions_generated.cpp + ${EXTENSION_SOURCES_DIR}/to_substrait.cpp + ${EXTENSION_SOURCES_DIR}/from_substrait.cpp + ${EXTENSION_SOURCES_DIR}/substrait_extension.cpp + ${EXTENSION_SOURCES_DIR}/custom_extensions.cpp + ${EXTENSION_SOURCES_DIR}/custom_extensions_generated.cpp ${SUBSTRAIT_SOURCES} ${PROTOBUF_SOURCES}) @@ -107,6 +111,7 @@ add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES}) set(PARAMETERS "-warnings") build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES}) +add_subdirectory(test/c) install( TARGETS ${EXTENSION_NAME} diff --git a/extension_config.cmake b/extension_config.cmake index 7665aeb..1051573 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -5,4 +5,5 @@ duckdb_extension_load(substrait SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR} INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/src/include LOAD_TESTS + DONT_LINK ) diff --git a/src/from_substrait.cpp b/src/from_substrait.cpp index 08b2683..f96a9dc 100644 --- a/src/from_substrait.cpp +++ b/src/from_substrait.cpp @@ -11,6 +11,7 @@ #include "duckdb/main/relation/aggregate_relation.hpp" #include "duckdb/main/relation/filter_relation.hpp" #include "duckdb/main/relation/order_relation.hpp" +#include "duckdb/main/relation/create_table_relation.hpp" #include "duckdb/main/connection.hpp" #include "duckdb/parser/parser.hpp" #include "duckdb/common/exception.hpp" @@ -621,6 +622,93 @@ shared_ptr SubstraitToDuckDB::TransformSetOp(const substrait::Rel &sop return make_shared_ptr(std::move(lhs), std::move(rhs), type); } +shared_ptr SubstraitToDuckDB::TransformDdlOp(const substrait::Rel &sop) { + auto &sddl = sop.ddl(); + auto ddl_op = sddl.op(); + if (ddl_op == substrait::DdlRel::DdlOp::DdlRel_DdlOp_DDL_OP_CREATE || + ddl_op == substrait::DdlRel::DdlOp::DdlRel_DdlOp_DDL_OP_CREATE_OR_REPLACE) { + switch (sddl.object()) { + case substrait::DdlRel::DdlObject::DdlRel_DdlObject_DDL_OBJECT_TABLE: { + if (sddl.write_type_case() != substrait::DdlRel::WriteTypeCase::kNamedObject) { + throw NotImplementedException("Only NamedObject is supported in CreateTable"); + } + auto &nobj = sddl.named_object(); + if (nobj.names_size() == 0) { + throw InvalidInputException("Named object must have at least one name"); + } + auto table_idx = nobj.names_size() - 1; + auto table_name = nobj.names(table_idx); + string schema_name; + if (table_idx > 0) { + schema_name = nobj.names(0); + } + + auto &col_names = sddl.table_schema().names(); + auto col_types = sddl.table_schema().struct_().types(); + if (col_names.size() != col_types.size()) { + throw NotImplementedException("Column names and types count do not match"); + } + vector column_definitions; + for (size_t i = 0; i < col_names.size(); i++) { + auto &scol_type = col_types[i]; + auto type = SubstraitToDuckType(scol_type); + column_definitions.push_back(ColumnDefinition(col_names[i], type)); + } + unique_ptr table_desc = make_uniq(); + table_desc->columns = std::move(column_definitions); + table_desc->table = table_name; + table_desc->schema = schema_name; + + std::cout << "Creating table1 " << schema_name << " " << table_name << std::endl; + shared_ptr table = make_shared_ptr(con.context, std::move(table_desc)); + auto create_table_rel = table->CreateRel(schema_name, table_name, false); + std::cout << "Creating table2" << std::endl; + return create_table_rel; + } + case substrait::DdlRel::DdlObject::DdlRel_DdlObject_DDL_OBJECT_VIEW: + throw NotImplementedException("Only CreateTable is supported "); + default: { + throw NotImplementedException("Unsupported DDL object"); + } + } + } + else + { + throw NotImplementedException("Unsupported DDL operation"); + } +} + +shared_ptr SubstraitToDuckDB::TransformWriteOp(const substrait::Rel &sop) { + auto &swrite = sop.write(); + auto &nobj = swrite.named_table(); + if (nobj.names_size() == 0) { + throw InvalidInputException("Named object must have at least one name"); + } + auto table_idx = nobj.names_size() - 1; + auto table_name = nobj.names(table_idx); + string schema_name; + if (table_idx > 0) { + schema_name = nobj.names(0); + } + + auto input = TransformOp(swrite.input()); + + switch (swrite.op()) + { + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS: { + + auto create_table_rel = input->CreateRel(schema_name, table_name, false); + return create_table_rel; + } + case substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT: { + auto insert_rel = input->InsertRel(schema_name, table_name); + return insert_rel; + } + default: + throw NotImplementedException("Unsupported write operation"); + } +} + shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) { switch (sop.rel_type_case()) { case substrait::Rel::RelTypeCase::kJoin: @@ -641,6 +729,10 @@ shared_ptr SubstraitToDuckDB::TransformOp(const substrait::Rel &sop) { return TransformSortOp(sop); case substrait::Rel::RelTypeCase::kSet: return TransformSetOp(sop); + case substrait::Rel::RelTypeCase::kDdl: + return TransformDdlOp(sop); + case substrait::Rel::RelTypeCase::kWrite: + return TransformWriteOp(sop); default: throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case())); } @@ -700,6 +792,9 @@ shared_ptr SubstraitToDuckDB::TransformRootOp(const substrait::RelRoot } } + if (sop.input().rel_type_case() == substrait::Rel::RelTypeCase::kWrite) { + return child; + } return make_shared_ptr(child, std::move(expressions), aliases); } diff --git a/src/include/from_substrait.hpp b/src/include/from_substrait.hpp index ffaaf92..3007bba 100644 --- a/src/include/from_substrait.hpp +++ b/src/include/from_substrait.hpp @@ -28,6 +28,8 @@ class SubstraitToDuckDB { shared_ptr TransformReadOp(const substrait::Rel &sop); shared_ptr TransformSortOp(const substrait::Rel &sop); shared_ptr TransformSetOp(const substrait::Rel &sop); + shared_ptr TransformDdlOp(const substrait::Rel &sop); + shared_ptr TransformWriteOp(const substrait::Rel &sop); //! Transform Substrait Expressions to DuckDB Expressions unique_ptr TransformExpr(const substrait::Expression &sexpr); diff --git a/src/include/to_substrait.hpp b/src/include/to_substrait.hpp index 06cd8b6..19aedfa 100644 --- a/src/include/to_substrait.hpp +++ b/src/include/to_substrait.hpp @@ -52,11 +52,14 @@ class DuckDBToSubstrait { substrait::Rel *TransformComparisonJoin(LogicalOperator &dop); substrait::Rel *TransformAggregateGroup(LogicalOperator &dop); substrait::Rel *TransformGet(LogicalOperator &dop); + substrait::Rel *TransformExpressionGet(LogicalOperator &dop); substrait::Rel *TransformCrossProduct(LogicalOperator &dop); substrait::Rel *TransformUnion(LogicalOperator &dop); substrait::Rel *TransformDistinct(LogicalOperator &dop); substrait::Rel *TransformExcept(LogicalOperator &dop); substrait::Rel *TransformIntersect(LogicalOperator &dop); + substrait::Rel *TransformCreateTable(LogicalOperator &dop); + substrait::Rel *TransformInsertTable(LogicalOperator &dop); static substrait::Rel *TransformDummyScan(); //! Methods to transform different LogicalGet Types (e.g., Table, Parquet) //! To Substrait; diff --git a/src/to_substrait.cpp b/src/to_substrait.cpp index 61f46d7..317e55c 100644 --- a/src/to_substrait.cpp +++ b/src/to_substrait.cpp @@ -1313,6 +1313,42 @@ substrait::Rel *DuckDBToSubstrait::TransformGet(LogicalOperator &dop) { return get_rel; } +static substrait::Expression_Literal GetLiteralFromSubstraitExpression(const substrait::Expression &expr) { + substrait::Expression_Literal literal_field; + switch (expr.rex_type_case()) + { + case substrait::Expression::kLiteral: + literal_field = expr.literal(); + break; + case substrait::Expression::kCast: + literal_field = GetLiteralFromSubstraitExpression(expr.cast().input()); + break; + default: + throw NotImplementedException("Unimplemented type of expression to fetch literal"); + } + return literal_field; +} + +substrait::Rel *DuckDBToSubstrait::TransformExpressionGet(LogicalOperator &dop) { + auto get_rel = new substrait::Rel(); + auto &dget = dop.Cast(); + + auto sget = get_rel->mutable_read(); + auto virtual_table = sget->mutable_virtual_table(); + + for (auto &row : dget.expressions) { + auto row_item = virtual_table->add_values(); + for (auto &expr : row) { + auto s_expr = new substrait::Expression(); + TransformExpr(*expr, *s_expr); + *row_item->add_fields() = GetLiteralFromSubstraitExpression(*s_expr); + delete s_expr; + } + } + return get_rel; +} + + substrait::Rel *DuckDBToSubstrait::TransformCrossProduct(LogicalOperator &dop) { auto rel = new substrait::Rel(); auto sub_cross_prod = rel->mutable_cross(); @@ -1391,6 +1427,101 @@ substrait::Rel *DuckDBToSubstrait::TransformIntersect(LogicalOperator &dop) { return rel; } +substrait::Rel *DuckDBToSubstrait::TransformCreateTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &create_table = dop.Cast(); + auto &create_info = create_table.info.get()->Base(); + + auto schema = new substrait::NamedStruct(); + auto type_info = new substrait::Type_Struct(); + type_info->set_nullability(substrait::Type_Nullability_NULLABILITY_REQUIRED); + for (auto &name : create_info.columns.GetColumnNames()) { + schema->add_names(name); + } + for (auto &col_type : create_info.columns.GetColumnTypes()) { + auto s_type = DuckToSubstraitType(col_type, nullptr, false); + *type_info->add_types() = s_type; + } + schema->set_allocated_struct_(type_info); + + if (create_table.children.size() == 0) { + // This is create table with schema + auto ddl = rel->mutable_ddl(); + ddl->set_op(substrait::DdlRel::DdlOp::DdlRel_DdlOp_DDL_OP_CREATE); + ddl->set_object(substrait::DdlRel::DDL_OBJECT_TABLE); + ddl->set_allocated_table_schema(schema); + + auto named_object = ddl->mutable_named_object(); + named_object->add_names(create_info.schema); + named_object->add_names(create_info.table); + return rel; + } + + // This is create table as select + substrait::Rel *input = nullptr; + switch (create_table.children[0]->type) + { + case LogicalOperatorType::LOGICAL_PROJECTION: + input = TransformProjection(*create_table.children[0]); + break; + default: + throw NotImplementedException("Create table with more than one child"); + } + auto write = rel->mutable_write(); + write->set_allocated_table_schema(schema); + write->set_allocated_input(input); + write->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_CTAS); + auto named_object = write->mutable_named_table(); + named_object->add_names(create_info.schema); + named_object->add_names(create_info.table); + + return rel; +} + +substrait::Rel *DuckDBToSubstrait::TransformInsertTable(LogicalOperator &dop) { + auto rel = new substrait::Rel(); + auto &insert_table = dop.Cast(); + auto &table = insert_table.table; + auto writeRel = rel->mutable_write(); + writeRel->set_op(substrait::WriteRel::WriteOp::WriteRel_WriteOp_WRITE_OP_INSERT); + writeRel->set_output(substrait::WriteRel::OUTPUT_MODE_NO_OUTPUT); + // set named_table + auto named_table = writeRel->mutable_named_table(); + named_table->add_names(table.schema.name); + named_table->add_names(table.name); + + // set table_schema + auto schema = new substrait::NamedStruct(); + for (auto &name : table.GetColumns().GetColumnNames()) { + schema->add_names(name); + } + auto type_info = new substrait::Type_Struct(); + type_info->set_nullability(substrait::Type_Nullability_NULLABILITY_REQUIRED); + for (auto &col_type : table.GetColumns().GetColumnTypes()) { + auto s_type = DuckToSubstraitType(col_type, nullptr, false); + *type_info->add_types() = s_type; + } + schema->set_allocated_struct_(type_info); + + if (insert_table.children.size() == 0) { + // This is insert table with values + // TODO: Set input relation + return rel; + } + // This is create table as select + substrait::Rel *input = nullptr; + switch (insert_table.children[0]->type) + { + case LogicalOperatorType::LOGICAL_PROJECTION: + input = TransformProjection(*insert_table.children[0]); + break; + default: + throw NotImplementedException("Create table with more than one child"); + } + writeRel->set_allocated_input(input); + return rel; +} + substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { switch (dop.type) { case LogicalOperatorType::LOGICAL_FILTER: @@ -1409,6 +1540,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformAggregateGroup(dop); case LogicalOperatorType::LOGICAL_GET: return TransformGet(dop); + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + return TransformExpressionGet(dop); case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: return TransformCrossProduct(dop); case LogicalOperatorType::LOGICAL_UNION: @@ -1421,6 +1554,10 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) { return TransformIntersect(dop); case LogicalOperatorType::LOGICAL_DUMMY_SCAN: return TransformDummyScan(); + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + return TransformCreateTable(dop); + case LogicalOperatorType::LOGICAL_INSERT: + return TransformInsertTable(dop); default: throw InternalException(LogicalOperatorToString(dop.type)); } @@ -1431,8 +1568,24 @@ static bool IsSetOperation(const LogicalOperator &op) { op.type == LogicalOperatorType::LOGICAL_INTERSECT; } +static bool IsWriteORUpdateOperation(const LogicalOperator &op) { + switch (op.type) + { + case LogicalOperatorType::LOGICAL_CREATE_TABLE: + case LogicalOperatorType::LOGICAL_INSERT: + case LogicalOperatorType::LOGICAL_DELETE: + case LogicalOperatorType::LOGICAL_UPDATE: + return true; + } + return false; +} + substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { auto root_rel = new substrait::RelRoot(); + if (IsWriteORUpdateOperation(dop)) { + root_rel->set_allocated_input(TransformOp(dop)); + return root_rel; + } LogicalOperator *current_op = &dop; bool weird_scenario = current_op->type == LogicalOperatorType::LOGICAL_PROJECTION && current_op->children[0]->type == LogicalOperatorType::LOGICAL_TOP_N; @@ -1451,6 +1604,10 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { continue; } if (current_op->children.size() != 1) { + if (current_op->type == LogicalOperatorType::LOGICAL_CREATE_TABLE) { + // Create table nodes have multiple children, but we don't care about them + break; + } throw InternalException("Root node has more than 1, or 0 children (%d) up to " "reaching a projection node. Type %d", current_op->children.size(), current_op->type); @@ -1458,6 +1615,11 @@ substrait::RelRoot *DuckDBToSubstrait::TransformRootOp(LogicalOperator &dop) { current_op = current_op->children[0].get(); } root_rel->set_allocated_input(TransformOp(dop)); + if (current_op->type != LogicalOperatorType::LOGICAL_PROJECTION) { + // No projection on top of the root, we don't have any aliases + return root_rel; + } + auto &dproj = current_op->Cast(); if (!weird_scenario) { for (auto &expression : dproj.expressions) { diff --git a/test/c/CMakeLists.txt b/test/c/CMakeLists.txt index e84f702..1660f20 100644 --- a/test/c/CMakeLists.txt +++ b/test/c/CMakeLists.txt @@ -12,10 +12,15 @@ include_directories(../../duckdb/src/include) include_directories(../../duckdb/test/include) include_directories(../../duckdb/third_party/catch) -set(ALL_SOURCES test_substrait_c_api.cpp) +# get all source files from test/helpers +file(GLOB TEST_HELPER_SOURCES "../../duckdb/test/helpers/*.cpp") +set(ALL_SOURCES ${EXTENSION_SOURCES} ${TEST_HELPER_SOURCES} test_substrait_c_api.cpp) +# this add_executable is needed to make unit test a target +add_executable(test_substrait test_substrait_c_api.cpp ${ALL_SOURCES}) +# add duckdb static library to the test so that duckdb symbols are available +target_link_libraries(test_substrait duckdb_static) -add_library_unity(test_substrait OBJECT ${ALL_SOURCES}) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ PARENT_SCOPE) diff --git a/test/c/test_substrait_c_api.cpp b/test/c/test_substrait_c_api.cpp index 3f97646..fd3808b 100644 --- a/test/c/test_substrait_c_api.cpp +++ b/test/c/test_substrait_c_api.cpp @@ -1,19 +1,22 @@ +#define CATCH_CONFIG_RUNNER #include "catch.hpp" #include "test_helpers.hpp" -#include "duckdb/parser/parser.hpp" -#include "duckdb/planner/logical_operator.hpp" #include "duckdb/main/connection_manager.hpp" #include "substrait_extension.hpp" -#include -#include using namespace duckdb; using namespace std; +int main(int argc, char* argv[]) { + // Call Catch2's session to run tests + return Catch::Session().run(argc, argv); +} + TEST_CASE("Test C Get and To Substrait API", "[substrait-api]") { DuckDB db(nullptr); - db.LoadExtension(); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); Connection con(db); con.EnableQueryVerification(); // create the database @@ -33,7 +36,8 @@ TEST_CASE("Test C Get and To Substrait API", "[substrait-api]") { TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") { DuckDB db(nullptr); - db.LoadExtension(); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); Connection con(db); con.EnableQueryVerification(); // create the database @@ -50,3 +54,107 @@ TEST_CASE("Test C Get and To Json-Substrait API", "[substrait-api]") { REQUIRE_THROWS(con.FromSubstraitJSON("this is not valid")); } + +TEST_CASE("Test C Get and To Substrait API for Insert from select", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + // create second table + // create first table and populate data + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t1(i INTEGER)")); + REQUIRE_NO_FAIL( + con.Query("INSERT INTO t1 VALUES (1), (2), (3), (NULL)")); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t2(i INTEGER)")); + + // Issue substrait query for insert as select + auto proto = con.GetSubstrait("insert into t2 from (select * from t1)"); + auto result = con.FromSubstrait(proto); + + // number of rows inserted are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {4})); +} + +TEST_CASE("Test C Get and To Json-Substrait API for Insert from select", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + // create second table + // create first table and populate data + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t1(i INTEGER)")); + REQUIRE_NO_FAIL( + con.Query("INSERT INTO t1 VALUES (1), (2), (3), (NULL)")); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t2(i INTEGER)")); + // Issue substrait query for insert as select + auto json = con.GetSubstraitJSON("insert into t2 from (select * from t1)"); + // round trip + auto result = con.FromSubstraitJSON(json); + + // number of rows inserted are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {4})); +} + +TEST_CASE("Test C Get and To Substrait API for Insert from virtual table", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t1(i INTEGER)")); + + auto proto = con.GetSubstrait("INSERT INTO t1 VALUES (1), (2), (3), (NULL)"); + auto result = con.FromSubstrait(proto); + // number of rows inserted are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {4})); +} + +TEST_CASE("Test C Get and To JSON-Substrait API for Insert from virtual table", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + + REQUIRE_NO_FAIL(con.Query("CREATE TABLE t1(i INTEGER)")); + + auto json = con.GetSubstraitJSON("INSERT INTO t1 VALUES (1), (2), (3), (NULL)"); + auto result = con.FromSubstraitJSON(json); + // number of rows inserted are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {4})); +} + +TEST_CASE("Test C Get and To Substrait API for Select from virtual table", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + + + auto json = con.GetSubstrait("SELECT * FROM (VALUES (1, 2), (3, 4))"); + auto result = con.FromSubstrait(json); + // number of rows selected are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {1, 3})); + REQUIRE(CHECK_COLUMN(result, 1, {2, 4})); +} + +TEST_CASE("Test C Get and To JSON-Substrait API for Select from virtual table", "[substrait-api]") { + DuckDB db(nullptr); + SubstraitExtension substrait_extension; + substrait_extension.Load(db); + Connection con(db); + con.EnableQueryVerification(); + + + auto json = con.GetSubstraitJSON("SELECT * FROM (VALUES (1, 2), (3, 4))"); + auto result = con.FromSubstraitJSON(json); + // number of rows selected are expected as result of insert + REQUIRE(CHECK_COLUMN(result, 0, {1, 3})); + REQUIRE(CHECK_COLUMN(result, 1, {2, 4})); +} \ No newline at end of file