diff --git a/python/mrc/tests/utils.cpp b/python/mrc/tests/utils.cpp index a01c061ed..b05695c86 100644 --- a/python/mrc/tests/utils.cpp +++ b/python/mrc/tests/utils.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -32,18 +33,7 @@ namespace mrc::pytests { namespace py = pybind11; -// Simple test class which acquires the GIL in it's destructor -struct ObjUsingGil -{ - ObjUsingGil() = default; - ~ObjUsingGil() - { - LOG(INFO) << "ObjUsingGil::~ObjUsingGil()"; - py::gil_scoped_acquire gil; - LOG(INFO) << "ObjUsingGil::~ObjUsingGil()+gil"; - } -}; - +// Simple test class which invokes Python's GC. Needed to repro # struct ObjCallingGC { ObjCallingGC() = default; @@ -76,8 +66,6 @@ PYBIND11_MODULE(utils, py_mod) }, py::arg("msg") = ""); - py::class_(py_mod, "ObjUsingGil").def(py::init<>()); - py::class_(py_mod, "ObjCallingGC").def(py::init<>()).def_static("finalize", &ObjCallingGC::finalize); py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." diff --git a/python/tests/test_gil_tls.py b/python/tests/test_gil_tls.py index 56248d75a..8c2b232d5 100644 --- a/python/tests/test_gil_tls.py +++ b/python/tests/test_gil_tls.py @@ -1,53 +1,30 @@ -import gc import threading import weakref import mrc from mrc.tests.utils import ObjCallingGC -from mrc.tests.utils import ObjUsingGil TLS = threading.local() -class Holder: +def test_gc_called_in_thread_finalizer(): + mrc.logging.log("Building pipeline") - def __init__(self, obj): - """Intentionally create a cycle to delay obj's destruction""" - self.obj = obj - self.cycle = self + def source_gen(): + mrc.logging.log("source_gen") + x = ObjCallingGC() + weakref.finalize(x, x.finalize) + TLS.x = x + yield x - def __del__(self): - mrc.logging.log("Holder.__del__") - self.obj = None + def init_seg(builder: mrc.Builder): + builder.make_source("souce_gen", source_gen) + pipe = mrc.Pipeline() + pipe.make_segment("seg1", init_seg) -class ThreadTest(threading.Thread): - - def _create_obs(self): - TLS.h = Holder(ObjUsingGil()) - TLS.ocg = ObjCallingGC() - weakref.finalize(TLS.ocg, TLS.ocg.finalize) - - def run(self): - mrc.logging.log("Running thread") - self._create_obs() - mrc.logging.log("Thread complete") - - -def test_gil_tls(): - t = ThreadTest() - t.start() - t.join() - mrc.logging.log("Thread joined") - - -def main(): - mrc.logging.init_logging(__name__) - gc.disable() - gc.set_debug(gc.DEBUG_STATS) - test_gil_tls() - mrc.logging.log("Exiting main") - - -if __name__ == "__main__": - main() + options = mrc.Options() + executor = mrc.Executor(options) + executor.register_pipeline(pipe) + executor.start() + executor.join()