diff --git a/src/cellmap_models/pytorch/cosem/load_checkpoint.py b/src/cellmap_models/pytorch/cosem/load_checkpoint.py index 081a1c4..51f19d4 100755 --- a/src/cellmap_models/pytorch/cosem/load_checkpoint.py +++ b/src/cellmap_models/pytorch/cosem/load_checkpoint.py @@ -2,8 +2,7 @@ from cellmap_models import download_url_to_file - -def download_checkpoint(checkpoint_name: str,local_folder: Path): +def download_checkpoint(checkpoint_name: str, local_folder: Path): """ download models checkpoint from s3 bucket. @@ -13,16 +12,16 @@ def download_checkpoint(checkpoint_name: str,local_folder: Path): return: checkpoint_path (Path): Path to the downloaded checkpoint. """ - from . import models_dict, models_list, model_names # avoid circular import + from . import models_dict, models_list # avoid circular import # Make sure the checkpoint exists if checkpoint_name not in models_list: raise ValueError( f"Checkpoint {checkpoint_name} not found. Available checkpoints: {models_list}" ) - + checkpoint_path = Path( - local_folder/ Path(checkpoint_name.replace(".", "_")) + local_folder / Path(checkpoint_name.replace(".", "_")) ).with_suffix(".pth") if not checkpoint_path.exists(): url = models_dict[checkpoint_name]