diff --git a/modules/basic/ds/arrow_utils.cc b/modules/basic/ds/arrow_utils.cc index a8389421eb..28d191d2ab 100644 --- a/modules/basic/ds/arrow_utils.cc +++ b/modules/basic/ds/arrow_utils.cc @@ -420,6 +420,27 @@ Status RecordBatchesToTable( } } +Status RecordBatchesToTableWithCast( + const std::vector>& batches, + std::shared_ptr* table) { + std::shared_ptr out; + RETURN_ON_ERROR(TypeLoosen(batches, out)); + return RecordBatchesToTableWithCast(out, batches, table); +} + +Status RecordBatchesToTableWithCast( + const std::shared_ptr schema, + const std::vector>& batches, + std::shared_ptr* table) { + std::vector> outs; + for (auto const& batch : batches) { + std::shared_ptr out; + RETURN_ON_ERROR(CastBatchToSchema(batch, schema, out)); + outs.push_back(out); + } + return RecordBatchesToTable(schema, outs, table); +} + Status CombineRecordBatches( const std::vector>& batches, std::shared_ptr* batch) { @@ -873,6 +894,31 @@ Status TypeLoosen(const std::vector>& schemas, return Status::OK(); } +Status TypeLoosen( + const std::vector>& batches, + std::shared_ptr& schema) { + std::vector> schemas; + schemas.reserve(batches.size()); + for (const auto& batch : batches) { + if (batch != nullptr) { + schemas.push_back(batch->schema()); + } + } + return TypeLoosen(schemas, schema); +} + +Status TypeLoosen(const std::vector>& tables, + std::shared_ptr& schema) { + std::vector> schemas; + schemas.reserve(tables.size()); + for (const auto& table : tables) { + if (table != nullptr) { + schemas.push_back(table->schema()); + } + } + return TypeLoosen(schemas, schema); +} + Status CastStringToBigString(const std::shared_ptr& in, const std::shared_ptr& to_type, std::shared_ptr& out) { @@ -925,6 +971,46 @@ Status GeneralCast(const std::shared_ptr& in, return Status::OK(); } +Status CastBatchToSchema(const std::shared_ptr& batch, + const std::shared_ptr& schema, + std::shared_ptr& out) { + if (batch->schema()->Equals(schema)) { + out = batch; + return Status::OK(); + } + + RETURN_ON_ASSERT(batch->num_columns() == schema->num_fields(), + "The schema of original recordbatch and expected schema is " + "not consistent"); + std::vector> new_columns; + for (int64_t i = 0; i < batch->num_columns(); ++i) { + auto col = batch->column(i); + if (batch->schema()->field(i)->type()->Equals(schema->field(i)->type())) { + new_columns.push_back(col); + continue; + } + auto from_type = batch->schema()->field(i)->type(); + auto to_type = schema->field(i)->type(); + auto array = col; + std::shared_ptr out; + if (arrow::compute::CanCast(*from_type, *to_type)) { + RETURN_ON_ERROR(GeneralCast(array, to_type, out)); + } else if (from_type->Equals(arrow::utf8()) && + to_type->Equals(arrow::large_utf8())) { + RETURN_ON_ERROR(CastStringToBigString(array, to_type, out)); + } else if (from_type->Equals(arrow::null())) { + RETURN_ON_ERROR(CastNullToOthers(array, to_type, out)); + } else { + return Status::Invalid( + "Unsupported cast: To type: " + to_type->ToString() + + " vs. origin type: " + from_type->ToString()); + } + new_columns.push_back(out); + } + out = arrow::RecordBatch::Make(schema, batch->num_rows(), new_columns); + return Status::OK(); +} + Status CastTableToSchema(const std::shared_ptr& table, const std::shared_ptr& schema, std::shared_ptr& out) { diff --git a/modules/basic/ds/arrow_utils.h b/modules/basic/ds/arrow_utils.h index fbc909a21b..b55974669f 100644 --- a/modules/basic/ds/arrow_utils.h +++ b/modules/basic/ds/arrow_utils.h @@ -267,6 +267,15 @@ Status RecordBatchesToTable( const std::vector>& batches, std::shared_ptr* table); +Status RecordBatchesToTableWithCast( + const std::vector>& batches, + std::shared_ptr* table); + +Status RecordBatchesToTableWithCast( + const std::shared_ptr schema, + const std::vector>& batches, + std::shared_ptr* table); + Status CombineRecordBatches( const std::vector>& batches, std::shared_ptr* batch); @@ -348,6 +357,13 @@ const void* get_arrow_array_data(std::shared_ptr const& array); Status TypeLoosen(const std::vector>& schemas, std::shared_ptr& schema); +Status TypeLoosen( + const std::vector>& batches, + std::shared_ptr& schema); + +Status TypeLoosen(const std::vector>& tables, + std::shared_ptr& schema); + Status CastStringToBigString(const std::shared_ptr& in, const std::shared_ptr& to_type, std::shared_ptr& out); @@ -360,6 +376,10 @@ Status GeneralCast(const std::shared_ptr& in, const std::shared_ptr& to_type, std::shared_ptr& out); +Status CastBatchToSchema(const std::shared_ptr& batch, + const std::shared_ptr& schema, + std::shared_ptr& out); + Status CastTableToSchema(const std::shared_ptr& table, const std::shared_ptr& schema, std::shared_ptr& out); diff --git a/modules/graph/loader/arrow_fragment_loader.cc b/modules/graph/loader/arrow_fragment_loader.cc index f01e06f074..288d5207ec 100644 --- a/modules/graph/loader/arrow_fragment_loader.cc +++ b/modules/graph/loader/arrow_fragment_loader.cc @@ -281,7 +281,7 @@ Status ReadTableFromVineyardDataFrame(Client& client, VLOG(10) << "read table from vineyard: total rows = " << 0; return Status::OK(); } else { - auto status = RecordBatchesToTable(batches, &table); + auto status = RecordBatchesToTableWithCast(batches, &table); if (status.ok()) { VLOG(10) << "read table from vineyard: total rows = " << table->num_rows(); @@ -389,7 +389,7 @@ GatherETables(Client& client, if (subgroup.second.empty()) { table = nullptr; // no tables at current worker } else { - VY_OK_OR_RAISE(RecordBatchesToTable(subgroup.second, &table)); + VY_OK_OR_RAISE(RecordBatchesToTableWithCast(subgroup.second, &table)); } subtables.emplace_back(table); } @@ -448,7 +448,7 @@ boost::leaf::result>> GatherVTables( if (group.second.empty()) { table = nullptr; // no tables at current worker } else { - VY_OK_OR_RAISE(RecordBatchesToTable(group.second, &table)); + VY_OK_OR_RAISE(RecordBatchesToTableWithCast(group.second, &table)); } tables.emplace_back(table); }