Skip to content

Commit

Permalink
Add 3D version
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertDominguez committed Aug 24, 2024
1 parent 01d0949 commit 1d0af1b
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 73 deletions.
168 changes: 112 additions & 56 deletions napari_spotiflow/_dock_widget.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from magicgui import magicgui
from magicgui import widgets as mw
from magicgui.application import use_app

import functools
import time
import numpy as np

import os
from copy import deepcopy
from pathlib import Path
from typing import List, Union
from warnings import warn
import logging

import napari
from typing import List, Union
from enum import Enum
from psygnal import Signal
import numpy as np
from magicgui import magicgui
from magicgui import widgets as mw
from magicgui.application import use_app

from .utils import _prepare_input, _validate_axes

logging.basicConfig(level=logging.ERROR)
log = logging.getLogger(__name__)

BASE_IMAGE_AXES_CHOICES = ["YX", "YXC", "CYX", "TYX", "TYXC", "TCYX"]
BASE_IMAGE_AXES_CHOICES_3D = [f"Z{axes}" if "T" not in axes else f"TZ{axes[1:]}" for axes in BASE_IMAGE_AXES_CHOICES]

CURR_IMAGE_AXES_CHOICES = deepcopy(BASE_IMAGE_AXES_CHOICES)
IS_3D = False

def abspath(root, relpath):
from pathlib import Path
Expand All @@ -38,8 +40,6 @@ def change_handler(*widgets, init=True):
def decorator_change_handler(handler):
@functools.wraps(handler)
def wrapper(*args):
source = Signal.sender()
emitter = Signal.current_emitter()
return handler(*args)

for widget in widgets:
Expand All @@ -55,23 +55,27 @@ def plugin_wrapper():
# delay imports until plugin is requested by user
import torch
from spotiflow.model import Spotiflow
from spotiflow.model.pretrained import list_registered, _REGISTERED
from spotiflow.utils import normalize
from spotiflow.model.pretrained import list_registered

from napari_spotiflow import _point_layer2d_default_kwargs

def get_data(image):
image = image.data[0] if image.multiscale else image.data
return np.asarray(image)

models_reg = list_registered()
models_reg_2d = [r for r in list_registered() if not _REGISTERED[r].is_3d]
models_reg_3d = [r for r in list_registered() if _REGISTERED[r].is_3d]

if 'general' in models_reg:
models_reg = ['general'] + sorted([m for m in models_reg if m != 'general'])
if 'general' in models_reg_2d:
models_reg_2d = ['general'] + sorted([m for m in models_reg_2d if m != 'general'])
else:
models_reg = sorted(models_reg)

model_configs = dict()
model_selected = None
models_reg_2d = sorted(models_reg_2d)

if 'synth_3d' in models_reg_3d:
models_reg_3d = ['synth_3d'] + sorted([m for m in models_reg_3d if m != 'synth_3d'])
else:
models_reg_3d = sorted(models_reg_3d)

CUSTOM_MODEL = 'CUSTOM_MODEL'
model_type_choices = [('Pre-trained', Spotiflow), ('Custom', CUSTOM_MODEL)]
Expand All @@ -82,6 +86,8 @@ def get_data(image):
if len(image_layers) > 0:
ndim_first = image_layers[0].data.ndim
CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == ndim_first]
else:
CURR_IMAGE_AXES_CHOICES = [""]



Expand All @@ -98,8 +104,10 @@ def get_model(model_type, model, device):


