Skip to content

Commit

Permalink
Merge pull request #1 from weigertlab/mps_fix
Browse files Browse the repository at this point in the history
Mps fix
  • Loading branch information
AlbertDominguez authored Aug 29, 2024
2 parents b044e7b + cfb26b6 commit 44d6bfd
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions napari_spotiflow/_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def plugin (
) -> list[napari.types.LayerDataTuple]:
if image_axes == "":
raise RuntimeError("Invalid axes order. If your input is 2D, please set the 2D mode. If your input is 3D, please set the 3D mode.")
should_use_mps = torch.backends.mps.is_available() and (not IS_3D or (os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"))
DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu"
should_use_mps = torch.backends.mps.is_available() and not IS_3D and not os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") == "0"
DEVICE_STR = "cuda" if torch.cuda.is_available() else "mps" if should_use_mps else "cpu"
print(f'using device {DEVICE_STR}')

model = get_model(
model_type,
Expand Down

0 comments on commit 44d6bfd

Please sign in to comment.