Skip to content

Commit

Permalink
Unify table/batch schemas when reading chunks from vineyard streams (v…
Browse files Browse the repository at this point in the history
…6d-io#1485)

Fixes v6d-io#1484

Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow authored Jul 21, 2023
1 parent 540511e commit 70d32f2
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
86 changes: 86 additions & 0 deletions modules/basic/ds/arrow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,27 @@ Status RecordBatchesToTable(
}
}

Status RecordBatchesToTableWithCast(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Table>* table) {
std::shared_ptr<arrow::Schema> out;
RETURN_ON_ERROR(TypeLoosen(batches, out));
return RecordBatchesToTableWithCast(out, batches, table);
}

Status RecordBatchesToTableWithCast(
const std::shared_ptr<arrow::Schema> schema,
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Table>* table) {
std::vector<std::shared_ptr<arrow::RecordBatch>> outs;
for (auto const& batch : batches) {
std::shared_ptr<arrow::RecordBatch> out;
RETURN_ON_ERROR(CastBatchToSchema(batch, schema, out));
outs.push_back(out);
}
return RecordBatchesToTable(schema, outs, table);
}

Status CombineRecordBatches(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::RecordBatch>* batch) {
Expand Down Expand Up @@ -873,6 +894,31 @@ Status TypeLoosen(const std::vector<std::shared_ptr<arrow::Schema>>& schemas,
return Status::OK();
}

Status TypeLoosen(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Schema>& schema) {
std::vector<std::shared_ptr<arrow::Schema>> schemas;
schemas.reserve(batches.size());
for (const auto& batch : batches) {
if (batch != nullptr) {
schemas.push_back(batch->schema());
}
}
return TypeLoosen(schemas, schema);
}

Status TypeLoosen(const std::vector<std::shared_ptr<arrow::Table>>& tables,
std::shared_ptr<arrow::Schema>& schema) {
std::vector<std::shared_ptr<arrow::Schema>> schemas;
schemas.reserve(tables.size());
for (const auto& table : tables) {
if (table != nullptr) {
schemas.push_back(table->schema());
}
}
return TypeLoosen(schemas, schema);
}

Status CastStringToBigString(const std::shared_ptr<arrow::Array>& in,
const std::shared_ptr<arrow::DataType>& to_type,
std::shared_ptr<arrow::Array>& out) {
Expand Down Expand Up @@ -925,6 +971,46 @@ Status GeneralCast(const std::shared_ptr<arrow::Array>& in,
return Status::OK();
}

Status CastBatchToSchema(const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::RecordBatch>& out) {
if (batch->schema()->Equals(schema)) {
out = batch;
return Status::OK();
}

RETURN_ON_ASSERT(batch->num_columns() == schema->num_fields(),
"The schema of original recordbatch and expected schema is "
"not consistent");
std::vector<std::shared_ptr<arrow::Array>> new_columns;
for (int64_t i = 0; i < batch->num_columns(); ++i) {
auto col = batch->column(i);
if (batch->schema()->field(i)->type()->Equals(schema->field(i)->type())) {
new_columns.push_back(col);
continue;
}
auto from_type = batch->schema()->field(i)->type();
auto to_type = schema->field(i)->type();
auto array = col;
std::shared_ptr<arrow::Array> out;
if (arrow::compute::CanCast(*from_type, *to_type)) {
RETURN_ON_ERROR(GeneralCast(array, to_type, out));
} else if (from_type->Equals(arrow::utf8()) &&
to_type->Equals(arrow::large_utf8())) {
RETURN_ON_ERROR(CastStringToBigString(array, to_type, out));
} else if (from_type->Equals(arrow::null())) {
RETURN_ON_ERROR(CastNullToOthers(array, to_type, out));
} else {
return Status::Invalid(
"Unsupported cast: To type: " + to_type->ToString() +
" vs. origin type: " + from_type->ToString());
}
new_columns.push_back(out);
}
out = arrow::RecordBatch::Make(schema, batch->num_rows(), new_columns);
return Status::OK();
}

Status CastTableToSchema(const std::shared_ptr<arrow::Table>& table,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::Table>& out) {
Expand Down
20 changes: 20 additions & 0 deletions modules/basic/ds/arrow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ Status RecordBatchesToTable(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Table>* table);

Status RecordBatchesToTableWithCast(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Table>* table);

Status RecordBatchesToTableWithCast(
const std::shared_ptr<arrow::Schema> schema,
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Table>* table);

Status CombineRecordBatches(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::RecordBatch>* batch);
Expand Down Expand Up @@ -348,6 +357,13 @@ const void* get_arrow_array_data(std::shared_ptr<arrow::Array> const& array);
Status TypeLoosen(const std::vector<std::shared_ptr<arrow::Schema>>& schemas,
std::shared_ptr<arrow::Schema>& schema);

Status TypeLoosen(
const std::vector<std::shared_ptr<arrow::RecordBatch>>& batches,
std::shared_ptr<arrow::Schema>& schema);

Status TypeLoosen(const std::vector<std::shared_ptr<arrow::Table>>& tables,
std::shared_ptr<arrow::Schema>& schema);

Status CastStringToBigString(const std::shared_ptr<arrow::Array>& in,
const std::shared_ptr<arrow::DataType>& to_type,
std::shared_ptr<arrow::Array>& out);
Expand All @@ -360,6 +376,10 @@ Status GeneralCast(const std::shared_ptr<arrow::Array>& in,
const std::shared_ptr<arrow::DataType>& to_type,
std::shared_ptr<arrow::Array>& out);

Status CastBatchToSchema(const std::shared_ptr<arrow::RecordBatch>& batch,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::RecordBatch>& out);

Status CastTableToSchema(const std::shared_ptr<arrow::Table>& table,
const std::shared_ptr<arrow::Schema>& schema,
std::shared_ptr<arrow::Table>& out);
Expand Down
6 changes: 3 additions & 3 deletions modules/graph/loader/arrow_fragment_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Status ReadTableFromVineyardDataFrame(Client& client,
VLOG(10) << "read table from vineyard: total rows = " << 0;
return Status::OK();
} else {
auto status = RecordBatchesToTable(batches, &table);
auto status = RecordBatchesToTableWithCast(batches, &table);
if (status.ok()) {
VLOG(10) << "read table from vineyard: total rows = "
<< table->num_rows();
Expand Down Expand Up @@ -389,7 +389,7 @@ GatherETables(Client& client,
if (subgroup.second.empty()) {
table = nullptr; // no tables at current worker
} else {
VY_OK_OR_RAISE(RecordBatchesToTable(subgroup.second, &table));
VY_OK_OR_RAISE(RecordBatchesToTableWithCast(subgroup.second, &table));
}
subtables.emplace_back(table);
}
Expand Down Expand Up @@ -448,7 +448,7 @@ boost::leaf::result<std::vector<std::shared_ptr<arrow::Table>>> GatherVTables(
if (group.second.empty()) {
table = nullptr; // no tables at current worker
} else {
VY_OK_OR_RAISE(RecordBatchesToTable(group.second, &table));
VY_OK_OR_RAISE(RecordBatchesToTableWithCast(group.second, &table));
}
tables.emplace_back(table);
}
Expand Down

0 comments on commit 70d32f2

Please sign in to comment.