From 2e8e1b9a739276dead27af4a9af25783635c0b02 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 8 Mar 2024 12:10:14 -0500 Subject: [PATCH] Add cellmap.add_cellpose script and update model loading --- pyproject.toml | 2 +- .../__pycache__/__init__.cpython-310.pyc | Bin 294 -> 313 bytes .../__pycache__/utils.cpython-310.pyc | Bin 1619 -> 1619 bytes .../pytorch/cellpose/__init__.py | 1 + .../pytorch/cellpose/add_model.py | 23 +++++--------- .../pytorch/cellpose/get_model.py | 29 ++++++++++++++++++ .../pytorch/cellpose/load_model.py | 14 ++------- 7 files changed, 42 insertions(+), 27 deletions(-) create mode 100644 src/cellmap_models/pytorch/cellpose/get_model.py diff --git a/pyproject.toml b/pyproject.toml index 34030a5..28771bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,4 +45,4 @@ exclude = ['setup*'] ignore_missing_imports = true [project.scripts] -cellmap.add_cellpose = "cellmap_models.cellpose:add_model" \ No newline at end of file +"cellmap.add_cellpose" = "cellmap_models.pytorch.cellpose:add_model" diff --git a/src/cellmap_models/__pycache__/__init__.cpython-310.pyc b/src/cellmap_models/__pycache__/__init__.cpython-310.pyc index 9c1ecebc321467a890be7c6792da9e3b3722b982..e649df2c0ccf6054e0a9f07b82b1ca906609b885 100644 GIT binary patch delta 99 zcmZ3+w3CT9pO=@50SNjGU#AvNa(;1Y?k$ew)SR3G yAk$Bi?G}4{d`fuvyTA_x}% delta 28 icmcc2bD4)XpO=@50SGS1yh`od$lJxr$T@ii>uvyYYY1lm diff --git a/src/cellmap_models/pytorch/cellpose/__init__.py b/src/cellmap_models/pytorch/cellpose/__init__.py index c8faf0a..6e0bc43 100644 --- a/src/cellmap_models/pytorch/cellpose/__init__.py +++ b/src/cellmap_models/pytorch/cellpose/__init__.py @@ -1,5 +1,6 @@ from .add_model import add_model from .load_model import load_model +from .get_model import get_model models_dict = { "jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp", diff --git a/src/cellmap_models/pytorch/cellpose/add_model.py b/src/cellmap_models/pytorch/cellpose/add_model.py index 7246750..d63d68c 100644 --- a/src/cellmap_models/pytorch/cellpose/add_model.py +++ b/src/cellmap_models/pytorch/cellpose/add_model.py @@ -1,27 +1,20 @@ -from . import models_dict -from cellpose.io import _add_model +import sys +from typing import Optional +from cellpose.io import add_model as _add_model from cellpose.models import MODEL_DIR -from cellpose.utils import download_url_to_file +from .get_model import get_model -def add_model(model_name: str): +def add_model(model_name: Optional[str] = None): """Add model to cellpose Args: model_name (str): model name """ - # download model to cellpose directory - if model_name not in models_dict: - raise ValueError( - f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." - ) + if model_name is None: + model_name = sys.argv[1] base_path = MODEL_DIR - - if not (base_path / f"{model_name}.pth").exists(): - print(f"Downloading {model_name} from {models_dict[model_name]}") - download_url_to_file( - models_dict[model_name], str(base_path / f"{model_name}.pth") - ) + get_model(model_name, base_path) _add_model(str(base_path / f"{model_name}.pth")) print( f"Added model {model_name}. This will now be available in the cellpose model list." diff --git a/src/cellmap_models/pytorch/cellpose/get_model.py b/src/cellmap_models/pytorch/cellpose/get_model.py new file mode 100644 index 0000000..0122dca --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/get_model.py @@ -0,0 +1,29 @@ +from pathlib import Path +from cellpose.utils import download_url_to_file + + +def get_model( + model_name: str, + base_path: str = f"{Path(__file__).parent}/models", +): + """Add model to cellpose + + Args: + model_name (str): model name + base_path (str, optional): base path to store Torchscript model. Defaults to "./models". + """ + from . import models_dict + + # download model to cellpose directory + if model_name not in models_dict: + raise ValueError( + f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." + ) + + if not (base_path / f"{model_name}.pth").exists(): + print(f"Downloading {model_name} from {models_dict[model_name]}") + download_url_to_file( + models_dict[model_name], str(base_path / f"{model_name}.pth") + ) + print("Downloaded model {model_name} to {base_path}.") + return diff --git a/src/cellmap_models/pytorch/cellpose/load_model.py b/src/cellmap_models/pytorch/cellpose/load_model.py index 9b11bc6..e00c6ba 100644 --- a/src/cellmap_models/pytorch/cellpose/load_model.py +++ b/src/cellmap_models/pytorch/cellpose/load_model.py @@ -1,7 +1,6 @@ from pathlib import Path -from . import models_dict -from cellmap_models.utils import download_url_to_file import torch +from .get_model import get_model def load_model( @@ -19,15 +18,8 @@ def load_model( Returns: model: model """ - if model_name not in models_dict: - raise ValueError( - f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." - ) - if not (base_path / f"{model_name}.pth").exists(): - print(f"Downloading {model_name} from {models_dict[model_name]}") - download_url_to_file( - models_dict[model_name], str(base_path / f"{model_name}.pth") - ) + + get_model(model_name, base_path) if device == "cuda" and not torch.cuda.is_available(): device = "cpu" print("CUDA not available. Using CPU.")