diff --git a/pyproject.toml b/pyproject.toml index 21d8a68..34030a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ dev = [ 'pdoc', 'pre-commit' ] +pretrained = [ + 'cellpose[gui]' +] [project.urls] homepage = "https://github.com/janelia-cellmap/cellmap-models" @@ -39,4 +42,7 @@ repository = "https://github.com/janelia-cellmap/cellmap-models" [tool.mypy] exclude = ['setup*'] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[project.scripts] +cellmap.add_cellpose = "cellmap_models.cellpose:add_model" \ No newline at end of file diff --git a/src/cellmap_models/__init__.py b/src/cellmap_models/__init__.py index cd9c858..bbf2be9 100644 --- a/src/cellmap_models/__init__.py +++ b/src/cellmap_models/__init__.py @@ -3,4 +3,4 @@ """ from .utils import download_url_to_file -from .pytorch import cosem +from .pytorch import cosem, cellpose diff --git a/src/cellmap_models/pytorch/__init__.py b/src/cellmap_models/pytorch/__init__.py index d95c684..5cdcb2d 100755 --- a/src/cellmap_models/pytorch/__init__.py +++ b/src/cellmap_models/pytorch/__init__.py @@ -1 +1,2 @@ from . import cosem +from . import cellpose diff --git a/src/cellmap_models/pytorch/cellpose/README.md b/src/cellmap_models/pytorch/cellpose/README.md new file mode 100644 index 0000000..3554924 --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/README.md @@ -0,0 +1,36 @@ + +

Finetuned Cellpose Models cellpose logo

+ +This directory contains finetuned scripts for downloading Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data. + +## Models + +... + +## Usage + +Once you have chosen a model based on the descriptions above, you can download its weights from the `cellmap-models` repository and use them as described below: + +If you would like to load a model for your own use, you can do the following: + +```python +from cellmap_models.cellpose import load_model +model = load_model('') +``` + +__If you would like to download and use a Cellpose model with the `cellpose` package or its GUI, do so by following the instructions below.__ + +First install the `cellpose` package: + +```bash +conda activate cellmap +pip install cellpose[gui] +``` + +Then you can also download model weights from the `cellmap-models` repository and add them to your local `cellpose` model directory. For example, you can run the following commands: + +```bash +cellmap.add_cellpose +``` + +where `` is the name of the model you would like to download, based on the descriptions above. For example, to download the `... diff --git a/src/cellmap_models/pytorch/cellpose/__init__.py b/src/cellmap_models/pytorch/cellpose/__init__.py new file mode 100644 index 0000000..fd75fad --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/__init__.py @@ -0,0 +1,2 @@ +from .add_model import add_model +from .load_model import load_model diff --git a/src/cellmap_models/pytorch/cellpose/add_model.py b/src/cellmap_models/pytorch/cellpose/add_model.py new file mode 100644 index 0000000..7246750 --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/add_model.py @@ -0,0 +1,29 @@ +from . import models_dict +from cellpose.io import _add_model +from cellpose.models import MODEL_DIR +from cellpose.utils import download_url_to_file + + +def add_model(model_name: str): + """Add model to cellpose + + Args: + model_name (str): model name + """ + # download model to cellpose directory + if model_name not in models_dict: + raise ValueError( + f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." + ) + base_path = MODEL_DIR + + if not (base_path / f"{model_name}.pth").exists(): + print(f"Downloading {model_name} from {models_dict[model_name]}") + download_url_to_file( + models_dict[model_name], str(base_path / f"{model_name}.pth") + ) + _add_model(str(base_path / f"{model_name}.pth")) + print( + f"Added model {model_name}. This will now be available in the cellpose model list." + ) + return diff --git a/src/cellmap_models/pytorch/cellpose/load_model.py b/src/cellmap_models/pytorch/cellpose/load_model.py new file mode 100644 index 0000000..9b11bc6 --- /dev/null +++ b/src/cellmap_models/pytorch/cellpose/load_model.py @@ -0,0 +1,36 @@ +from pathlib import Path +from . import models_dict +from cellmap_models.utils import download_url_to_file +import torch + + +def load_model( + model_name: str, + base_path: str = f"{Path(__file__).parent}/models", + device: str = "cuda", +): + """Load model + + Args: + model_name (str): model name + base_path (str, optional): base path to store Torchscript model. Defaults to "./models". + device (str, optional): device. Defaults to "cuda". + + Returns: + model: model + """ + if model_name not in models_dict: + raise ValueError( + f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." + ) + if not (base_path / f"{model_name}.pth").exists(): + print(f"Downloading {model_name} from {models_dict[model_name]}") + download_url_to_file( + models_dict[model_name], str(base_path / f"{model_name}.pth") + ) + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + print("CUDA not available. Using CPU.") + model = torch.jit.load(str(base_path / f"{model_name}.pth"), device) + model.eval() + return model diff --git a/src/cellmap_models/pytorch/cosem/load_model.py b/src/cellmap_models/pytorch/cosem/load_model.py index 6930bad..c42238a 100755 --- a/src/cellmap_models/pytorch/cosem/load_model.py +++ b/src/cellmap_models/pytorch/cosem/load_model.py @@ -30,6 +30,12 @@ def get_param_dict(model_params): def load_model(checkpoint_path): + """ + Load a model from a checkpoint file. + + Args: + checkpoint_path (str): Path to the checkpoint file. + """ if not Path(checkpoint_path).exists(): checkpoint_path = Path(Path(__file__).parent / checkpoint_path) model_params = SourceFileLoader( diff --git a/src/cellmap_models/utils.py b/src/cellmap_models/utils.py index 45b3ad9..3cf1d4d 100644 --- a/src/cellmap_models/utils.py +++ b/src/cellmap_models/utils.py @@ -7,7 +7,8 @@ def download_url_to_file(url, dst, progress=True): - r"""Download object at the given URL to a local path. + # Originally from CellPose + """Download object at the given URL to a local path. Thanks to torch, slightly modified Args: url (string): URL of the object to download