From 66d3b07736619994396d890b4f149f421abf3a7c Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 22 Jan 2024 10:09:41 -0800 Subject: [PATCH] Add new test for a variation on #360 when a module has been loaded and added to the segment [no ci] --- python/tests/test_pipeline.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/python/tests/test_pipeline.py b/python/tests/test_pipeline.py index 6c268be06..55176517f 100644 --- a/python/tests/test_pipeline.py +++ b/python/tests/test_pipeline.py @@ -22,6 +22,8 @@ import pytest import mrc +# Required to register sample modules with the ModuleRegistry +import mrc.tests.sample_modules import mrc.tests.test_edges_cpp as m # from mrc.core.options import PlacementStrategy @@ -484,6 +486,59 @@ def on_next(input): executor.join() +def test_module_init_error(): + """ + Test for variation on issue #360 where the error is raised in the module init function + Reproduces Morpheus issue #953 + """ + + def gen_data(): + for i in range(10): + yield i + + def init(builder: mrc.Builder): + + def on_next(input): + pass + + def on_error(): + pass + + def on_complete(): + pass + + config = {"config_key_1": True} + + registry = mrc.ModuleRegistry + + source1 = builder.make_source("src1", gen_data) + source2 = builder.make_source("src2", gen_data) + fn_constructor = registry.get_module_constructor("SimpleModule", "mrc_unittest") + simple_mod = fn_constructor("ModuleInitializationTest_mod2", config) + sink = builder.make_sink("sink", on_next, on_error, on_complete) + + builder.init_module(simple_mod) + builder.make_edge(source1, simple_mod.input_port("input1")) + builder.make_edge(source2, simple_mod.input_port("input2")) + builder.make_edge(simple_mod.output_port("output1"), sink) + builder.make_edge(simple_mod.output_port("output2"), sink) + + raise RuntimeError("Test for #360") + + pipe = mrc.Pipeline() + + pipe.make_segment("segment", init) + + options = mrc.Options() + + executor = mrc.Executor(options) + executor.register_pipeline(pipe) + + with pytest.raises(RuntimeError): + executor.start() + executor.join() + + if (__name__ in ("__main__", )): test_dynamic_port_creation_good() test_dynamic_port_creation_bad() @@ -491,3 +546,4 @@ def on_next(input): test_dynamic_port_get_ingress_egress() test_dynamic_port_with_type_get_ingress_egress() test_segment_init_error() + test_module_init_error()