Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback mechanism for GUI mode #301

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
10 changes: 6 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pytest_addoption(parser):


@pytest.fixture(scope="session")
def ij_fixture(request):
def ij(request):
"""
Create an ImageJ instance to be used by the whole testing environment
:param request: Pytest variable passed in to fixtures
Expand All @@ -43,12 +43,14 @@ def ij_fixture(request):
legacy = request.config.getoption("--legacy")
headless = request.config.getoption("--headless")

imagej.when_imagej_starts(lambda ij: setattr(ij, "_when_imagej_starts_result", "success"))

mode = "headless" if headless else "interactive"
ij_wrapper = imagej.init(ij_dir, mode=mode, add_legacy=legacy)
ij = imagej.init(ij_dir, mode=mode, add_legacy=legacy)

yield ij_wrapper
yield ij

ij_wrapper.dispose()
ij.dispose()


def str2bool(v):
Expand Down
53 changes: 43 additions & 10 deletions src/imagej/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
__version__ = sj.get_version("pyimagej")

_logger = logging.getLogger(__name__)
rai_lock = threading.Lock()
_init_callbacks = []
_rai_lock = threading.Lock()

# Enable debug logging if DEBUG environment variable is set.
try:
Expand Down Expand Up @@ -1007,14 +1008,14 @@ def _op(self):
def _ra(self):
threadLocal = getattr(self, "_threadLocal", None)
if threadLocal is None:
with rai_lock:
with _rai_lock:
threadLocal = getattr(self, "_threadLocal", None)
if threadLocal is None:
threadLocal = threading.local()
self._threadLocal = threadLocal
ra = getattr(threadLocal, "ra", None)
if ra is None:
with rai_lock:
with _rai_lock:
ra = getattr(threadLocal, "ra", None)
if ra is None:
ra = self.randomAccess()
Expand Down Expand Up @@ -1212,12 +1213,27 @@ def init(
if not success:
raise RuntimeError("Failed to create a JVM with the requested environment.")

def run_callbacks(ij):
# invoke registered callback functions
for callback in _init_callbacks:
callback(ij)
return ij

if mode == Mode.GUI:
# Show the GUI and block.
global gateway

def show_gui_and_run_callbacks():
global gateway
gateway = _create_gateway()
gateway.ui().showUI()
run_callbacks(gateway)
return gateway

if macos:
# NB: This will block the calling (main) thread forever!
try:
setupGuiEnvironment(lambda: _create_gateway().ui().showUI())
setupGuiEnvironment(show_gui_and_run_callbacks)
except ModuleNotFoundError as e:
if e.msg == "No module named 'PyObjCTools'":
advice = (
Expand All @@ -1237,16 +1253,33 @@ def init(
raise
else:
# Create and show the application.
gateway = _create_gateway()
gateway.ui().showUI()
gateway = show_gui_and_run_callbacks()
# We are responsible for our own blocking.
# TODO: Poll using something better than ui().isVisible().
while gateway.ui().isVisible():
time.sleep(1)
return None
else:
# HEADLESS or INTERACTIVE mode: create the gateway and return it.
return _create_gateway()

return gateway

# HEADLESS or INTERACTIVE mode: create the gateway and return it.
return run_callbacks(_create_gateway())


def when_imagej_starts(f) -> None:
elevans marked this conversation as resolved.
Show resolved Hide resolved
"""
Registers a function to be called immediately after ImageJ2 starts.
This is useful especially with GUI mode, to perform additional
configuration and operations following initialization of ImageJ2,
because the use of GUI mode blocks the calling thread indefinitely.

:param f: Single-argument function to invoke during imagej.init().
The function will be passed the newly created ImageJ2 Gateway
as its sole argument, and called as the final action of the
init function before it returns or blocks.
"""
# Add function to the list of callbacks to invoke upon start_jvm().
global _init_callbacks
_init_callbacks.append(f)


def imagej_main():
Expand Down
7 changes: 7 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def test_when_imagej_starts(ij):
"""
The ImageJ2 gateway test fixture registers a callback function via
when_imagej_starts, which injects a small piece of data into the gateway
object. We check for that data here to make sure the callback happened.
"""
assert "success" == getattr(ij, "_when_imagej_starts_result", None)
6 changes: 3 additions & 3 deletions tests/test_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@


@pytest.mark.parametrize(argnames="ctype,jtype_str,value", argvalues=parameters)
def test_ctype_to_realtype(ij_fixture, ctype, jtype_str, value):
def test_ctype_to_realtype(ij, ctype, jtype_str, value):
py_type = ctype(value)
# Convert the ctype into a RealType
converted = ij_fixture.py.to_java(py_type)
converted = ij.py.to_java(py_type)
jtype = sj.jimport(jtype_str)
assert isinstance(converted, jtype)
assert converted.get() == value
# Convert the RealType back into a ctype
converted_back = ij_fixture.py.from_java(converted)
converted_back = ij.py.from_java(converted)
assert isinstance(converted_back, ctype)
assert converted_back.value == value
16 changes: 8 additions & 8 deletions tests/test_fiji.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@
# -- Tests --


def test_plugins_load_using_pairwise_stitching(ij_fixture):
def test_plugins_load_using_pairwise_stitching(ij):
try:
sj.jimport("plugin.Stitching_Pairwise")
except TypeError:
pytest.skip("No Pairwise Stitching plugin available. Skipping test.")

if not ij_fixture.legacy:
if not ij.legacy:
pytest.skip("No original ImageJ. Skipping test.")
if ij_fixture.ui().isHeadless():
if ij.ui().isHeadless():
pytest.skip("No GUI. Skipping test.")

tile1 = ij_fixture.IJ.createImage("Tile1", "8-bit random", 512, 512, 1)
tile2 = ij_fixture.IJ.createImage("Tile2", "8-bit random", 512, 512, 1)
tile1 = ij.IJ.createImage("Tile1", "8-bit random", 512, 512, 1)
tile2 = ij.IJ.createImage("Tile2", "8-bit random", 512, 512, 1)
args = {"first_image": tile1.getTitle(), "second_image": tile2.getTitle()}
ij_fixture.py.run_plugin("Pairwise stitching", args)
result_name = ij_fixture.WindowManager.getCurrentImage().getTitle()
ij.py.run_plugin("Pairwise stitching", args)
result_name = ij.WindowManager.getCurrentImage().getTitle()

ij_fixture.IJ.run("Close All", "")
ij.IJ.run("Close All", "")

assert result_name == "Tile1<->Tile2"
Loading