Skip to content

Commit

Permalink
support cellpose checkpoint download
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Mar 18, 2024
1 parent 599b7ec commit d85f13a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/cellmap_models/pytorch/cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .add_model import add_model
from .load_model import load_model
from .get_model import get_model
from .download_checkpoint import download_checkpoint

models_dict = {
"jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp",
Expand Down
Binary file removed src/cellmap_models/pytorch/cosem/.DS_Store
Binary file not shown.
30 changes: 30 additions & 0 deletions src/cellmap_models/pytorch/cosem/download_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from pathlib import Path
from cellmap_models import download_url_to_file


def download_checkpoint(checkpoint_name: str, checkpoint_path: Path):
"""
download models checkpoint from s3 bucket.
Args:
checkpoint_name (str): Name of the checkpoint file.
local_folder (Path): Local path to save the checkpoint.
return:
checkpoint_path (Path): Path to the downloaded checkpoint.
"""
from . import models_dict, models_list # avoid circular import

# Make sure the checkpoint exists
if checkpoint_name not in models_list:
raise ValueError(
f"Checkpoint {checkpoint_name} not found. Available checkpoints: {models_list}"
)

if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)
else:
print(f"Checkpoint {checkpoint_name} found at {checkpoint_path}")

return checkpoint_path

0 comments on commit d85f13a

Please sign in to comment.