Skip to content

Commit

Permalink
simplify asyncio_runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Nov 1, 2023
1 parent 521abc2 commit a0d136e
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,22 @@ class CoroutineRunnableSource : public mrc::node::WritableAcceptor<T>,
public mrc::node::SourceChannelOwner<T>
{
protected:
CoroutineRunnableSource()
CoroutineRunnableSource() :
m_writer([this](T&& value) {
return this->get_writable_edge()->await_write(std::move(value));
})
{
// Set the default channel
this->set_channel(std::make_unique<mrc::channel::BufferedChannel<T>>());

m_writer = std::make_shared<BoostFutureWriter<T>>([this](T&& value) {
return this->get_writable_edge()->await_write(std::move(value));
});
}

auto get_writable_receiver() -> std::shared_ptr<BoostFutureWriter<T>>
coroutines::Task<mrc::channel::Status> async_write(T&& value)
{
return m_writer;
co_return co_await m_writer.async_write(std::move(value));
}

private:
std::shared_ptr<BoostFutureWriter<T>> m_writer;
BoostFutureWriter<T> m_writer;
};

template <typename InputT, typename OutputT>
Expand All @@ -205,7 +204,6 @@ class AsyncioRunnable : public CoroutineRunnableSink<InputT>,
coroutines::Task<> main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler);

coroutines::Task<> process_one(InputT&& value,
std::shared_ptr<BoostFutureWriter<OutputT>> writer,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher);
Expand Down Expand Up @@ -236,18 +234,15 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx)
template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler)
{
// Get the generator and receiver
auto output_receiver = CoroutineRunnableSource<OutputT>::get_writable_receiver();

// Create the task buffer to limit the number of running tasks
task_buffer_t task_buffer{{.capacity = m_concurrency}};

coroutines::TaskContainer outstanding_tasks(scheduler);

ExceptionCatcher catcher{};

while (not m_stop_source.stop_requested() and not catcher.has_exception()) {
while (not m_stop_source.stop_requested() and not catcher.has_exception())
{
InputT data;

auto read_status = co_await this->async_read(data);
Expand All @@ -260,7 +255,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
// Wait for an available slot in the task buffer
co_await task_buffer.write(0);

outstanding_tasks.start(this->process_one(std::move(data), output_receiver, task_buffer, scheduler, catcher));
outstanding_tasks.start(this->process_one(std::move(data), task_buffer, scheduler, catcher));
}

// Close the buffer
Expand All @@ -276,7 +271,6 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m

template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT&& value,
std::shared_ptr<BoostFutureWriter<OutputT>> writer,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher)
Expand All @@ -295,7 +289,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT&& value,
// Weird bug, cant directly move the value into the async_write call
auto data = std::move(*iter);

co_await writer->async_write(std::move(data));
co_await this->async_write(std::move(data));

// Advance the iterator
co_await ++iter;
Expand Down

0 comments on commit a0d136e

Please sign in to comment.