From 26fae58982da9c5b57bde290b4efb05a24f0f942 Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Fri, 27 Oct 2023 01:05:25 +0000 Subject: [PATCH] add asynciorunnable --- cpp/mrc/include/mrc/coroutines/scheduler.hpp | 18 - cpp/mrc/src/public/coroutines/scheduler.cpp | 12 +- .../_pymrc/include/pymrc/asyncio_runnable.hpp | 386 ++++++++++++++++++ .../include/pymrc/asyncio_scheduler.hpp | 126 ++++++ 4 files changed, 513 insertions(+), 29 deletions(-) create mode 100644 python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp create mode 100644 python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp diff --git a/cpp/mrc/include/mrc/coroutines/scheduler.hpp b/cpp/mrc/include/mrc/coroutines/scheduler.hpp index 1b0aac502..f799db83a 100644 --- a/cpp/mrc/include/mrc/coroutines/scheduler.hpp +++ b/cpp/mrc/include/mrc/coroutines/scheduler.hpp @@ -25,13 +25,8 @@ #include #include -// IWYU thinks this is needed, but it's not -// IWYU pragma: no_include "mrc/coroutines/task_container.hpp" - namespace mrc::coroutines { -class TaskContainer; // IWYU pragma: keep - /** * @brief Scheduler base class * @@ -75,9 +70,6 @@ class Scheduler : public std::enable_shared_from_this */ [[nodiscard]] virtual auto schedule() -> Operation; - // Enqueues a message without waiting for it. Must return void since the caller will not get the return value - virtual void schedule(Task&& task); - /** * Schedules any coroutine handle that is ready to be resumed. * @param handle The coroutine handle to schedule. @@ -103,13 +95,6 @@ class Scheduler : public std::enable_shared_from_this protected: virtual auto on_thread_start(std::size_t) -> void; - /** - * @brief Get the task container object - * - * @return TaskContainer& - */ - TaskContainer& get_task_container() const; - private: /** * @brief When co_await schedule() is called, this function will be executed by the awaiter. Each scheduler @@ -123,9 +108,6 @@ class Scheduler : public std::enable_shared_from_this mutable std::mutex m_mutex; - // Maintains the lifetime of fire-and-forget tasks scheduled with schedule(Task&& task) - std::unique_ptr m_task_container; - thread_local static Scheduler* m_thread_local_scheduler; thread_local static std::size_t m_thread_id; }; diff --git a/cpp/mrc/src/public/coroutines/scheduler.cpp b/cpp/mrc/src/public/coroutines/scheduler.cpp index af2e70294..f4f7776ac 100644 --- a/cpp/mrc/src/public/coroutines/scheduler.cpp +++ b/cpp/mrc/src/public/coroutines/scheduler.cpp @@ -39,18 +39,13 @@ std::coroutine_handle<> Scheduler::Operation::await_suspend(std::coroutine_handl return m_scheduler.schedule_operation(this); } -Scheduler::Scheduler() : m_task_container(new TaskContainer(*this)) {} +Scheduler::Scheduler() = default; auto Scheduler::schedule() -> Operation { return Operation{*this}; } -void Scheduler::schedule(Task&& task) -{ - return m_task_container->start(std::move(task)); -} - auto Scheduler::yield() -> Operation { return schedule(); @@ -77,9 +72,4 @@ auto Scheduler::on_thread_start(std::size_t thread_id) -> void m_thread_local_scheduler = this; } -TaskContainer& Scheduler::get_task_container() const -{ - return *m_task_container; -} - } // namespace mrc::coroutines diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp new file mode 100644 index 000000000..88cd91773 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -0,0 +1,386 @@ +#pragma once + +#include "pymrc/asyncio_scheduler.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mrc::pymrc { + +template +using Task = mrc::coroutines::Task; + +class ExceptionCatcher +{ + public: + void set_exception(std::exception_ptr ex) + { + auto lock = std::lock_guard(m_mutex); + m_exceptions.push(ex); + } + + bool has_exception() + { + auto lock = std::lock_guard(m_mutex); + return not m_exceptions.empty(); + } + + void rethrow_next_exception() + { + auto lock = std::lock_guard(m_mutex); + + if (m_exceptions.empty()) + { + return; + } + + auto ex = m_exceptions.front(); + m_exceptions.pop(); + + std::rethrow_exception(ex); + } + + private: + std::mutex m_mutex{}; + std::queue m_exceptions{}; +}; + +template +class BoostFutureAwaiter +{ + class Awaiter; + + public: + BoostFutureAwaiter(std::function fn) : m_fn(std::move(fn)) {} + + template + auto operator()(ArgsT&&... args) -> Awaiter + { + // Make a copy of m_fn here so we can call this operator again + return Awaiter(m_fn, std::forward(args)...); + } + + private: + class Awaiter + { + public: + using return_t = typename std::function::result_type; + + template + Awaiter(std::function fn, ArgsT&&... args) + { + m_future = boost::fibers::async(boost::fibers::launch::post, fn, std::forward(args)...); + } + + bool await_ready() noexcept + { + return false; + } + + bool await_suspend(std::coroutine_handle<> continuation) noexcept + { + // Launch a new fiber that waits on the future and then resumes the coroutine + boost::fibers::async( + boost::fibers::launch::post, + [this](std::coroutine_handle<> continuation) { + // Wait on the future + m_future.wait(); + + // Resume the coroutine + continuation.resume(); + }, + std::move(continuation)); + + return true; + } + + auto await_resume() noexcept + { + return m_future.get(); + } + + private: + boost::fibers::future m_future; + std::function)> m_inner_fn; + }; + + std::function m_fn; +}; + +template +class IReadable +{ + public: + virtual ~IReadable() = default; + virtual Task async_read(T& value) = 0; +}; + +template +class BoostFutureReader : public IReadable +{ + public: + template + BoostFutureReader(FuncT&& fn) : m_awaiter(std::forward(fn)) + {} + + Task async_read(T& value) override + { + co_return co_await m_awaiter(std::ref(value)); + } + + private: + BoostFutureAwaiter m_awaiter; +}; + +template +class IWritable +{ + public: + virtual ~IWritable() = default; + virtual Task async_write(T&& value) = 0; +}; + +template +class BoostFutureWriter : public IWritable +{ + public: + template + BoostFutureWriter(FuncT&& fn) : m_awaiter(std::forward(fn)) + {} + + Task async_write(T&& value) override + { + co_return co_await m_awaiter(std::move(value)); + } + + private: + BoostFutureAwaiter m_awaiter; +}; + +template +class CoroutineRunnableSink : public mrc::node::WritableProvider, + public mrc::node::ReadableAcceptor, + public mrc::node::SinkChannelOwner +{ + protected: + CoroutineRunnableSink() + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + auto build_readable_generator(std::stop_token stop_token) -> mrc::coroutines::AsyncGenerator + { + auto read_awaiter = BoostFutureReader([this](T& value) { + return this->get_readable_edge()->await_read(value); + }); + + while (!stop_token.stop_requested()) + { + T value; + + // Pull a message off of the upstream channel + auto status = co_await read_awaiter.async_read(std::ref(value)); + + if (status != mrc::channel::Status::success) + { + break; + } + + co_yield std::move(value); + } + + co_return; + } +}; + +template +class CoroutineRunnableSource : public mrc::node::WritableAcceptor, + public mrc::node::ReadableProvider, + public mrc::node::SourceChannelOwner +{ + protected: + CoroutineRunnableSource() + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + // auto build_readable_generator(std::stop_token stop_token) + // -> mrc::coroutines::AsyncGenerator + // { + // while (!stop_token.stop_requested()) + // { + // co_yield mrc::coroutines::detail::VoidValue{}; + // } + + // co_return; + // } + + auto build_writable_receiver() -> std::shared_ptr> + { + return std::make_shared>([this](T&& value) { + return this->get_writable_edge()->await_write(std::move(value)); + }); + } +}; + +template +class AsyncioRunnable : public CoroutineRunnableSink, + public CoroutineRunnableSource, + public mrc::runnable::RunnableWithContext<> +{ + using state_t = mrc::runnable::Runnable::State; + using task_buffer_t = mrc::coroutines::ClosableRingBuffer; + + public: + AsyncioRunnable(size_t concurrency = 8) : m_concurrency(concurrency){}; + ~AsyncioRunnable() override = default; + + private: + void run(mrc::runnable::Context& ctx) override; + void on_state_update(const state_t& state) final; + + Task main_task(std::shared_ptr scheduler); + + Task process_one(InputT&& value, + std::shared_ptr> writer, + task_buffer_t& task_buffer, + std::shared_ptr on, + ExceptionCatcher& catcher); + + virtual mrc::coroutines::AsyncGenerator on_data(InputT&& value) = 0; + + std::stop_source m_stop_source; + + size_t m_concurrency{8}; +}; + +template +void AsyncioRunnable::run(mrc::runnable::Context& ctx) +{ + // auto& scheduler = ctx.scheduler(); + + // TODO(MDD): Eventually we should get this from the context object. For now, just create it directly + auto scheduler = std::make_shared(m_concurrency); + + // Now use the scheduler to run the main task until it is complete + scheduler->run_until_complete(this->main_task(scheduler)); + + // Need to drop the output edges + mrc::node::SourceProperties::release_edge_connection(); + mrc::node::SinkProperties::release_edge_connection(); +} + +template +Task AsyncioRunnable::main_task(std::shared_ptr scheduler) +{ + // Get the generator and receiver + auto input_generator = CoroutineRunnableSink::build_readable_generator(m_stop_source.get_token()); + auto output_receiver = CoroutineRunnableSource::build_writable_receiver(); + + // Create the task buffer to limit the number of running tasks + task_buffer_t task_buffer{{.capacity = m_concurrency}}; + + size_t i = 0; + + auto iter = co_await input_generator.begin(); + + coroutines::TaskContainer outstanding_tasks(scheduler); + + ExceptionCatcher catcher{}; + + while (not catcher.has_exception() and iter != input_generator.end()) + { + // Weird bug, cant directly move the value into the process_one call + auto data = std::move(*iter); + + // Wait for an available slot in the task buffer + co_await task_buffer.write(i); + + outstanding_tasks.start(this->process_one(std::move(data), output_receiver, task_buffer, scheduler, catcher)); + + // Advance the iterator + co_await ++iter; + ++i; + } + + // Close the buffer + task_buffer.close(); + + // Now block until all tasks are complete + co_await task_buffer.completed(); + + co_await outstanding_tasks.garbage_collect_and_yield_until_empty(); + + catcher.rethrow_next_exception(); +} + +template +Task AsyncioRunnable::process_one(InputT&& value, + std::shared_ptr> writer, + task_buffer_t& task_buffer, + std::shared_ptr on, + ExceptionCatcher& catcher) +{ + co_await on->yield(); + + try + { + // Call the on_data function + auto on_data_gen = this->on_data(std::move(value)); + + auto iter = co_await on_data_gen.begin(); + + while (iter != on_data_gen.end()) + { + // Weird bug, cant directly move the value into the async_write call + auto data = std::move(*iter); + + co_await writer->async_write(std::move(data)); + + // Advance the iterator + co_await ++iter; + } + } catch (...) + { + // TODO(cwharris): communicate error back to the runnable's main main task + catcher.set_exception(std::current_exception()); + } + + // Return the slot to the task buffer + co_await task_buffer.read(); +} + +template +void AsyncioRunnable::on_state_update(const state_t& state) +{ + switch (state) + { + case state_t::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case state_t::Kill: + + m_stop_source.request_stop(); + break; + + default: + break; + } +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp new file mode 100644 index 000000000..26e30fd58 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp @@ -0,0 +1,126 @@ +#pragma once + +#include "pymrc/coro.hpp" +#include "pymrc/utilities/acquire_gil.hpp" + +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace mrc::pymrc { +class AsyncioScheduler : public mrc::coroutines::Scheduler +{ + public: + AsyncioScheduler(size_t concurrency) {} + + std::string description() const override + { + return "AsyncioScheduler"; + } + + void resume(std::coroutine_handle<> coroutine) override + { + if (coroutine.done()) + { + LOG(WARNING) << "AsyncioScheduler::resume() > Attempted to resume a completed coroutine"; + return; + } + + py::gil_scoped_acquire gil; + + auto& loop = this->get_loop(); + + // TODO(MDD): Check whether or not we need thread safe version + loop.attr("call_soon_threadsafe")(py::cpp_function([this, handle = std::move(coroutine)]() { + if (handle.done()) + { + LOG(WARNING) << "AsyncioScheduler::resume() > Attempted to resume a completed coroutine"; + return; + } + + py::gil_scoped_release nogil; + + handle.resume(); + })); + } + + mrc::pymrc::PyHolder& init_loop() + { + CHECK_EQ(PyGILState_Check(), 1) << "Must have the GIL when calling AsyncioScheduler::init_loop()"; + + std::unique_lock lock(m_mutex); + + if (m_loop) + { + return m_loop; + } + + auto asyncio_mod = py::module_::import("asyncio"); + + py::object loop; + + try + { + // Otherwise check if one is already allocated + loop = asyncio_mod.attr("get_running_loop")(); + } catch (std::runtime_error&) + { + // Need to create a loop + LOG(INFO) << "AsyncioScheduler::run() > Creating new event loop"; + + // Gets (or more likely, creates) an event loop and runs it forever until stop is called + loop = asyncio_mod.attr("new_event_loop")(); + + // Set the event loop as the current event loop + asyncio_mod.attr("set_event_loop")(loop); + } + + m_loop = std::move(loop); + + return m_loop; + } + + // Runs the task until its complete + void run_until_complete(coroutines::Task<>&& task) + { + mrc::pymrc::AcquireGIL gil; + + auto& loop = this->init_loop(); + + LOG(INFO) << "AsyncioScheduler::run() > Calling run_until_complete() on main_task()"; + + // Use the BoostFibersMainPyAwaitable to allow fibers to be progressed + loop.attr("run_until_complete")(mrc::pymrc::coro::BoostFibersMainPyAwaitable(std::move(task))); + } + + private: + std::coroutine_handle<> schedule_operation(Operation* operation) override + { + this->resume(std::move(operation->m_awaiting_coroutine)); + + return std::noop_coroutine(); + } + + mrc::pymrc::PyHolder& get_loop() + { + if (!m_loop) + { + throw std::runtime_error("Must call init_loop() before get_loop()"); + } + + // TODO(MDD): Check that we are on the same thread as the loop + return m_loop; + } + + std::mutex m_mutex; + + std::atomic_size_t m_outstanding{0}; + + mrc::pymrc::PyHolder m_loop; +}; + +} // namespace mrc::pymrc