diff --git a/src/cellmap_models/pytorch/cellpose/__init__.py b/src/cellmap_models/pytorch/cellpose/__init__.py index 6e0bc43..97a9ec5 100644 --- a/src/cellmap_models/pytorch/cellpose/__init__.py +++ b/src/cellmap_models/pytorch/cellpose/__init__.py @@ -1,6 +1,7 @@ from .add_model import add_model from .load_model import load_model from .get_model import get_model +from .download_checkpoint import download_checkpoint models_dict = { "jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp", diff --git a/src/cellmap_models/pytorch/cosem/load_checkpoint.py b/src/cellmap_models/pytorch/cellpose/download_checkpoint.py similarity index 100% rename from src/cellmap_models/pytorch/cosem/load_checkpoint.py rename to src/cellmap_models/pytorch/cellpose/download_checkpoint.py diff --git a/src/cellmap_models/pytorch/cosem/.DS_Store b/src/cellmap_models/pytorch/cosem/.DS_Store deleted file mode 100644 index 4ffd9ab..0000000 Binary files a/src/cellmap_models/pytorch/cosem/.DS_Store and /dev/null differ diff --git a/src/cellmap_models/pytorch/cosem/download_checkpoint.py b/src/cellmap_models/pytorch/cosem/download_checkpoint.py new file mode 100755 index 0000000..ea406bd --- /dev/null +++ b/src/cellmap_models/pytorch/cosem/download_checkpoint.py @@ -0,0 +1,30 @@ +from pathlib import Path +from cellmap_models import download_url_to_file + + +def download_checkpoint(checkpoint_name: str, checkpoint_path: Path): + """ + download models checkpoint from s3 bucket. + + Args: + checkpoint_name (str): Name of the checkpoint file. + local_folder (Path): Local path to save the checkpoint. + return: + checkpoint_path (Path): Path to the downloaded checkpoint. + """ + from . import models_dict, models_list # avoid circular import + + # Make sure the checkpoint exists + if checkpoint_name not in models_list: + raise ValueError( + f"Checkpoint {checkpoint_name} not found. Available checkpoints: {models_list}" + ) + + if not checkpoint_path.exists(): + url = models_dict[checkpoint_name] + print(f"Downloading {checkpoint_name} from {url}") + download_url_to_file(url, checkpoint_path) + else: + print(f"Checkpoint {checkpoint_name} found at {checkpoint_path}") + + return checkpoint_path