diff --git a/napari_spotiflow/_dock_widget.py b/napari_spotiflow/_dock_widget.py
index c34883d..57abf52 100644
--- a/napari_spotiflow/_dock_widget.py
+++ b/napari_spotiflow/_dock_widget.py
@@ -1,24 +1,26 @@
-from magicgui import magicgui
-from magicgui import widgets as mw
-from magicgui.application import use_app
-
import functools
-import time
-import numpy as np
-
+import os
from copy import deepcopy
-from pathlib import Path
+from typing import List, Union
from warnings import warn
+import logging
import napari
-from typing import List, Union
-from enum import Enum
-from psygnal import Signal
+import numpy as np
+from magicgui import magicgui
+from magicgui import widgets as mw
+from magicgui.application import use_app
from .utils import _prepare_input, _validate_axes
+logging.basicConfig(level=logging.ERROR)
+log = logging.getLogger(__name__)
+
BASE_IMAGE_AXES_CHOICES = ["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]
+BASE_IMAGE_AXES_CHOICES_3D = [f"Z{axes}" if "T" not in axes else f"TZ{axes[1:]}" for axes in BASE_IMAGE_AXES_CHOICES]
+
CURR_IMAGE_AXES_CHOICES = deepcopy(BASE_IMAGE_AXES_CHOICES)
+IS_3D = False
def abspath(root, relpath):
from pathlib import Path
@@ -38,8 +40,6 @@ def change_handler(*widgets, init=True):
def decorator_change_handler(handler):
@functools.wraps(handler)
def wrapper(*args):
- source = Signal.sender()
- emitter = Signal.current_emitter()
return handler(*args)
for widget in widgets:
@@ -55,23 +55,27 @@ def plugin_wrapper():
# delay imports until plugin is requested by user
import torch
from spotiflow.model import Spotiflow
+ from spotiflow.model.pretrained import list_registered, _REGISTERED
from spotiflow.utils import normalize
- from spotiflow.model.pretrained import list_registered
+
from napari_spotiflow import _point_layer2d_default_kwargs
def get_data(image):
image = image.data[0] if image.multiscale else image.data
return np.asarray(image)
- models_reg = list_registered()
+ models_reg_2d = [r for r in list_registered() if not _REGISTERED[r].is_3d]
+ models_reg_3d = [r for r in list_registered() if _REGISTERED[r].is_3d]
- if 'general' in models_reg:
- models_reg = ['general'] + sorted([m for m in models_reg if m != 'general'])
+ if 'general' in models_reg_2d:
+ models_reg_2d = ['general'] + sorted([m for m in models_reg_2d if m != 'general'])
else:
- models_reg = sorted(models_reg)
-
- model_configs = dict()
- model_selected = None
+ models_reg_2d = sorted(models_reg_2d)
+
+ if 'synth_3d' in models_reg_3d:
+ models_reg_3d = ['synth_3d'] + sorted([m for m in models_reg_3d if m != 'synth_3d'])
+ else:
+ models_reg_3d = sorted(models_reg_3d)
CUSTOM_MODEL = 'CUSTOM_MODEL'
model_type_choices = [('Pre-trained', Spotiflow), ('Custom', CUSTOM_MODEL)]
@@ -82,6 +86,8 @@ def get_data(image):
if len(image_layers) > 0:
ndim_first = image_layers[0].data.ndim
CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == ndim_first]
+ else:
+ CURR_IMAGE_AXES_CHOICES = [""]
@@ -98,8 +104,10 @@ def get_model(model_type, model, device):
DEFAULTS = dict (
+ mode = '2D',
model_type = Spotiflow,
model2d = 'general',
+ model3d = 'synth_3d',
norm_image = True,
perc_low = 1.0,
perc_high = 99.8,
@@ -118,14 +126,15 @@ def get_model(model_type, model, device):
# -------------------------------------------------------------------------
logo = abspath(__file__, 'resources/spotiflow_transp_small.png')
-
@magicgui (
label_head = dict(widget_type='Label', label=f'
'),
image = dict(label='Input Image'),
image_axes = dict(widget_type='RadioButtons', label='Image axes order', orientation='horizontal', choices=get_image_axes_choices, value=CURR_IMAGE_AXES_CHOICES[0]),
label_nn = dict(widget_type='Label', label='
Neural Network Prediction:'),
+ mode = dict(widget_type='RadioButtons', label='Mode', orientation='horizontal', choices=['2D', '3D'], value=DEFAULTS["mode"]),
model_type = dict(widget_type='RadioButtons', label='Model Type', orientation='horizontal', choices=model_type_choices, value=DEFAULTS['model_type']),
- model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model', choices=models_reg, value=DEFAULTS['model2d']),
+ model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (2D)', choices=models_reg_2d, value=DEFAULTS['model2d']),
+ model3d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (3D)', choices=models_reg_3d, value=DEFAULTS['model3d']),
model_folder = dict(widget_type='FileEdit', visible=True, label='Custom Model', mode='d'),
norm_image = dict(widget_type='CheckBox', text='Normalize Image', value=DEFAULTS['norm_image']),
scale = dict(widget_type='FloatSpinBox', label='Scale factor', min=0.5, max=2, step=0.1, value=DEFAULTS['scale']),
@@ -153,8 +162,10 @@ def plugin (
image: napari.layers.Image,
image_axes: str,
label_nn,
+ mode,
model_type,
model2d,
+ model3d,
model_folder,
norm_image,
perc_low,
@@ -172,17 +183,30 @@ def plugin (
n_tiles,
cnn_output,
progress_bar: mw.ProgressBar,
- ) -> List[napari.types.LayerDataTuple]:
- DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
-
+ ) -> list[napari.types.LayerDataTuple]:
+ if image_axes == "":
+ raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.")
+ should_use_mps = torch.backends.mps.is_available() and (not IS_3D or (os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"))
+ DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu"
+
model = get_model(
model_type,
{
- Spotiflow: model2d,
+ Spotiflow: model2d if not IS_3D else model3d,
CUSTOM_MODEL: model_folder,
}[model_type],
DEVICE_STR
)
+
+ # TODO: improve errors - display them in the GUI
+ if IS_3D:
+ assert model.config.is_3d, "Expected folder containing a 3D model. Are you sure the given folder is correct?"
+ else:
+ assert not model.config.is_3d, "Expected folder containing a 2D model. Are you sure the given folder is correct?"
+
+ if subpix and not model.config.compute_flow:
+ warn("Model was not trained to predict the stereographic flow. Will disable subpixel prediction.")
+ subpix = False
model.to(torch.device(DEVICE_STR))
try:
@@ -201,7 +225,7 @@ def plugin (
x = _prepare_input(x, image_axes)
if "T" not in image_axes:
- if len(n_tiles)==2:
+ if len(n_tiles)==2+int(IS_3D):
n_tiles = n_tiles + (1,)
if norm_image:
@@ -209,7 +233,7 @@ def plugin (
x = normalize(x, perc_low, perc_high)
else:
- if x.ndim==4 and len(n_tiles)==2:
+ if x.ndim==(4+int(IS_3D)) and len(n_tiles)==(2+int(IS_3D)):
n_tiles = n_tiles + (1,)
if norm_image:
print("Normalizing frames...")
@@ -231,7 +255,7 @@ def _progress(it, **kwargs):
return _progress
actual_prob_thresh = prob_thresh if not use_optimized else None
if "T" not in image_axes:
- actual_n_tiles = tuple(max(1,s//1024) for s in x.shape) if auto_n_tiles else n_tiles
+ actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape) if auto_n_tiles else n_tiles
pred_points, details = model.predict(x,
prob_thresh=actual_prob_thresh,
n_tiles=actual_n_tiles,
@@ -247,10 +271,12 @@ def _progress(it, **kwargs):
if cnn_output:
details_prob_heatmap = details.heatmap
- details_flow = details.flow
+ if subpix:
+ details_flow = details.flow
else:
- actual_n_tiles = tuple(max(1,s//1024) for s in x.shape[1:]) if auto_n_tiles else n_tiles
+ # Predict frames
+ actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape[1:]) if auto_n_tiles else n_tiles
pred_points_t, details_t = tuple(zip(*tuple(model.predict(_x,
prob_thresh=actual_prob_thresh,
n_tiles=actual_n_tiles,
@@ -267,22 +293,27 @@ def _progress(it, **kwargs):
for i,ps in enumerate(pred_points_t) for p in ps)
if cnn_output:
details_prob_heatmap = np.stack([det.heatmap for det in details_t], axis=0)
- details_flow = np.stack([det.flow for det in details_t], axis=0)
-
- if cnn_output:
- viewer.add_image(.5*(1+details_flow), name=f'Stereographic flow ({image.name})')
- viewer.add_image(details_prob_heatmap, name=f'Gaussian heatmap ({image.name})', colormap='magma')
+ if subpix:
+ details_flow = np.stack([det.flow for det in details_t], axis=0)
+ if cnn_output:
+ if subpix:
+ layers.append((.5*(1+details_flow), dict(name=f'Stereographic flow ({image.name})', scale=model.config.grid,
+ ), 'image'))
+ layers.append((details_prob_heatmap, dict(name=f'Gaussian heatmap ({image.name})',
+ colormap='magma', scale=model.config.grid), 'image'))
points_layer_name = f'Spots ({image.name})'
for l in viewer.layers:
if l.name == points_layer_name:
viewer.layers.remove(l)
- viewer.add_points(pred_points, name=points_layer_name, **_point_layer2d_default_kwargs)
+ layers.append((pred_points, dict(name=f'Spots ({image.name})',
+ **_point_layer2d_default_kwargs), 'points'))
+
progress_bar.hide()
- return
+ return layers
# # -------------------------------------------------------------------------
@@ -296,7 +327,7 @@ def _progress(it, **kwargs):
# -------------------------------------------------------------------------
widget_for_modeltype = {
- Spotiflow: plugin.model2d,
+ Spotiflow: "pretrained",
CUSTOM_MODEL: plugin.model_folder,
}
@@ -319,34 +350,40 @@ def _thr_change(active: bool):
plugin.prob_thresh,
active=not active
)
-
- @change_handler(plugin.model_type, init=True)
- def _model_type_change(model_type: Union[str, type]):
- selected = widget_for_modeltype[model_type]
- for w in set((plugin.model2d, plugin.model_folder)) - {selected}:
- w.hide()
- selected.show()
- selected.changed(selected.value)
- @change_handler(plugin.image, init=False)
+ @change_handler(plugin.image, init=True)
def _image_update(image: napari.layers.Image):
global CURR_IMAGE_AXES_CHOICES
+ global IS_3D
+ possible_axes_choices = BASE_IMAGE_AXES_CHOICES if not IS_3D else BASE_IMAGE_AXES_CHOICES_3D
if image is not None:
inp_ndim = get_data(image).ndim
- assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4."
+ # if not IS_3D:
+ # assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4."
+ # else:
+ # assert inp_ndim in (3,4,5), f"Invalid input dimension: {inp_ndim}. Should be 3, 4, or 5."
# Update the choices for image_axes
- CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == inp_ndim]
- # Trigger event to update the choices and value of image_axes
- plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES)
- plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]
-
+ CURR_IMAGE_AXES_CHOICES = [c for c in possible_axes_choices if len(c) == inp_ndim]
+ else:
+ CURR_IMAGE_AXES_CHOICES = [""]
+
+ if len(CURR_IMAGE_AXES_CHOICES) == 0:
+ CURR_IMAGE_AXES_CHOICES = [""]
+ # Trigger event to update the choices and value of image_axes
+ plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES)
+ plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]
+
@change_handler(plugin.image_axes, init=False)
def _image_axes_update(choices: List[str]):
with plugin.image_axes.changed.blocked():
plugin.image_axes.choices = CURR_IMAGE_AXES_CHOICES
if plugin.image_axes.value not in choices:
plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]
+ if plugin.image_axes.value == "":
+ plugin.call_button.enabled = False
+ else:
+ plugin.call_button.enabled = True
@change_handler(plugin.norm_image)
def _norm_image_change(active: bool):
@@ -361,6 +398,25 @@ def _auto_n_tiles_change(active: bool):
active=not active
)
- return plugin
+ @change_handler(plugin.model_type, init=False)
+ def _model_type_change(model_type: Union[str, type]):
+ selected = widget_for_modeltype[model_type]
+ if selected == "pretrained":
+ selected = plugin.model2d if not IS_3D else plugin.model3d
+ for w in set((plugin.model2d, plugin.model3d, plugin.model_folder)) - {selected}:
+ w.hide()
+ selected.show()
+ selected.changed(selected.value)
+ @change_handler(plugin.mode, init=True)
+ def _model_mode_change(mode: str):
+ global IS_3D
+ IS_3D = mode == '3D'
+ selected = plugin.model3d if IS_3D else plugin.model2d
+ for w in set((plugin.model2d, plugin.model3d)) - {selected}:
+ w.hide()
+ selected.show()
+ selected.changed(selected.value)
+ _image_update(plugin.image.value)
+ return plugin
diff --git a/napari_spotiflow/_io_hooks.py b/napari_spotiflow/_io_hooks.py
index a5d0f97..29de594 100644
--- a/napari_spotiflow/_io_hooks.py
+++ b/napari_spotiflow/_io_hooks.py
@@ -11,7 +11,9 @@
import pandas as pd
from napari_builtins.io import napari_get_reader as default_napari_get_reader
-COLUMNS = ('z', 'y', 'x')
+COLUMNS_4D = ('t', 'z', 'y', 'x')
+COLUMNS_3D = ('z', 'y', 'x')
+
COLUMNS_NAME_MAP_2D = {
@@ -25,12 +27,21 @@
'axis-2' : 'x',
}
+COLUMNS_NAME_MAP_4D = {
+ 'axis-0' : 't',
+ 'axis-1' : 'z',
+ 'axis-2' : 'y',
+ 'axis-3' : 'x',
+
+}
def _load_and_parse_csv(path, **kwargs):
df = pd.read_csv(path, **kwargs)
df.columns = df.columns.str.lower()
df.columns = df.columns.str.strip()
- if 'axis-2' in df.columns:
+ if 'axis-3' in df.columns:
+ df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_4D.get(n,n))
+ elif 'axis-2' in df.columns:
df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_3D.get(n,n))
else:
df = df.rename(columns = lambda n: COLUMNS_NAME_MAP_2D.get(n,n))
@@ -38,7 +49,7 @@ def _load_and_parse_csv(path, **kwargs):
return df
def _validate_dataframe(df):
- return set(COLUMNS[-2:]).issubset(set(df.columns))
+ return set(COLUMNS_3D[-2:]).issubset(set(df.columns))
def _validate_path(path: Union[str, Path]):
""" checks whether path is a valid csv """
@@ -72,12 +83,11 @@ def reader_function(path):
df = _load_and_parse_csv(path)
- # if 3d
- if set(COLUMNS).issubset(set(df.columns)):
- # data = df[list(columns)].to_numpy()
+ if set(COLUMNS_4D).issubset(set(df.columns)):
+ data = df[['t','z','y','x']].to_numpy()
+ elif set(COLUMNS_3D).issubset(set(df.columns)):
data = df[['z','y','x']].to_numpy()
else:
- # data = df[list(columns[-2:])].to_numpy()
data = df[['y','x']].to_numpy()
kwargs = dict(_point_layer2d_default_kwargs)
@@ -91,6 +101,8 @@ def napari_write_points(path, data, meta):
df = pd.DataFrame(data[:,::-1], columns=['x','y'])
elif data.shape[-1]==3:
df = pd.DataFrame(data[:,::-1], columns=['x','y','z'])
+ elif data.shape[-1]==4:
+ df = pd.DataFrame(data[:,::-1], columns=['x','y','z','t'])
else:
return None
df.to_csv(path, index=False)
diff --git a/napari_spotiflow/_sample_data.py b/napari_spotiflow/_sample_data.py
index d65b0b4..394caf8 100644
--- a/napari_spotiflow/_sample_data.py
+++ b/napari_spotiflow/_sample_data.py
@@ -1,9 +1,11 @@
def _test_image_hybiss_2d():
from spotiflow import sample_data
-
return [(sample_data.test_image_hybiss_2d(), {"name": "hybiss_2d"})]
-
def _test_image_terra_2d():
from spotiflow import sample_data
return [(sample_data.test_image_terra_2d(), {"name": "terra_2d"})]
+
+def _test_image_synth_3d():
+ from spotiflow import sample_data
+ return [(sample_data.test_image_synth_3d(), {"name": "synth_3d"})]
diff --git a/napari_spotiflow/napari.yaml b/napari_spotiflow/napari.yaml
index dc30f07..c842b9a 100644
--- a/napari_spotiflow/napari.yaml
+++ b/napari_spotiflow/napari.yaml
@@ -14,6 +14,9 @@ contributions:
- id: napari-spotiflow.data.terra_2d
title: Terra (2D) sample
python_name: napari_spotiflow._sample_data:_test_image_terra_2d
+ - id: napari-spotiflow.data.synth_3d
+ title: Synthetic (3D) sample
+ python_name: napari_spotiflow._sample_data:_test_image_synth_3d
sample_data:
- key: hybiss
display_name: HybISS
@@ -21,6 +24,9 @@ contributions:
- key: terra
display_name: Terra
command: napari-spotiflow.data.terra_2d
+ - key: synth_3d
+ display_name: Synthetic (3D)
+ command: napari-spotiflow.data.synth_3d
readers:
- command: napari-spotiflow.reader
accepts_directories: false
diff --git a/napari_spotiflow/utils.py b/napari_spotiflow/utils.py
index 6903fb8..f7344d8 100644
--- a/napari_spotiflow/utils.py
+++ b/napari_spotiflow/utils.py
@@ -1,23 +1,39 @@
import numpy as np
from typing import Literal
-def _validate_axes(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]):
+def _validate_axes(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "CZYX", "ZTYX", "ZTYXC", "ZTCYX"]) -> None:
assert img.ndim == len(axes), f"Image has {img.ndim} dimensions, but axes has {len(axes)} dimensions"
return
-def _prepare_input(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]):
+def _prepare_input(img: np.ndarray, axes: Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "CZYX", "ZTYX", "TZYXC", "TCZYX"]) -> np.ndarray:
+ """Reshape input for Spotiflow's API compatibility. If `axes` contains "Z", then assumes `img` is a volumetric (3D) image.
+
+ Args:
+ img (np.ndarray): input image to be reformatted
+ axes (Literal["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX", "ZYX", "ZYXC", "ZCYX", "ZTYX", "ZTYXC", "ZTCYX"]): given axes
+
+ Raises:
+ ValueError: thrown if axis is not valid
+
+ Returns:
+ np.ndarray: reshaped NumPy array compatible with Spotiflow's `predict` API
+ """
_validate_axes(img, axes)
- if axes == "YX":
+ if axes == "YX" or axes == "ZYX":
return img[..., None]
- elif axes == "YXC":
+ elif axes == "YXC" or axes == "ZYXC":
return img
elif axes == "CYX":
- return img.transpose(1, 2, 0)
- elif axes == "TYX":
+ return img.transpose(1,2,0)
+ elif axes == "CZYX":
+ return img.transpose(1,2,3,0)
+ elif axes == "TYX" or axes == "TZYX":
return img[..., None]
- elif axes == "TYXC":
+ elif axes == "TYXC" or axes == "TZYXC":
return img
elif axes == "TCYX":
- return img.transpose(0, 2, 3, 1)
+ return img.transpose(0,2,3,1)
+ elif axes == "TCZYX":
+ return img.transpose(0,2,3,4,0)
else:
raise ValueError(f"Invalid axes: {axes}")