-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ✨ Add Cellpose model loading utilities.
- Loading branch information
1 parent
4f59a8d
commit a0af74f
Showing
9 changed files
with
120 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,4 @@ | |
""" | ||
|
||
from .utils import download_url_to_file | ||
from .pytorch import cosem | ||
from .pytorch import cosem, cellpose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from . import cosem | ||
from . import cellpose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
<!-- FILEPATH: /Users/rhoadesj/Repos/cellmap-models/src/cellmap_models/pytorch/cellpose/README.md --> | ||
<h1 style="height: 2em;">Finetuned Cellpose Models <img src="https://www.cellpose.org/static/images/cellpose_transparent.png" alt="cellpose logo"></h1> | ||
|
||
This directory contains finetuned scripts for downloading Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data. | ||
|
||
## Models | ||
|
||
... | ||
|
||
## Usage | ||
|
||
Once you have chosen a model based on the descriptions above, you can download its weights from the `cellmap-models` repository and use them as described below: | ||
|
||
If you would like to load a model for your own use, you can do the following: | ||
|
||
```python | ||
from cellmap_models.cellpose import load_model | ||
model = load_model('<model_name>') | ||
``` | ||
|
||
__If you would like to download and use a Cellpose model with the `cellpose` package or its GUI, do so by following the instructions below.__ | ||
|
||
First install the `cellpose` package: | ||
|
||
```bash | ||
conda activate cellmap | ||
pip install cellpose[gui] | ||
``` | ||
|
||
Then you can also download model weights from the `cellmap-models` repository and add them to your local `cellpose` model directory. For example, you can run the following commands: | ||
|
||
```bash | ||
cellmap.add_cellpose <model_name> | ||
``` | ||
|
||
where `<model_name>` is the name of the model you would like to download, based on the descriptions above. For example, to download the `... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .add_model import add_model | ||
from .load_model import load_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from . import models_dict | ||
from cellpose.io import _add_model | ||
from cellpose.models import MODEL_DIR | ||
from cellpose.utils import download_url_to_file | ||
|
||
|
||
def add_model(model_name: str): | ||
"""Add model to cellpose | ||
Args: | ||
model_name (str): model name | ||
""" | ||
# download model to cellpose directory | ||
if model_name not in models_dict: | ||
raise ValueError( | ||
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." | ||
) | ||
base_path = MODEL_DIR | ||
|
||
if not (base_path / f"{model_name}.pth").exists(): | ||
print(f"Downloading {model_name} from {models_dict[model_name]}") | ||
download_url_to_file( | ||
models_dict[model_name], str(base_path / f"{model_name}.pth") | ||
) | ||
_add_model(str(base_path / f"{model_name}.pth")) | ||
print( | ||
f"Added model {model_name}. This will now be available in the cellpose model list." | ||
) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from pathlib import Path | ||
from . import models_dict | ||
from cellmap_models.utils import download_url_to_file | ||
import torch | ||
|
||
|
||
def load_model( | ||
model_name: str, | ||
base_path: str = f"{Path(__file__).parent}/models", | ||
device: str = "cuda", | ||
): | ||
"""Load model | ||
Args: | ||
model_name (str): model name | ||
base_path (str, optional): base path to store Torchscript model. Defaults to "./models". | ||
device (str, optional): device. Defaults to "cuda". | ||
Returns: | ||
model: model | ||
""" | ||
if model_name not in models_dict: | ||
raise ValueError( | ||
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." | ||
) | ||
if not (base_path / f"{model_name}.pth").exists(): | ||
print(f"Downloading {model_name} from {models_dict[model_name]}") | ||
download_url_to_file( | ||
models_dict[model_name], str(base_path / f"{model_name}.pth") | ||
) | ||
if device == "cuda" and not torch.cuda.is_available(): | ||
device = "cpu" | ||
print("CUDA not available. Using CPU.") | ||
model = torch.jit.load(str(base_path / f"{model_name}.pth"), device) | ||
model.eval() | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters