diff --git a/napari_spotiflow/_dock_widget.py b/napari_spotiflow/_dock_widget.py index c34883d..57abf52 100644 --- a/napari_spotiflow/_dock_widget.py +++ b/napari_spotiflow/_dock_widget.py @@ -1,24 +1,26 @@ -from magicgui import magicgui -from magicgui import widgets as mw -from magicgui.application import use_app - import functools -import time -import numpy as np - +import os from copy import deepcopy -from pathlib import Path +from typing import List, Union from warnings import warn +import logging import napari -from typing import List, Union -from enum import Enum -from psygnal import Signal +import numpy as np +from magicgui import magicgui +from magicgui import widgets as mw +from magicgui.application import use_app from .utils import _prepare_input, _validate_axes +logging.basicConfig(level=logging.ERROR) +log = logging.getLogger(__name__) + BASE_IMAGE_AXES_CHOICES = ["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"] +BASE_IMAGE_AXES_CHOICES_3D = [f"Z{axes}" if "T" not in axes else f"TZ{axes[1:]}" for axes in BASE_IMAGE_AXES_CHOICES] + CURR_IMAGE_AXES_CHOICES = deepcopy(BASE_IMAGE_AXES_CHOICES) +IS_3D = False def abspath(root, relpath): from pathlib import Path @@ -38,8 +40,6 @@ def change_handler(*widgets, init=True): def decorator_change_handler(handler): @functools.wraps(handler) def wrapper(*args): - source = Signal.sender() - emitter = Signal.current_emitter() return handler(*args) for widget in widgets: @@ -55,23 +55,27 @@ def plugin_wrapper(): # delay imports until plugin is requested by user import torch from spotiflow.model import Spotiflow + from spotiflow.model.pretrained import list_registered, _REGISTERED from spotiflow.utils import normalize - from spotiflow.model.pretrained import list_registered + from napari_spotiflow import _point_layer2d_default_kwargs def get_data(image): image = image.data[0] if image.multiscale else image.data return np.asarray(image) - models_reg = list_registered() + models_reg_2d = [r for r in list_registered() if not _REGISTERED[r].is_3d] + models_reg_3d = [r for r in list_registered() if _REGISTERED[r].is_3d] - if 'general' in models_reg: - models_reg = ['general'] + sorted([m for m in models_reg if m != 'general']) + if 'general' in models_reg_2d: + models_reg_2d = ['general'] + sorted([m for m in models_reg_2d if m != 'general']) else: - models_reg = sorted(models_reg) - - model_configs = dict() - model_selected = None + models_reg_2d = sorted(models_reg_2d) + + if 'synth_3d' in models_reg_3d: + models_reg_3d = ['synth_3d'] + sorted([m for m in models_reg_3d if m != 'synth_3d']) + else: + models_reg_3d = sorted(models_reg_3d) CUSTOM_MODEL = 'CUSTOM_MODEL' model_type_choices = [('Pre-trained', Spotiflow), ('Custom', CUSTOM_MODEL)] @@ -82,6 +86,8 @@ def get_data(image): if len(image_layers) > 0: ndim_first = image_layers[0].data.ndim CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == ndim_first] + else: + CURR_IMAGE_AXES_CHOICES = [""] @@ -98,8 +104,10 @@ def get_model(model_type, model, device): DEFAULTS = dict ( + mode = '2D', model_type = Spotiflow, model2d = 'general', + model3d = 'synth_3d', norm_image = True, perc_low = 1.0, perc_high = 99.8, @@ -118,14 +126,15 @@ def get_model(model_type, model, device): # ------------------------------------------------------------------------- logo = abspath(__file__, 'resources/spotiflow_transp_small.png') - @magicgui ( label_head = dict(widget_type='Label', label=f'

