Skip to content

Commit

Permalink
fix test_asyncio_runnable test class to support multiple test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
cwharris committed Nov 1, 2023
1 parent e509bbb commit 2b2f26e
Showing 1 changed file with 39 additions and 58 deletions.
97 changes: 39 additions & 58 deletions python/mrc/_pymrc/tests/test_asyncio_runnable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,65 +52,46 @@ namespace pymrc = mrc::pymrc;
using namespace std::string_literals;
using namespace py::literals;

PYMRC_TEST_CLASS(AsyncioRunnable);

class MyAsyncioRunnable : public pymrc::AsyncioRunnable<int, unsigned int>
class TestWithPythonInterpreter : public ::testing::Test
{
mrc::coroutines::AsyncGenerator<unsigned int> on_data(int&& value) override
{
co_yield value;
};
};

// TEST_F(TestAsyncioRunnable, YieldMultipleValues)
// {
// std::atomic<unsigned int> counter = 0;
// pymrc::Pipeline p;

// pybind11::module_::import("mrc.core.coro");

// auto init = [&counter](mrc::segment::IBuilder& seg) {
// auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
// if (s.is_subscribed())
// {
// s.on_next(5);
// s.on_next(10);
// }

// s.on_completed();
// });

// auto internal = seg.construct_object<MyAsyncioRunnable>("internal");
public:
virtual void interpreter_setup() = 0;

// auto sink = seg.make_sink<unsigned int>("sink", [&counter](unsigned int x) {
// counter.fetch_add(x, std::memory_order_relaxed);
// });
protected:
void SetUp() override;

// seg.make_edge(src, internal);
// seg.make_edge(internal, sink);
// };
void TearDown() override;

// p.make_segment("seg1"s, init);
// p.make_segment("seg2"s, init);
private:
static bool m_initialized;
};

// auto options = std::make_shared<mrc::Options>();
// options->topology().user_cpuset("0");
// // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific.
// options->engine_factories().set_default_engine_type(runnable::EngineType::Thread);
bool TestWithPythonInterpreter::m_initialized;

// pymrc::Executor exec{options};
// exec.register_pipeline(p);
void TestWithPythonInterpreter::SetUp()
{
if (!m_initialized)
{
m_initialized = true;
pybind11::initialize_interpreter();
interpreter_setup();
}
}

// exec.start();
// exec.join();
void TestWithPythonInterpreter::TearDown() {}

// EXPECT_EQ(counter, 30);
// }
class __attribute__((visibility("default"))) TestAsyncioRunnable : public TestWithPythonInterpreter
{
void interpreter_setup() override
{
pybind11::module_::import("mrc.core.coro");
}
};

class PythonFlatmapAsyncioRunnable : public pymrc::AsyncioRunnable<int, int>
class PythonCallbackAsyncioRunnable : public pymrc::AsyncioRunnable<int, int>
{
public:
PythonFlatmapAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {}
PythonCallbackAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {}

mrc::coroutines::AsyncGenerator<int> on_data(int&& value) override
{
Expand Down Expand Up @@ -140,10 +121,12 @@ TEST_F(TestAsyncioRunnable, UseAsyncioTasks)
{
py::object globals = py::globals();
py::exec(
"\
async def fn(value): \
return value \
",
R"(
async def fn(value):
import asyncio
await asyncio.sleep(0)
return value * 2
)",
globals);

pymrc::PyObjectHolder fn = static_cast<py::object>(globals["fn"]);
Expand All @@ -153,8 +136,6 @@ async def fn(value): \
std::atomic<unsigned int> counter = 0;
pymrc::Pipeline p;

pybind11::module_::import("mrc.core.coro");

auto init = [&counter, &fn](mrc::segment::IBuilder& seg) {
auto src = seg.make_source<int>("src", [](rxcpp::subscriber<int>& s) {
if (s.is_subscribed())
Expand All @@ -166,9 +147,9 @@ async def fn(value): \
s.on_completed();
});

auto internal = seg.construct_object<PythonFlatmapAsyncioRunnable>("internal", fn);
auto internal = seg.construct_object<PythonCallbackAsyncioRunnable>("internal", fn);

auto sink = seg.make_sink<unsigned int>("sink", [&counter](unsigned int x) {
auto sink = seg.make_sink<int>("sink", [&counter](int x) {
counter.fetch_add(x, std::memory_order_relaxed);
});

Expand All @@ -190,5 +171,5 @@ async def fn(value): \
exec.start();
exec.join();

EXPECT_EQ(counter, 30);
}
EXPECT_EQ(counter, 60);
}

0 comments on commit 2b2f26e

Please sign in to comment.