From e5a511e58fa4d67152ed6c349dd202a55eefb1eb Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Thu, 7 Nov 2024 14:36:20 +0100 Subject: [PATCH] fix #3 --- napari_spotiflow/__init__.py | 6 ++++++ napari_spotiflow/_dock_widget.py | 21 ++++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/napari_spotiflow/__init__.py b/napari_spotiflow/__init__.py index ca494db..fd388cc 100644 --- a/napari_spotiflow/__init__.py +++ b/napari_spotiflow/__init__.py @@ -6,6 +6,12 @@ face_color=[1.,.5,.2], border_color=[1.,.5,.2]) +_point_layer3d_default_kwargs = dict(size=8, + symbol='ring', + opacity=1, + face_color=[1.,.5,.2], + border_color=[1.,.5,.2], + out_of_slice_display=True) # def sample_data_2d(): # from spotiflow.data import hybiss_data_2d diff --git a/napari_spotiflow/_dock_widget.py b/napari_spotiflow/_dock_widget.py index 2e4f2a8..c01090f 100644 --- a/napari_spotiflow/_dock_widget.py +++ b/napari_spotiflow/_dock_widget.py @@ -1,9 +1,10 @@ -import functools import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +import functools +import logging from copy import deepcopy from typing import List, Union from warnings import warn -import logging import napari import numpy as np @@ -53,10 +54,10 @@ def plugin_wrapper(): # delay imports until plugin is requested by user import torch from spotiflow.model import Spotiflow - from spotiflow.model.pretrained import list_registered, _REGISTERED + from spotiflow.model.pretrained import _REGISTERED, list_registered from spotiflow.utils import normalize - from napari_spotiflow import _point_layer2d_default_kwargs + from napari_spotiflow import _point_layer2d_default_kwargs, _point_layer3d_default_kwargs def get_data(image): image = image.data[0] if image.multiscale else image.data @@ -186,7 +187,7 @@ def plugin ( ) -> list[napari.types.LayerDataTuple]: if image_axes == "": raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.") - should_use_mps = torch.backends.mps.is_available() and not IS_3D and not os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "0" + should_use_mps = torch.backends.mps.is_available() and (not IS_3D or os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" or os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None) DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps else "cpu" print(f'using device {DEVICE_STR}') @@ -244,7 +245,10 @@ def plugin ( def progress(size): def _progress(it, **kwargs): progress_bar.label = 'Spotiflow Prediction' - progress_bar.range = (0, size) + if kwargs.get("total", None) is None: + progress_bar.range = (0, size+1) + else: + progress_bar.range = (0, kwargs["total"]) progress_bar.value = 0 progress_bar.show() app.process_events() @@ -268,6 +272,7 @@ def _progress(it, **kwargs): progress_bar_wrapper=progress(np.prod(actual_n_tiles)), device=DEVICE_STR, subpix=subpix, + normalizer=None, ) if cnn_output: @@ -288,6 +293,7 @@ def _progress(it, **kwargs): verbose=True, device=DEVICE_STR, subpix=subpix, + normalizer=None, ) for _x in progress(x.shape[0])(x)))) pred_points = tuple(np.concatenate([[i], p]) @@ -308,8 +314,9 @@ def _progress(it, **kwargs): if l.name == points_layer_name: viewer.layers.remove(l) + point_layer_kwargs = _point_layer2d_default_kwargs if not IS_3D else _point_layer3d_default_kwargs layers.append((pred_points, dict(name=f'Spots ({image.name})', - **_point_layer2d_default_kwargs), 'points')) + **point_layer_kwargs), 'points')) progress_bar.hide()