'), image = dict(label='Input Image'), image_axes = dict(widget_type='RadioButtons', label='Image axes order', orientation='horizontal', choices=get_image_axes_choices, value=CURR_IMAGE_AXES_CHOICES[0]), label_nn = dict(widget_type='Label', label='
Neural Network Prediction:'), + mode = dict(widget_type='RadioButtons', label='Mode', orientation='horizontal', choices=['2D', '3D'], value=DEFAULTS["mode"]), model_type = dict(widget_type='RadioButtons', label='Model Type', orientation='horizontal', choices=model_type_choices, value=DEFAULTS['model_type']), - model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model', choices=models_reg, value=DEFAULTS['model2d']), + model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (2D)', choices=models_reg_2d, value=DEFAULTS['model2d']), + model3d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (3D)', choices=models_reg_3d, value=DEFAULTS['model3d']), model_folder = dict(widget_type='FileEdit', visible=True, label='Custom Model', mode='d'), norm_image = dict(widget_type='CheckBox', text='Normalize Image', value=DEFAULTS['norm_image']), scale = dict(widget_type='FloatSpinBox', label='Scale factor', min=0.5, max=2, step=0.1, value=DEFAULTS['scale']), @@ -153,8 +162,10 @@ def plugin ( image: napari.layers.Image, image_axes: str, label_nn, + mode, model_type, model2d, + model3d, model_folder, norm_image, perc_low, @@ -172,17 +183,30 @@ def plugin ( n_tiles, cnn_output, progress_bar: mw.ProgressBar, - ) -> List[napari.types.LayerDataTuple]: - DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" - + ) -> list[napari.types.LayerDataTuple]: + if image_axes == "": + raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.") + should_use_mps = torch.backends.mps.is_available() and (not IS_3D or (os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0")) + DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu" + model = get_model( model_type, { - Spotiflow: model2d, + Spotiflow: model2d if not IS_3D else model3d, CUSTOM_MODEL: model_folder, }[model_type], DEVICE_STR ) + + # TODO: improve errors - display them in the GUI + if IS_3D: + assert model.config.is_3d, "Expected folder containing a 3D model. Are you sure the given folder is correct?" + else: + assert not model.config.is_3d, "Expected folder containing a 2D model. Are you sure the given folder is correct?" + + if subpix and not model.config.compute_flow: + warn("Model was not trained to predict the stereographic flow. Will disable subpixel prediction.") + subpix = False model.to(torch.device(DEVICE_STR)) try: @@ -201,7 +225,7 @@ def plugin ( x = _prepare_input(x, image_axes) if "T" not in image_axes: - if len(n_tiles)==2: + if len(n_tiles)==2+int(IS_3D): n_tiles = n_tiles + (1,) if norm_image: @@ -209,7 +233,7 @@ def plugin ( x = normalize(x, perc_low, perc_high) else: - if x.ndim==4 and len(n_tiles)==2: + if x.ndim==(4+int(IS_3D)) and len(n_tiles)==(2+int(IS_3D)): n_tiles = n_tiles + (1,) if norm_image: print("Normalizing frames...") @@ -231,7 +255,7 @@ def _progress(it, **kwargs): return _progress actual_prob_thresh = prob_thresh if not use_optimized else None if "T" not in image_axes: - actual_n_tiles = tuple(max(1,s//1024) for s in x.shape) if auto_n_tiles else n_tiles + actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape) if auto_n_tiles else n_tiles pred_points, details = model.predict(x, prob_thresh=actual_prob_thresh, n_tiles=actual_n_tiles, @@ -247,10 +271,12 @@ def _progress(it, **kwargs): if cnn_output: details_prob_heatmap = details.heatmap - details_flow = details.flow + if subpix: + details_flow = details.flow else: - actual_n_tiles = tuple(max(1,s//1024) for s in x.shape[1:]) if auto_n_tiles else n_tiles + # Predict frames + actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape[1:]) if auto_n_tiles else n_tiles pred_points_t, details_t = tuple(zip(*tuple(model.predict(_x, prob_thresh=actual_prob_thresh, n_tiles=actual_n_tiles, @@ -267,22 +293,27 @@ def _progress(it, **kwargs): for i,ps in enumerate(pred_points_t) for p in ps) if cnn_output: details_prob_heatmap = np.stack([det.heatmap for det in details_t], axis=0) - details_flow = np.stack([det.flow for det in details_t], axis=0) - - if cnn_output: - viewer.add_image(.5*(1+details_flow), name=f'Stereographic flow ({image.name})') - viewer.add_image(details_prob_heatmap, name=f'Gaussian heatmap ({image.name})', colormap='magma') + if subpix: + details_flow = np.stack([det.flow for det in details_t], axis=0) + if cnn_output: + if subpix: + layers.append((.5*(1+details_flow), dict(name=f'Stereographic flow ({image.name})', scale=model.config.grid, + ), 'image')) + layers.append((details_prob_heatmap, dict(name=f'Gaussian heatmap ({image.name})', + colormap='magma', scale=model.config.grid), 'image')) points_layer_name = f'Spots ({image.name})' for l in viewer.layers: if l.name == points_layer_name: viewer.layers.remove(l) - viewer.add_points(pred_points, name=points_layer_name, **_point_layer2d_default_kwargs) + layers.append((pred_points, dict(name=f'Spots ({image.name})', + **_point_layer2d_default_kwargs), 'points')) + progress_bar.hide() - return + return layers # # ------------------------------------------------------------------------- @@ -296,7 +327,7 @@ def _progress(it, **kwargs): # ------------------------------------------------------------------------- widget_for_modeltype = { - Spotiflow: plugin.model2d, + Spotiflow: "pretrained", CUSTOM_MODEL: plugin.model_folder, } @@ -319,34 +350,40 @@ def _thr_change(active: bool): plugin.prob_thresh, active=not active ) - - @change_handler(plugin.model_type, init=True) - def _model_type_change(model_type: Union[str, type]): - selected = widget_for_modeltype[model_type] - for w in set((plugin.model2d, plugin.model_folder)) - {selected}: - w.hide() - selected.show() - selected.changed(selected.value) - @change_handler(plugin.image, init=False) + @change_handler(plugin.image, init=True) def _image_update(image: napari.layers.Image): global CURR_IMAGE_AXES_CHOICES + global IS_3D + possible_axes_choices = BASE_IMAGE_AXES_CHOICES if not IS_3D else BASE_IMAGE_AXES_CHOICES_3D if image is not None: inp_ndim = get_data(image).ndim - assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4." + # if not IS_3D: + # assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4." + # else: + # assert inp_ndim in (3,4,5), f"Invalid input dimension: {inp_ndim}. Should be 3, 4, or 5." # Update the choices for image_axes - CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == inp_ndim] - # Trigger event to update the choices and value of image_axes - plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES) - plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0] - + CURR_IMAGE_AXES_CHOICES = [c for c in possible_axes_choices if len(c) == inp_ndim] + else: + CURR_IMAGE_AXES_CHOICES = [""] + + if len(CURR_IMAGE_AXES_CHOICES) == 0: + CURR_IMAGE_AXES_CHOICES = [""] + # Trigger event to update the choices and value of image_axes + plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES) + plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0] + @change_handler(plugin.image_axes, init=False) def _image_axes_update(choices: List[str]): with plugin.image_axes.changed.blocked(): plugin.image_axes.choices = CURR_IMAGE_AXES_CHOICES if plugin.image_axes.value not in choices: plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0] + if plugin.image_axes.value == "": + plugin.call_button.enabled = False + else: + plugin.call_button.enabled = True @change_handler(plugin.norm_image) def _norm_image_change(active: bool): @@ -361,6 +398,25 @@ def _auto_n_tiles_change(active: bool): active=not active ) - return plugin + @change_handler(plugin.model_type, init=False) + def _model_type_change(model_type: Union[str, type]): + selected = widget_for_modeltype[model_type] + if selected == "pretrained": + selected = plugin.model2d if not IS_3D else plugin.model3d + for w in set((plugin.model2d, plugin.model3d, plugin.model_folder)) - {selected}: + w.hide() + selected.show() + selected.changed(selected.value) + @change_handler(plugin.mode, init=True) + def _model_mode_change(mode: str): + global IS_3D + IS_3D = mode == '3D' + selected = plugin.model3d if IS_3D else plugin.model2d + for w in set((plugin.model2d, plugin.model3d)) - {selected}: + w.hide() + selected.show() + selected.changed(selected.value) + _image_update(plugin.image.value) + return plugin diff --git a/napari_spotiflow/_io_hooks.py b/napari_spotiflow/_io_hooks.py index a5d0f97..29de594 100644 --- a/napari_spotiflow/_io_hooks.py +++ b/napari_spotiflow/_io_hooks.py @@ -11,7 +11,9 @@ import pandas as pd from napari_builtins.io import napari_get_reader as default_napari_get_reader -COLUMNS = ('z', 'y', 'x') +COLUMNS_4D = ('t', 'z', 'y', 'x') +COLUMNS_3D = ('z', 'y', 'x') + COLUMNS_NAME_MAP_2D = { @@ -25,12 +27,21 @@ 'axis-2' : 'x', } +COLUMNS_NAME_MAP_4D = { + 'axis-0' : 't', + 'axis-1' : 'z', + 'axis-2' : 'y', + 'axis-3' : 'x', + +} def _load_and_parse_csv(path, **kwargs): df = pd.read_csv(path, **kwargs) df.columns = df.columns.str.lower() df.columns = df.columns.str.strip() - if 'axis-2' in df.columns: + if 'axis-3' in df.columns: + df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_4D.get(n,n)) + elif 'axis-2' in df.columns: df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_3D.get(n,n)) else: df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_2D.get(n,n)) @@ -38,7 +49,7 @@ def _load_and_parse_csv(path, **kwargs): return df def _validate_dataframe(df): - return set(COLUMNS[-2:]).issubset(set(df.columns)) + return set(COLUMNS_3D[-2:]).issubset(set(df.columns)) def _validate_path(path: Union[str, Path]): """ checks whether path is a valid csv """ @@ -72,12 +83,11 @@ def reader_function(path): df = _load_and_parse_csv(path) - # if 3d - if set(COLUMNS).issubset(set(df.columns)): - # data = df[list(columns)].to_numpy() + if set(COLUMNS_4D).issubset(set(df.columns)): + data = df[['t','z','y','x']].to_numpy() + elif set(COLUMNS_3D).issubset(set(df.columns)): data = df[['z','y','x']].to_numpy() else: - # data = df[list(columns[-2:])].to_numpy() data = df[['y','x']].to_numpy() kwargs = dict(_point_layer2d_default_kwargs) @@ -91,6 +101,8 @@ def napari_write_points(path, data, meta): df = pd.DataFrame(data[:,::-1], columns=['x','y']) elif data.shape[-1]==3: df = pd.DataFrame(data[:,::-1], columns=['x','y','z']) + elif data.shape[-1]==4: + df = pd.DataFrame(data[:,::-1], columns=['x','y','z','t']) else: return None df.to_csv(path, index=False) diff --git a/napari_spotiflow/_sample_data.py b/napari_spotiflow/_sample_data.py index d65b0b4..394caf8 100644 --- a/napari_spotiflow/_sample_data.py +++ b/napari_spotiflow/_sample_data.py @@ -1,9 +1,11 @@ def _test_image_hybiss_2d(): from spotiflow import sample_data - return [(sample_data.test_image_hybiss_2d(), {"name": "hybiss_2d"})] - def _test_image_terra_2d(): from spotiflow import sample_data return [(sample_data.test_image_terra_2d(), {"name": "terra_2d"})] + +def _test_image_synth_3d(): + from spotiflow import sample_data + return [(sample_data.test_image_synth_3d(), {"name": "synth_3d"})] diff --git a/napari_spotiflow/napari.yaml b/napari_spotiflow/napari.yaml index dc30f07..c842b9a 100644 --- a/napari_spotiflow/napari.yaml +++ b/napari_spotiflow/napari.yaml @@ -14,6 +14,9 @@ contributions: - id: napari-spotiflow.data.terra_2d title: Terra (2D) sample python_name: napari_spotiflow._sample_data:_test_image_terra_2d + - id: napari-spotiflow.data.synth_3d + title: Synthetic (3D) sample + python_name: napari_spotiflow._sample_data:_test_image_synth_3d sample_data: - key: hybiss display_name: HybISS @@ -21,6 +24,9 @@ contributions: - key: terra display_name: Terra command: napari-spotiflow.data.terra_2d + - key: synth_3d + display_name: Synthetic (3D) + command: napari-spotiflow.data.synth_3d readers: - command: napari-spotiflow.reader accepts_directories: false diff --git a/napari_spotiflow/utils.py b/napari_spotiflow/utils.py index 6903fb8..f7344d8 100644 --- a/napari_spotiflow/utils.py +++ b/napari_spotiflow/utils.py @@ -1,23 +1,39 @@ import numpy as np from typing import Literal -def _validate_axes(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]): +def _validate_axes(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "CZYX", "ZTYX", "ZTYXC", "ZTCYX"]) -> None: assert img.ndim == len(axes), f"Image has {img.ndim} dimensions, but axes has {len(axes)} dimensions" return -def _prepare_input(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]): +def _prepare_input(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "CZYX", "ZTYX", "TZYXC", "TCZYX"]) -> np.ndarray: + """Reshape input for Spotiflow's API compatibility. If `axes` contains "Z", then assumes `img` is a volumetric (3D) image. + + Args: + img (np.ndarray): input image to be reformatted + axes (Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "ZCYX", "ZTYX", "ZTYXC", "ZTCYX"]): given axes + + Raises: + ValueError: thrown if axis is not valid + + Returns: + np.ndarray: reshaped NumPy array compatible with Spotiflow's `predict` API + """ _validate_axes(img, axes) - if axes == "YX": + if axes == "YX" or axes == "ZYX": return img[..., None] - elif axes == "YXC": + elif axes == "YXC" or axes == "ZYXC": return img elif axes == "CYX": - return img.transpose(1, 2, 0) - elif axes == "TYX": + return img.transpose(1,2,0) + elif axes == "CZYX": + return img.transpose(1,2,3,0) + elif axes == "TYX" or axes == "TZYX": return img[..., None] - elif axes == "TYXC": + elif axes == "TYXC" or axes == "TZYXC": return img elif axes == "TCYX": - return img.transpose(0, 2, 3, 1) + return img.transpose(0,2,3,1) + elif axes == "TCZYX": + return img.transpose(0,2,3,4,0) else: raise ValueError(f"Invalid axes: {axes}")