diff --git a/cpp/mrc/src/public/segment/builder.cpp b/cpp/mrc/src/public/segment/builder.cpp index 68dd0530f..f8a063e20 100644 --- a/cpp/mrc/src/public/segment/builder.cpp +++ b/cpp/mrc/src/public/segment/builder.cpp @@ -49,14 +49,14 @@ void Builder::init_module(std::shared_ptr smodule) VLOG(2) << "Initializing module: " << m_namespace_prefix; smodule->m_module_instance_registered_namespace = m_namespace_prefix; smodule->initialize(*this); - ns_pop(); // TODO(Devin): Maybe a better way to do this with compile time type ledger. if (std::dynamic_pointer_cast(smodule) != nullptr) { - VLOG(2) << "Registering persistent module -> '" << smodule->component_prefix() << "'"; + VLOG(2) << "Registering persistent module -> '" << m_namespace_prefix << "'"; m_backend.add_module(m_namespace_prefix, smodule); } + ns_pop(); } std::shared_ptr Builder::get_ingress(std::string name, std::type_index type_index) diff --git a/external/utilities b/external/utilities index 429ae1840..ad81faac9 160000 --- a/external/utilities +++ b/external/utilities @@ -1 +1 @@ -Subproject commit 429ae1840d02cf4cd2d4f1bbbac85a61f7ee8907 +Subproject commit ad81faac968919678988dc33cc89c85ddd52a643 diff --git a/python/tests/test_segment_modules.py b/python/tests/test_segment_modules.py index 08c78fc36..8b6b96908 100644 --- a/python/tests/test_segment_modules.py +++ b/python/tests/test_segment_modules.py @@ -117,7 +117,6 @@ def on_complete(): def test_py_constructor(): - config = {"config_key_1": True} registry = mrc.ModuleRegistry @@ -312,11 +311,41 @@ def on_complete(): def test_py_module_nesting(): - def gen_data(): - for i in range(0, 43): - yield True + def init_wrapper(builder: mrc.Builder): + global packet_count + packet_count = 0 + + def on_next(input): global packet_count packet_count += 1 + logging.info("Sinking {}".format(input)) + + def on_error(): + pass + + def on_complete(): + pass + + nested_mod = builder.load_module("NestedModule", "mrc_unittest", "ModuleNestingTest_mod1", {}) + nested_sink = builder.make_sink("nested_sink", on_next, on_error, on_complete) + + builder.make_edge(nested_mod.output_port("nested_module_output"), nested_sink) + + pipeline = mrc.Pipeline() + pipeline.make_segment("ModuleNesting_Segment", init_wrapper) + + options = mrc.Options() + options.topology.user_cpuset = "0-1" + + executor = mrc.Executor(options) + executor.register_pipeline(pipeline) + executor.start() + executor.join() + + assert packet_count == 4 + + +def test_py_modules_dont_overwrite(): def init_wrapper(builder: mrc.Builder): global packet_count @@ -334,9 +363,18 @@ def on_complete(): pass nested_mod = builder.load_module("NestedModule", "mrc_unittest", "ModuleNestingTest_mod1", {}) + + # Make sure we can't re-register the same name + with pytest.raises(RuntimeError): + this_should_fail = builder.load_module( # noqa + "NestedModule", "mrc_unittest", "ModuleNestingTest_mod1", {}) + + nested_mod2 = builder.load_module("NestedModule", "mrc_unittest", "ModuleNestingTest_mod2", {}) nested_sink = builder.make_sink("nested_sink", on_next, on_error, on_complete) + nested_sink2 = builder.make_sink("nested_sink2", on_next, on_error, on_complete) builder.make_edge(nested_mod.output_port("nested_module_output"), nested_sink) + builder.make_edge(nested_mod2.output_port("nested_module_output"), nested_sink2) pipeline = mrc.Pipeline() pipeline.make_segment("ModuleNesting_Segment", init_wrapper) @@ -349,7 +387,7 @@ def on_complete(): executor.start() executor.join() - assert packet_count == 4 + assert packet_count == 8 if (__name__ in ("__main__", )): @@ -358,5 +396,6 @@ def on_complete(): test_py_module_as_sink() test_py_module_chaining() test_py_module_nesting() + test_py_modules_dont_overwrite() test_py_constructor() test_py_module_initialization()