From 2b2f26eb094dcad6c1dcaf934fcf028fb44f9e8b Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Wed, 1 Nov 2023 18:29:11 +0000 Subject: [PATCH] fix test_asyncio_runnable test class to support multiple test cases --- .../_pymrc/tests/test_asyncio_runnable.cpp | 97 ++++++++----------- 1 file changed, 39 insertions(+), 58 deletions(-) diff --git a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp index 265876f32..5d15a9a3c 100644 --- a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp +++ b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp @@ -52,65 +52,46 @@ namespace pymrc = mrc::pymrc; using namespace std::string_literals; using namespace py::literals; -PYMRC_TEST_CLASS(AsyncioRunnable); - -class MyAsyncioRunnable : public pymrc::AsyncioRunnable +class TestWithPythonInterpreter : public ::testing::Test { - mrc::coroutines::AsyncGenerator on_data(int&& value) override - { - co_yield value; - }; -}; - -// TEST_F(TestAsyncioRunnable, YieldMultipleValues) -// { -// std::atomic counter = 0; -// pymrc::Pipeline p; - -// pybind11::module_::import("mrc.core.coro"); - -// auto init = [&counter](mrc::segment::IBuilder& seg) { -// auto src = seg.make_source("src", [](rxcpp::subscriber& s) { -// if (s.is_subscribed()) -// { -// s.on_next(5); -// s.on_next(10); -// } - -// s.on_completed(); -// }); - -// auto internal = seg.construct_object("internal"); + public: + virtual void interpreter_setup() = 0; -// auto sink = seg.make_sink("sink", [&counter](unsigned int x) { -// counter.fetch_add(x, std::memory_order_relaxed); -// }); + protected: + void SetUp() override; -// seg.make_edge(src, internal); -// seg.make_edge(internal, sink); -// }; + void TearDown() override; -// p.make_segment("seg1"s, init); -// p.make_segment("seg2"s, init); + private: + static bool m_initialized; +}; -// auto options = std::make_shared(); -// options->topology().user_cpuset("0"); -// // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. -// options->engine_factories().set_default_engine_type(runnable::EngineType::Thread); +bool TestWithPythonInterpreter::m_initialized; -// pymrc::Executor exec{options}; -// exec.register_pipeline(p); +void TestWithPythonInterpreter::SetUp() +{ + if (!m_initialized) + { + m_initialized = true; + pybind11::initialize_interpreter(); + interpreter_setup(); + } +} -// exec.start(); -// exec.join(); +void TestWithPythonInterpreter::TearDown() {} -// EXPECT_EQ(counter, 30); -// } +class __attribute__((visibility("default"))) TestAsyncioRunnable : public TestWithPythonInterpreter +{ + void interpreter_setup() override + { + pybind11::module_::import("mrc.core.coro"); + } +}; -class PythonFlatmapAsyncioRunnable : public pymrc::AsyncioRunnable +class PythonCallbackAsyncioRunnable : public pymrc::AsyncioRunnable { public: - PythonFlatmapAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {} + PythonCallbackAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {} mrc::coroutines::AsyncGenerator on_data(int&& value) override { @@ -140,10 +121,12 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasks) { py::object globals = py::globals(); py::exec( - "\ -async def fn(value): \ - return value \ - ", + R"( + async def fn(value): + import asyncio + await asyncio.sleep(0) + return value * 2 + )", globals); pymrc::PyObjectHolder fn = static_cast(globals["fn"]); @@ -153,8 +136,6 @@ async def fn(value): \ std::atomic counter = 0; pymrc::Pipeline p; - pybind11::module_::import("mrc.core.coro"); - auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { auto src = seg.make_source("src", [](rxcpp::subscriber& s) { if (s.is_subscribed()) @@ -166,9 +147,9 @@ async def fn(value): \ s.on_completed(); }); - auto internal = seg.construct_object("internal", fn); + auto internal = seg.construct_object("internal", fn); - auto sink = seg.make_sink("sink", [&counter](unsigned int x) { + auto sink = seg.make_sink("sink", [&counter](int x) { counter.fetch_add(x, std::memory_order_relaxed); }); @@ -190,5 +171,5 @@ async def fn(value): \ exec.start(); exec.join(); - EXPECT_EQ(counter, 30); -} \ No newline at end of file + EXPECT_EQ(counter, 60); +}