diff --git a/pyproject.toml b/pyproject.toml
index 21d8a68..34030a5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,9 @@ dev = [
'pdoc',
'pre-commit'
]
+pretrained = [
+ 'cellpose[gui]'
+]
[project.urls]
homepage = "https://github.com/janelia-cellmap/cellmap-models"
@@ -39,4 +42,7 @@ repository = "https://github.com/janelia-cellmap/cellmap-models"
[tool.mypy]
exclude = ['setup*']
-ignore_missing_imports = true
\ No newline at end of file
+ignore_missing_imports = true
+
+[project.scripts]
+cellmap.add_cellpose = "cellmap_models.cellpose:add_model"
\ No newline at end of file
diff --git a/src/cellmap_models/__init__.py b/src/cellmap_models/__init__.py
index cd9c858..bbf2be9 100644
--- a/src/cellmap_models/__init__.py
+++ b/src/cellmap_models/__init__.py
@@ -3,4 +3,4 @@
"""
from .utils import download_url_to_file
-from .pytorch import cosem
+from .pytorch import cosem, cellpose
diff --git a/src/cellmap_models/pytorch/__init__.py b/src/cellmap_models/pytorch/__init__.py
index d95c684..5cdcb2d 100755
--- a/src/cellmap_models/pytorch/__init__.py
+++ b/src/cellmap_models/pytorch/__init__.py
@@ -1 +1,2 @@
from . import cosem
+from . import cellpose
diff --git a/src/cellmap_models/pytorch/cellpose/README.md b/src/cellmap_models/pytorch/cellpose/README.md
new file mode 100644
index 0000000..3554924
--- /dev/null
+++ b/src/cellmap_models/pytorch/cellpose/README.md
@@ -0,0 +1,36 @@
+
+
Finetuned Cellpose Models
+
+This directory contains finetuned scripts for downloading Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data.
+
+## Models
+
+...
+
+## Usage
+
+Once you have chosen a model based on the descriptions above, you can download its weights from the `cellmap-models` repository and use them as described below:
+
+If you would like to load a model for your own use, you can do the following:
+
+```python
+from cellmap_models.cellpose import load_model
+model = load_model('')
+```
+
+__If you would like to download and use a Cellpose model with the `cellpose` package or its GUI, do so by following the instructions below.__
+
+First install the `cellpose` package:
+
+```bash
+conda activate cellmap
+pip install cellpose[gui]
+```
+
+Then you can also download model weights from the `cellmap-models` repository and add them to your local `cellpose` model directory. For example, you can run the following commands:
+
+```bash
+cellmap.add_cellpose
+```
+
+where `` is the name of the model you would like to download, based on the descriptions above. For example, to download the `...
diff --git a/src/cellmap_models/pytorch/cellpose/__init__.py b/src/cellmap_models/pytorch/cellpose/__init__.py
new file mode 100644
index 0000000..fd75fad
--- /dev/null
+++ b/src/cellmap_models/pytorch/cellpose/__init__.py
@@ -0,0 +1,2 @@
+from .add_model import add_model
+from .load_model import load_model
diff --git a/src/cellmap_models/pytorch/cellpose/add_model.py b/src/cellmap_models/pytorch/cellpose/add_model.py
new file mode 100644
index 0000000..7246750
--- /dev/null
+++ b/src/cellmap_models/pytorch/cellpose/add_model.py
@@ -0,0 +1,29 @@
+from . import models_dict
+from cellpose.io import _add_model
+from cellpose.models import MODEL_DIR
+from cellpose.utils import download_url_to_file
+
+
+def add_model(model_name: str):
+ """Add model to cellpose
+
+ Args:
+ model_name (str): model name
+ """
+ # download model to cellpose directory
+ if model_name not in models_dict:
+ raise ValueError(
+ f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
+ )
+ base_path = MODEL_DIR
+
+ if not (base_path / f"{model_name}.pth").exists():
+ print(f"Downloading {model_name} from {models_dict[model_name]}")
+ download_url_to_file(
+ models_dict[model_name], str(base_path / f"{model_name}.pth")
+ )
+ _add_model(str(base_path / f"{model_name}.pth"))
+ print(
+ f"Added model {model_name}. This will now be available in the cellpose model list."
+ )
+ return
diff --git a/src/cellmap_models/pytorch/cellpose/load_model.py b/src/cellmap_models/pytorch/cellpose/load_model.py
new file mode 100644
index 0000000..9b11bc6
--- /dev/null
+++ b/src/cellmap_models/pytorch/cellpose/load_model.py
@@ -0,0 +1,36 @@
+from pathlib import Path
+from . import models_dict
+from cellmap_models.utils import download_url_to_file
+import torch
+
+
+def load_model(
+ model_name: str,
+ base_path: str = f"{Path(__file__).parent}/models",
+ device: str = "cuda",
+):
+ """Load model
+
+ Args:
+ model_name (str): model name
+ base_path (str, optional): base path to store Torchscript model. Defaults to "./models".
+ device (str, optional): device. Defaults to "cuda".
+
+ Returns:
+ model: model
+ """
+ if model_name not in models_dict:
+ raise ValueError(
+ f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
+ )
+ if not (base_path / f"{model_name}.pth").exists():
+ print(f"Downloading {model_name} from {models_dict[model_name]}")
+ download_url_to_file(
+ models_dict[model_name], str(base_path / f"{model_name}.pth")
+ )
+ if device == "cuda" and not torch.cuda.is_available():
+ device = "cpu"
+ print("CUDA not available. Using CPU.")
+ model = torch.jit.load(str(base_path / f"{model_name}.pth"), device)
+ model.eval()
+ return model
diff --git a/src/cellmap_models/pytorch/cosem/load_model.py b/src/cellmap_models/pytorch/cosem/load_model.py
index 6930bad..c42238a 100755
--- a/src/cellmap_models/pytorch/cosem/load_model.py
+++ b/src/cellmap_models/pytorch/cosem/load_model.py
@@ -30,6 +30,12 @@ def get_param_dict(model_params):
def load_model(checkpoint_path):
+ """
+ Load a model from a checkpoint file.
+
+ Args:
+ checkpoint_path (str): Path to the checkpoint file.
+ """
if not Path(checkpoint_path).exists():
checkpoint_path = Path(Path(__file__).parent / checkpoint_path)
model_params = SourceFileLoader(
diff --git a/src/cellmap_models/utils.py b/src/cellmap_models/utils.py
index 45b3ad9..3cf1d4d 100644
--- a/src/cellmap_models/utils.py
+++ b/src/cellmap_models/utils.py
@@ -7,7 +7,8 @@
def download_url_to_file(url, dst, progress=True):
- r"""Download object at the given URL to a local path.
+ # Originally from CellPose
+ """Download object at the given URL to a local path.
Thanks to torch, slightly modified
Args:
url (string): URL of the object to download