DEFAULTS = dict (
mode = '2D',
model_type = Spotiflow,
model2d = 'general',
model3d = 'synth_3d',
norm_image = True,
perc_low = 1.0,
perc_high = 99.8,
Expand All @@ -118,14 +126,15 @@ def get_model(model_type, model, device):
# -------------------------------------------------------------------------

logo = abspath(__file__, 'resources/spotiflow_transp_small.png')

@magicgui (
label_head = dict(widget_type='Label', label=f'<h1><img src="{logo}"></h1>'),
image = dict(label='Input Image'),
image_axes = dict(widget_type='RadioButtons', label='Image axes order', orientation='horizontal', choices=get_image_axes_choices, value=CURR_IMAGE_AXES_CHOICES[0]),
label_nn = dict(widget_type='Label', label='<br><b>Neural Network Prediction:</b>'),
mode = dict(widget_type='RadioButtons', label='Mode', orientation='horizontal', choices=['2D', '3D'], value=DEFAULTS["mode"]),
model_type = dict(widget_type='RadioButtons', label='Model Type', orientation='horizontal', choices=model_type_choices, value=DEFAULTS['model_type']),
model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model', choices=models_reg, value=DEFAULTS['model2d']),
model2d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (2D)', choices=models_reg_2d, value=DEFAULTS['model2d']),
model3d = dict(widget_type='ComboBox', visible=True, label='Pre-trained Model (3D)', choices=models_reg_3d, value=DEFAULTS['model3d']),
model_folder = dict(widget_type='FileEdit', visible=True, label='Custom Model', mode='d'),
norm_image = dict(widget_type='CheckBox', text='Normalize Image', value=DEFAULTS['norm_image']),
scale = dict(widget_type='FloatSpinBox', label='Scale factor', min=0.5, max=2, step=0.1, value=DEFAULTS['scale']),
Expand Down Expand Up @@ -153,8 +162,10 @@ def plugin (
image: napari.layers.Image,
image_axes: str,
label_nn,
mode,
model_type,
model2d,
model3d,
model_folder,
norm_image,
perc_low,
Expand All @@ -172,17 +183,30 @@ def plugin (
n_tiles,
cnn_output,
progress_bar: mw.ProgressBar,
) -> List[napari.types.LayerDataTuple]:
DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

) -> list[napari.types.LayerDataTuple]:
if image_axes == "":
raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.")
should_use_mps = torch.backends.mps.is_available() and (not IS_3D or (os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"))
DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu"

model = get_model(
model_type,
{
Spotiflow: model2d,
Spotiflow: model2d if not IS_3D else model3d,
CUSTOM_MODEL: model_folder,
}[model_type],
DEVICE_STR
)

# TODO: improve errors - display them in the GUI
if IS_3D:
assert model.config.is_3d, "Expected folder containing a 3D model. Are you sure the given folder is correct?"
else:
assert not model.config.is_3d, "Expected folder containing a 2D model. Are you sure the given folder is correct?"

if subpix and not model.config.compute_flow:
warn("Model was not trained to predict the stereographic flow. Will disable subpixel prediction.")
subpix = False

model.to(torch.device(DEVICE_STR))
try:
Expand All @@ -201,15 +225,15 @@ def plugin (
x = _prepare_input(x, image_axes)

if "T" not in image_axes:
if len(n_tiles)==2:
if len(n_tiles)==2+int(IS_3D):
n_tiles = n_tiles + (1,)

if norm_image:
print("Normalizing image...")
x = normalize(x, perc_low, perc_high)

else:
if x.ndim==4 and len(n_tiles)==2:
if x.ndim==(4+int(IS_3D)) and len(n_tiles)==(2+int(IS_3D)):
n_tiles = n_tiles + (1,)
if norm_image:
print("Normalizing frames...")
Expand All @@ -231,7 +255,7 @@ def _progress(it, **kwargs):
return _progress
actual_prob_thresh = prob_thresh if not use_optimized else None
if "T" not in image_axes:
actual_n_tiles = tuple(max(1,s//1024) for s in x.shape) if auto_n_tiles else n_tiles
actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape) if auto_n_tiles else n_tiles
pred_points, details = model.predict(x,
prob_thresh=actual_prob_thresh,
n_tiles=actual_n_tiles,
Expand All @@ -247,10 +271,12 @@ def _progress(it, **kwargs):

if cnn_output:
details_prob_heatmap = details.heatmap
details_flow = details.flow
if subpix:
details_flow = details.flow

else:
actual_n_tiles = tuple(max(1,s//1024) for s in x.shape[1:]) if auto_n_tiles else n_tiles
# Predict frames
actual_n_tiles = tuple(max(1,s//(1024 if not IS_3D else 128)) for s in x.shape[1:]) if auto_n_tiles else n_tiles
pred_points_t, details_t = tuple(zip(*tuple(model.predict(_x,
prob_thresh=actual_prob_thresh,
n_tiles=actual_n_tiles,
Expand All @@ -267,22 +293,27 @@ def _progress(it, **kwargs):
for i,ps in enumerate(pred_points_t) for p in ps)
if cnn_output:
details_prob_heatmap = np.stack([det.heatmap for det in details_t], axis=0)
details_flow = np.stack([det.flow for det in details_t], axis=0)

if cnn_output:
viewer.add_image(.5*(1+details_flow), name=f'Stereographic flow ({image.name})')
viewer.add_image(details_prob_heatmap, name=f'Gaussian heatmap ({image.name})', colormap='magma')
if subpix:
details_flow = np.stack([det.flow for det in details_t], axis=0)

if cnn_output:
if subpix:
layers.append((.5*(1+details_flow), dict(name=f'Stereographic flow ({image.name})', scale=model.config.grid,
), 'image'))
layers.append((details_prob_heatmap, dict(name=f'Gaussian heatmap ({image.name})',
colormap='magma', scale=model.config.grid), 'image'))
points_layer_name = f'Spots ({image.name})'
for l in viewer.layers:
if l.name == points_layer_name:
viewer.layers.remove(l)

viewer.add_points(pred_points, name=points_layer_name, **_point_layer2d_default_kwargs)
layers.append((pred_points, dict(name=f'Spots ({image.name})',
**_point_layer2d_default_kwargs), 'points'))


progress_bar.hide()

return
return layers

# # -------------------------------------------------------------------------

Expand All @@ -296,7 +327,7 @@ def _progress(it, **kwargs):
# -------------------------------------------------------------------------

widget_for_modeltype = {
Spotiflow: plugin.model2d,
Spotiflow: "pretrained",
CUSTOM_MODEL: plugin.model_folder,
}

Expand All @@ -319,34 +350,40 @@ def _thr_change(active: bool):
plugin.prob_thresh,
active=not active
)

@change_handler(plugin.model_type, init=True)
def _model_type_change(model_type: Union[str, type]):
selected = widget_for_modeltype[model_type]
for w in set((plugin.model2d, plugin.model_folder)) - {selected}:
w.hide()
selected.show()
selected.changed(selected.value)

@change_handler(plugin.image, init=False)
@change_handler(plugin.image, init=True)
def _image_update(image: napari.layers.Image):
global CURR_IMAGE_AXES_CHOICES
global IS_3D
possible_axes_choices = BASE_IMAGE_AXES_CHOICES if not IS_3D else BASE_IMAGE_AXES_CHOICES_3D
if image is not None:
inp_ndim = get_data(image).ndim
assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4."
# if not IS_3D:
# assert inp_ndim in (2,3,4), f"Invalid input dimension: {inp_ndim}. Should be 2, 3, or 4."
# else:
# assert inp_ndim in (3,4,5), f"Invalid input dimension: {inp_ndim}. Should be 3, 4, or 5."
# Update the choices for image_axes
CURR_IMAGE_AXES_CHOICES = [c for c in BASE_IMAGE_AXES_CHOICES if len(c) == inp_ndim]

# Trigger event to update the choices and value of image_axes
plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES)
plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]

CURR_IMAGE_AXES_CHOICES = [c for c in possible_axes_choices if len(c) == inp_ndim]
else:
CURR_IMAGE_AXES_CHOICES = [""]

if len(CURR_IMAGE_AXES_CHOICES) == 0:
CURR_IMAGE_AXES_CHOICES = [""]
# Trigger event to update the choices and value of image_axes
plugin.image_axes.changed(CURR_IMAGE_AXES_CHOICES)
plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]

@change_handler(plugin.image_axes, init=False)
def _image_axes_update(choices: List[str]):
with plugin.image_axes.changed.blocked():
plugin.image_axes.choices = CURR_IMAGE_AXES_CHOICES
if plugin.image_axes.value not in choices:
plugin.image_axes.value = CURR_IMAGE_AXES_CHOICES[0]
if plugin.image_axes.value == "":
plugin.call_button.enabled = False
else:
plugin.call_button.enabled = True

@change_handler(plugin.norm_image)
def _norm_image_change(active: bool):
Expand All @@ -361,6 +398,25 @@ def _auto_n_tiles_change(active: bool):
active=not active
)

return plugin
@change_handler(plugin.model_type, init=False)
def _model_type_change(model_type: Union[str, type]):
selected = widget_for_modeltype[model_type]
if selected == "pretrained":
selected = plugin.model2d if not IS_3D else plugin.model3d
for w in set((plugin.model2d, plugin.model3d, plugin.model_folder)) - {selected}:
w.hide()
selected.show()
selected.changed(selected.value)

@change_handler(plugin.mode, init=True)
def _model_mode_change(mode: str):
global IS_3D
IS_3D = mode == '3D'
selected = plugin.model3d if IS_3D else plugin.model2d
for w in set((plugin.model2d, plugin.model3d)) - {selected}:
w.hide()
selected.show()
selected.changed(selected.value)
_image_update(plugin.image.value)

return plugin
Loading

0 comments on commit 1d0af1b

Please sign in to comment.