diff --git a/napari_spotiflow/_dock_widget.py b/napari_spotiflow/_dock_widget.py index f65998d..2e4f2a8 100644 --- a/napari_spotiflow/_dock_widget.py +++ b/napari_spotiflow/_dock_widget.py @@ -186,8 +186,9 @@ def plugin ( ) -> list[napari.types.LayerDataTuple]: if image_axes == "": raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.") - should_use_mps = torch.backends.mps.is_available() and (not IS_3D or (os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0")) - DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu" + should_use_mps = torch.backends.mps.is_available() and not IS_3D and not os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "0" + DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps else "cpu" + print(f'using device {DEVICE_STR}') model = get_model( model_type,