Skip to content

Commit

Permalink
add sempahore to asyncio_runnable
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Nov 3, 2023
1 parent 62e1834 commit 6ddadf7
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ class AsyncioRunnable : public AsyncSink<InputT>,
using task_buffer_t = mrc::coroutines::ClosableRingBuffer<size_t>;

public:
AsyncioRunnable(size_t concurrency = 8) : m_concurrency(concurrency){};
~AsyncioRunnable() override = default;

private:
Expand All @@ -199,7 +198,6 @@ class AsyncioRunnable : public AsyncSink<InputT>,
* @brief The per-value coroutine run asynchronously alongside other calls.
*/
coroutines::Task<> process_one(InputT value,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher);

Expand All @@ -211,7 +209,11 @@ class AsyncioRunnable : public AsyncSink<InputT>,

std::stop_source m_stop_source;

size_t m_concurrency{8};
/**
* @brief A semaphore used to control the number of outstanding operations. Acquire one before
* a beginning a task, and release it before finishing.
*/
std::counting_semaphore<8> m_task_tickets{8};
};

template <typename InputT, typename OutputT>
Expand Down Expand Up @@ -279,15 +281,16 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx)
template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler)
{
// Create the task buffer to limit the number of running tasks
task_buffer_t task_buffer{{.capacity = m_concurrency}};
co_await scheduler->yield();

coroutines::TaskContainer outstanding_tasks(scheduler);

ExceptionCatcher catcher{};

while (not m_stop_source.stop_requested() and not catcher.has_exception())
{
m_task_tickets.acquire();

InputT data;

auto read_status = co_await this->read_async(data);
Expand All @@ -297,26 +300,16 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<m
break;
}

// Wait for an available slot in the task buffer
co_await task_buffer.write(0);

outstanding_tasks.start(this->process_one(std::move(data), task_buffer, scheduler, catcher));
outstanding_tasks.start(this->process_one(std::move(data), scheduler, catcher));
}

// Close the buffer
task_buffer.close();

// Now block until all tasks are complete
co_await task_buffer.completed();

co_await outstanding_tasks.garbage_collect_and_yield_until_empty();

catcher.rethrow_next_exception();
}

template <typename InputT, typename OutputT>
coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
task_buffer_t& task_buffer,
std::shared_ptr<mrc::coroutines::Scheduler> on,
ExceptionCatcher& catcher)
{
Expand Down Expand Up @@ -344,8 +337,7 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value,
catcher.push_exception(std::current_exception());
}

// Return the slot to the task buffer
co_await task_buffer.read();
m_task_tickets.release();
}

template <typename InputT, typename OutputT>
Expand Down

0 comments on commit 6ddadf7

Please sign in to comment.