Skip to content

Commit

Permalink
updates from feedback and mark experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Nov 11, 2024
1 parent 2dc2a13 commit 6a0aa0a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 52 deletions.
99 changes: 50 additions & 49 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2520,10 +2520,45 @@ namespace {

class AsyncRecordBatchIterator {
public:
struct TaskWithMetadata {
ArrowAsyncTask task_;
std::shared_ptr<KeyValueMetadata> metadata_;
};

struct State {
State(uint64_t queue_size, const DeviceMemoryMapper mapper)
State(uint64_t queue_size, DeviceMemoryMapper mapper)
: queue_size_{queue_size}, mapper_{std::move(mapper)} {}

Result<RecordBatchWithMetadata> next() {
TaskWithMetadata task;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock,
[&] { return !error_.ok() || !batches_.empty() || end_of_stream_; });
if (!error_.ok()) {
return error_;
}

if (batches_.empty() && end_of_stream_) {
return IterationEnd<RecordBatchWithMetadata>();
}

task = std::move(batches_.front());
batches_.pop();
}

producer_->request(producer_, 1);
ArrowDeviceArray out;
if (task.task_.extract_data(&task.task_, &out) != 0) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return !error_.ok(); });
return error_;
}

ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, mapper_));
return RecordBatchWithMetadata{std::move(batch), std::move(task.metadata_)};
}

const uint64_t queue_size_;
const DeviceMemoryMapper mapper_;
ArrowAsyncProducer* producer_;
Expand All @@ -2532,7 +2567,7 @@ class AsyncRecordBatchIterator {
std::mutex mutex_;
std::shared_ptr<Schema> schema_;
std::condition_variable cv_;
std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>> batches_;
std::queue<TaskWithMetadata> batches_;
bool end_of_stream_ = false;
Status error_{Status::OK()};
};
Expand All @@ -2547,38 +2582,7 @@ class AsyncRecordBatchIterator {

DeviceAllocationType device_type() const { return state_->device_type_; }

Result<RecordBatchWithMetadata> Next() {
std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
{
std::unique_lock<std::mutex> lock(state_->mutex_);
state_->cv_.wait(lock, [&] {
return !state_->error_.ok() || !state_->batches_.empty() ||
state_->end_of_stream_;
});
if (!state_->error_.ok()) {
return state_->error_;
}

if (state_->batches_.empty() && state_->end_of_stream_) {
return RecordBatchWithMetadata{nullptr, nullptr};
}

task = state_->batches_.front();
state_->batches_.pop();
}

state_->producer_->request(state_->producer_, 1);
ArrowDeviceArray out;
if (task.first.extract_data(&task.first, &out) != 0) {
std::unique_lock<std::mutex> lock(state_->mutex_);
state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
return state_->error_;
}

ARROW_ASSIGN_OR_RAISE(
auto batch, ImportDeviceRecordBatch(&out, state_->schema_, state_->mapper_));
return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
}
Result<RecordBatchWithMetadata> Next() { return state_->next(); }

static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
AsyncRecordBatchIterator& iterator, struct ArrowAsyncDeviceStreamHandler* handler) {
Expand Down Expand Up @@ -2618,9 +2622,8 @@ class AsyncRecordBatchIterator {
private_data->fut_iterator_.MarkFinished(maybe_schema.status());
return EINVAL;
}

auto schema = maybe_schema.MoveValueUnsafe();
private_data->state_->schema_ = schema;

private_data->state_->schema_ = maybe_schema.MoveValueUnsafe();
private_data->fut_iterator_.MarkFinished(private_data->state_);
self->producer->request(self->producer,
static_cast<int64_t>(private_data->state_->queue_size_));
Expand All @@ -2643,16 +2646,16 @@ class AsyncRecordBatchIterator {
if (metadata != nullptr) {
auto maybe_decoded = DecodeMetadata(metadata);
if (!maybe_decoded.ok()) {
private_data->state_->error_ = maybe_decoded.status();
private_data->state_->error_ = std::move(maybe_decoded).status();
private_data->state_->cv_.notify_one();
return EINVAL;
}

kvmetadata = maybe_decoded->metadata;
kvmetadata = std::move(maybe_decoded->metadata);
}

std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
private_data->state_->batches_.emplace(*task, std::move(kvmetadata));
private_data->state_->batches_.push({*task, std::move(kvmetadata)});
lock.unlock();
private_data->state_->cv_.notify_one();
return 0;
Expand Down Expand Up @@ -2680,14 +2683,13 @@ class AsyncRecordBatchIterator {
}

std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
private_data->state_->error_ = error;
private_data->state_->error_ = std::move(error);
lock.unlock();
private_data->state_->cv_.notify_one();
}

static void release(ArrowAsyncDeviceStreamHandler* self) {
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
delete private_data;
delete reinterpret_cast<PrivateData*>(self->private_data);
}

std::shared_ptr<State> state_;
Expand Down Expand Up @@ -2781,7 +2783,7 @@ struct AsyncProducer {
}

static int extract_data(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) {
auto private_data = reinterpret_cast<PrivateTaskData*>(task->private_data);
std::unique_ptr<PrivateTaskData> private_data{reinterpret_cast<PrivateTaskData*>(task->private_data)};
int ret = 0;
if (out != nullptr) {
auto status = ExportDeviceRecordBatch(*private_data->record_,
Expand All @@ -2791,8 +2793,7 @@ struct AsyncProducer {
private_data->producer_->error_ = status;
}
}

delete private_data;

return ret;
}

Expand Down Expand Up @@ -2841,16 +2842,16 @@ Future<> ExportAsyncRecordBatchReader(

return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, handler})
.Then(
[handler]() -> Future<> {
[handler]() -> Status {
int status = handler->on_next_task(handler, nullptr, nullptr);
handler->release(handler);
if (status != 0) {
return Status::UnknownError("Received error from handler::on_next_task ",
status);
}
return Future<>::MakeFinished();
return Status::OK();
},
[handler](const Status status) -> Future<> {
[handler](const Status status) -> Status {
handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
handler->release(handler);
return status;
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/c/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
///
/// @{

/// \brief AsyncErrorDetail is a StatusDetail that contains an error code and message
/// \brief EXPERIMENTAL: AsyncErrorDetail is a StatusDetail that contains an error code and message
/// from an asynchronous operation.
class AsyncErrorDetail : public StatusDetail {
public:
Expand Down Expand Up @@ -444,7 +444,7 @@ namespace internal {
class Executor;
}

/// \brief Create an AsyncRecordBatchReader and populate a corresponding handler to pass
/// \brief EXPERIMENTAL: Create an AsyncRecordBatchReader and populate a corresponding handler to pass
/// to a producer
///
/// The ArrowAsyncDeviceStreamHandler struct is intended to have its callbacks populated
Expand All @@ -464,7 +464,7 @@ Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor,
uint64_t queue_size = 5, DeviceMemoryMapper mapper = DefaultDeviceMemoryMapper);

/// \brief Export an AsyncGenerator of record batches using a provided handler
/// \brief EXPERIMENTAL: Export an AsyncGenerator of record batches using a provided handler
///
/// This function calls the callbacks on the consumer-provided async handler as record
/// batches become available from the AsyncGenerator which is provided. It will first call
Expand Down

0 comments on commit 6a0aa0a

Please sign in to comment.