Skip to content

Commit

Permalink
style: 🎨 Clean up model (down)loading
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 12, 2024
1 parent 7f5de47 commit a28444a
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 83 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ cython_debug/

# Pytorch
*.pth
*.pt

# Miscellaneous
.DS_Store
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ This package contains the models used for segmention by the CellMap project team

## Installation

We strongly recommend using a [conda](https://docs.anaconda.com/free/miniconda/#quick-command-line-install) (or [mamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html#automatic-install)) environment to install the package. First, clone the repository and then create a new environment with the required dependencies:

```bash
git clone https://github.com/janelia-cellmap/cellmap-models
cd cellmap-models
conda env create -n cellmap python=3.10
conda env create -n cellmap python=3.10 pytorch -c pytorch -c conda-forge
conda activate cellmap
pip install .
```
Expand Down
Binary file modified src/cellmap_models/__pycache__/utils.cpython-310.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion src/cellmap_models/pytorch/cellpose/add_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def add_model(model_name: Optional[str] = None):
model_name = sys.argv[1]
base_path = MODEL_DIR
get_model(model_name, base_path)
_add_model(os.path.join(base_path, f"{model_name}.pth"))
_add_model(os.path.join(base_path, f"{model_name}.pt"))
print(
f"Added model {model_name}. This will now be available in the cellpose model list."
)
Expand Down
11 changes: 7 additions & 4 deletions src/cellmap_models/pytorch/cellpose/get_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from pathlib import Path
from cellpose.utils import download_url_to_file

from cellmap_models import download_url_to_file

# from cellpose.utils import download_url_to_file


def get_model(
Expand All @@ -13,16 +16,16 @@ def get_model(
model_name (str): model name
base_path (str, optional): base path to store Torchscript model. Defaults to "./models".
"""
from . import models_dict
from . import models_dict # avoid circular import

# download model to cellpose directory
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)
full_path = os.path.join(base_path, f"{model_name}.pth")
full_path = os.path.join(base_path, f"{model_name}.pt")
if not Path(full_path).exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(models_dict[model_name], full_path)
print("Downloaded model {model_name} to {base_path}.")
print(f"Downloaded model {model_name} to {base_path}.")
return
2 changes: 1 addition & 1 deletion src/cellmap_models/pytorch/cellpose/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def load_model(
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
print("CUDA not available. Using CPU.")
model = torch.jit.load(os.path.join(base_path, f"{model_name}.pth"), device)
model = torch.jit.load(os.path.join(base_path, f"{model_name}.pt"), device)
model.eval()
return model
28 changes: 15 additions & 13 deletions src/cellmap_models/pytorch/cosem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from .load_model import load_model

models_list = [
"setup04/1820500.pth",
"setup04/625000.pth",
"setup04/975000.pth",
"setup26.1/2580000.pth",
"setup26.1/650000.pth",
"setup28/1440000.pth",
"setup28/775000.pth",
"setup36/1100000.pth",
"setup36/500000.pth",
"setup45/1634500.pth",
"setup45/625000.pth",
]
models_dict = {
"setup04/1820500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/1820500",
"setup04/625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/625000",
"setup04/975000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/975000",
"setup26.1/2580000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/2580000",
"setup26.1/650000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/650000",
"setup28/1440000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/1440000",
"setup28/775000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/775000",
"setup36/1100000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/1100000",
"setup36/500000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/500000",
"setup45/1634500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/1634500",
"setup45/625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/625000",
}

models_list = list(models_dict.keys())
24 changes: 20 additions & 4 deletions src/cellmap_models/pytorch/cosem/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from pathlib import Path
from importlib.machinery import SourceFileLoader
from cellmap_models import download_url_to_file

default_params = {
"in_channels": 1,
Expand All @@ -29,21 +30,36 @@ def get_param_dict(model_params):
return param_dict


def load_model(checkpoint_path):
def load_model(checkpoint_name):
"""
Load a model from a checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file.
checkpoint_name (str): Name of the checkpoint file.
"""
if not Path(checkpoint_path).exists():
checkpoint_path = Path(Path(__file__).parent / checkpoint_path)
from . import models_dict, models_list # avoid circular import

# Make sure the checkpoint exists
if (
checkpoint_name not in models_dict
and Path(checkpoint_name).with_suffix(".pth") not in models_list
):
raise ValueError(f"Model {checkpoint_name} not found")
checkpoint_path = Path(Path(__file__).parent / Path(checkpoint_name)).with_suffix(
".pth"
)
if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)

model_params = SourceFileLoader(
"model", str(Path(checkpoint_path).parent / "model.py")
).load_module()

model = Architecture(model_params)

print(f"Loading model from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
new_checkpoint = deepcopy(checkpoint)
for key in checkpoint["model"].keys():
Expand Down
11 changes: 0 additions & 11 deletions src/cellmap_models/pytorch/cosem/setup04/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,3 @@
final_feature_width = 12 * 6

classes_out = 14

# download pretrained model checkpoints from s3
urls = {
"1820500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/1820500",
"625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/625000",
"975000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/975000",
}
for name, url in urls.items():
if not (Path(__file__).parent / f"{name}").exists():
print(f"Downloading {name} from {url}")
download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth"))
10 changes: 0 additions & 10 deletions src/cellmap_models/pytorch/cosem/setup26.1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,3 @@
final_feature_width = 12 * 6

classes_out = 3

# download pretrained model checkpoints from s3
urls = {
"2580000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/2580000",
"650000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/650000",
}
for name, url in urls.items():
if not (Path(__file__).parent / f"{name}").exists():
print(f"Downloading {name} from {url}")
download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth"))
10 changes: 0 additions & 10 deletions src/cellmap_models/pytorch/cosem/setup28/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,3 @@
final_feature_width = 12 * 6

classes_out = 2

# download pretrained model checkpoints from s3
urls = {
"1440000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/1440000",
"775000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/775000",
}
for name, url in urls.items():
if not (Path(__file__).parent / f"{name}").exists():
print(f"Downloading {name} from {url}")
download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth"))
10 changes: 0 additions & 10 deletions src/cellmap_models/pytorch/cosem/setup36/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,3 @@
final_feature_width = 12 * 6

classes_out = 2

# download pretrained model checkpoints from s3
urls = {
"1100000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/1100000",
"500000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/500000",
}
for name, url in urls.items():
if not (Path(__file__).parent / f"{name}").exists():
print(f"Downloading {name} from {url}")
download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth"))
10 changes: 0 additions & 10 deletions src/cellmap_models/pytorch/cosem/setup45/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,3 @@
final_feature_width = 12 * 6

classes_out = 2

# download pretrained model checkpoints from s3
urls = {
"1634500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/1634500",
"625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/625000",
}
for name, url in urls.items():
if not (Path(__file__).parent / f"{name}").exists():
print(f"Downloading {name} from {url}")
download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth"))
11 changes: 3 additions & 8 deletions src/cellmap_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@


def download_url_to_file(url, dst, progress=True):
# Originally from CellPose
"""Download object at the given URL to a local path.
r"""Download object at the given URL to a local path.
Thanks to torch, slightly modified
Args:
url (string): URL of the object to download
Expand All @@ -31,8 +30,8 @@ def download_url_to_file(url, dst, progress=True):
# We deliberately save it in a temp file and move it after
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
try:
os.makedirs(dst_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) as f:
with tqdm(
total=file_size,
disable=not progress,
Expand All @@ -48,7 +47,3 @@ def download_url_to_file(url, dst, progress=True):
pbar.update(len(buffer))
f.close()
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)

0 comments on commit a28444a

Please sign in to comment.