Skip to content

Commit

Permalink
fix: 🐛 Fix model download and loading paths.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 8, 2024
1 parent 848eb73 commit b4d7bac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/cellmap_models/pytorch/cellpose/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ def get_model(
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)

if not Path(base_path / f"{model_name}.pth").exists():
full_path = os.path.join(base_path, f"{model_name}.pth")
if not Path(full_path).exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], os.path.join(base_path, f"{model_name}.pth")
)
download_url_to_file(models_dict[model_name], full_path)
print("Downloaded model {model_name} to {base_path}.")
return
3 changes: 2 additions & 1 deletion src/cellmap_models/pytorch/cellpose/load_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
import torch
from .get_model import get_model
Expand All @@ -23,6 +24,6 @@ def load_model(
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
print("CUDA not available. Using CPU.")
model = torch.jit.load(str(base_path / f"{model_name}.pth"), device)
model = torch.jit.load(os.path.join(base_path, f"{model_name}.pth"), device)
model.eval()
return model

0 comments on commit b4d7bac

Please sign in to comment.