From a28444a27b22ee549949aaab3298f24e29161b3d Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 12 Mar 2024 11:09:08 -0400 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=8E=A8=20Clean=20up=20model=20(d?= =?UTF-8?q?own)loading?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + README.md | 4 ++- .../__pycache__/utils.cpython-310.pyc | Bin 1619 -> 1579 bytes .../pytorch/cellpose/add_model.py | 2 +- .../pytorch/cellpose/get_model.py | 11 ++++--- .../pytorch/cellpose/load_model.py | 2 +- src/cellmap_models/pytorch/cosem/__init__.py | 28 ++++++++++-------- .../pytorch/cosem/load_model.py | 24 ++++++++++++--- .../pytorch/cosem/setup04/model.py | 11 ------- .../pytorch/cosem/setup26.1/model.py | 10 ------- .../pytorch/cosem/setup28/model.py | 10 ------- .../pytorch/cosem/setup36/model.py | 10 ------- .../pytorch/cosem/setup45/model.py | 10 ------- src/cellmap_models/utils.py | 11 ++----- 14 files changed, 51 insertions(+), 83 deletions(-) diff --git a/.gitignore b/.gitignore index 01e50ad..eb2920e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,6 +161,7 @@ cython_debug/ # Pytorch *.pth +*.pt # Miscellaneous .DS_Store diff --git a/README.md b/README.md index ff5e2eb..261420b 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,12 @@ This package contains the models used for segmention by the CellMap project team ## Installation +We strongly recommend using a [conda](https://docs.anaconda.com/free/miniconda/#quick-command-line-install) (or [mamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html#automatic-install)) environment to install the package. First, clone the repository and then create a new environment with the required dependencies: + ```bash git clone https://github.com/janelia-cellmap/cellmap-models cd cellmap-models -conda env create -n cellmap python=3.10 +conda env create -n cellmap python=3.10 pytorch -c pytorch -c conda-forge conda activate cellmap pip install . ``` diff --git a/src/cellmap_models/__pycache__/utils.cpython-310.pyc b/src/cellmap_models/__pycache__/utils.cpython-310.pyc index b067da1065c4ab5318e8c869ad18ac5499e14c90..ee0ed8d90c502bc4f7a146695a86e272725cde0a 100644 GIT binary patch delta 357 zcmYk2Jxc>Y5Qca5_L9B5Jrgx*QYqKSmG;^M5j%@SB$WueU=mI6a-hWqh15p~iwIVC zSc#SRQ^e-}fUR~`;)FB~%-1{fz%cLfMFztl5CrYK`=yU&!9DEGp6+f%Z^>}ws!~dU zB$ZxOm9gh3ogR+8WZ5`|N}@d65Xw*8X*d=hF>bb^WDa9L;l`hUqJ0|cL14UW)%ZC8 zF(RuiX7lc|uc| zsJRveZ;Z}hs~thX^-Xt_oDTa35iLqeU+6>^j_Sq*H0`sq2Mzn_r2O)%uTdP?n6>N& z3qjZ)_PDu^zhNE+PSHVt1VL*UC6%xQRr~60*Te5P#I*?8Q8}()6*uOgW`Et|Kg73D A3IG5A delta 379 zcmY*VJ4*vW5Z>8)g|v@ZtK(3-d(d2GIw& zXI?Z|fh!z4h~$VtOkxeRn`+N85!d1wVkgkEnZfNo5PXQ~e%s|{T;vi5#I1OC%4O3@ zt%(SzNJV`GP}^8V-P1_oneWPwd!0H0KH>YyfE!S zp56J5pA-L-&wrH>ijeQTe83P&n?6o;5*!6ANEOneDQwimNLhzPF;)g5RJZe~9-Ozj ztefgIIPRQ@S9J+Y@vh$MHTgEx5|>&9u(;J$mJRGd0n~{x@xTK-p;f8yq3-9XCp%is JQc=8VjW5?0U^4&! diff --git a/src/cellmap_models/pytorch/cellpose/add_model.py b/src/cellmap_models/pytorch/cellpose/add_model.py index d51a13d..969fa01 100644 --- a/src/cellmap_models/pytorch/cellpose/add_model.py +++ b/src/cellmap_models/pytorch/cellpose/add_model.py @@ -16,7 +16,7 @@ def add_model(model_name: Optional[str] = None): model_name = sys.argv[1] base_path = MODEL_DIR get_model(model_name, base_path) - _add_model(os.path.join(base_path, f"{model_name}.pth")) + _add_model(os.path.join(base_path, f"{model_name}.pt")) print( f"Added model {model_name}. This will now be available in the cellpose model list." ) diff --git a/src/cellmap_models/pytorch/cellpose/get_model.py b/src/cellmap_models/pytorch/cellpose/get_model.py index b909895..0a43e48 100644 --- a/src/cellmap_models/pytorch/cellpose/get_model.py +++ b/src/cellmap_models/pytorch/cellpose/get_model.py @@ -1,6 +1,9 @@ import os from pathlib import Path -from cellpose.utils import download_url_to_file + +from cellmap_models import download_url_to_file + +# from cellpose.utils import download_url_to_file def get_model( @@ -13,16 +16,16 @@ def get_model( model_name (str): model name base_path (str, optional): base path to store Torchscript model. Defaults to "./models". """ - from . import models_dict + from . import models_dict # avoid circular import # download model to cellpose directory if model_name not in models_dict: raise ValueError( f"Model {model_name} is not available. Available models are {list(models_dict.keys())}." ) - full_path = os.path.join(base_path, f"{model_name}.pth") + full_path = os.path.join(base_path, f"{model_name}.pt") if not Path(full_path).exists(): print(f"Downloading {model_name} from {models_dict[model_name]}") download_url_to_file(models_dict[model_name], full_path) - print("Downloaded model {model_name} to {base_path}.") + print(f"Downloaded model {model_name} to {base_path}.") return diff --git a/src/cellmap_models/pytorch/cellpose/load_model.py b/src/cellmap_models/pytorch/cellpose/load_model.py index 63b41a3..b39dd7e 100644 --- a/src/cellmap_models/pytorch/cellpose/load_model.py +++ b/src/cellmap_models/pytorch/cellpose/load_model.py @@ -24,6 +24,6 @@ def load_model( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" print("CUDA not available. Using CPU.") - model = torch.jit.load(os.path.join(base_path, f"{model_name}.pth"), device) + model = torch.jit.load(os.path.join(base_path, f"{model_name}.pt"), device) model.eval() return model diff --git a/src/cellmap_models/pytorch/cosem/__init__.py b/src/cellmap_models/pytorch/cosem/__init__.py index a07c6c8..58eac8c 100755 --- a/src/cellmap_models/pytorch/cosem/__init__.py +++ b/src/cellmap_models/pytorch/cosem/__init__.py @@ -1,15 +1,17 @@ from .load_model import load_model -models_list = [ - "setup04/1820500.pth", - "setup04/625000.pth", - "setup04/975000.pth", - "setup26.1/2580000.pth", - "setup26.1/650000.pth", - "setup28/1440000.pth", - "setup28/775000.pth", - "setup36/1100000.pth", - "setup36/500000.pth", - "setup45/1634500.pth", - "setup45/625000.pth", -] +models_dict = { + "setup04/1820500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/1820500", + "setup04/625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/625000", + "setup04/975000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/975000", + "setup26.1/2580000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/2580000", + "setup26.1/650000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/650000", + "setup28/1440000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/1440000", + "setup28/775000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/775000", + "setup36/1100000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/1100000", + "setup36/500000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/500000", + "setup45/1634500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/1634500", + "setup45/625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/625000", +} + +models_list = list(models_dict.keys()) diff --git a/src/cellmap_models/pytorch/cosem/load_model.py b/src/cellmap_models/pytorch/cosem/load_model.py index c42238a..c44d691 100755 --- a/src/cellmap_models/pytorch/cosem/load_model.py +++ b/src/cellmap_models/pytorch/cosem/load_model.py @@ -6,6 +6,7 @@ from torch import nn from pathlib import Path from importlib.machinery import SourceFileLoader +from cellmap_models import download_url_to_file default_params = { "in_channels": 1, @@ -29,21 +30,36 @@ def get_param_dict(model_params): return param_dict -def load_model(checkpoint_path): +def load_model(checkpoint_name): """ Load a model from a checkpoint file. Args: - checkpoint_path (str): Path to the checkpoint file. + checkpoint_name (str): Name of the checkpoint file. """ - if not Path(checkpoint_path).exists(): - checkpoint_path = Path(Path(__file__).parent / checkpoint_path) + from . import models_dict, models_list # avoid circular import + + # Make sure the checkpoint exists + if ( + checkpoint_name not in models_dict + and Path(checkpoint_name).with_suffix(".pth") not in models_list + ): + raise ValueError(f"Model {checkpoint_name} not found") + checkpoint_path = Path(Path(__file__).parent / Path(checkpoint_name)).with_suffix( + ".pth" + ) + if not checkpoint_path.exists(): + url = models_dict[checkpoint_name] + print(f"Downloading {checkpoint_name} from {url}") + download_url_to_file(url, checkpoint_path) + model_params = SourceFileLoader( "model", str(Path(checkpoint_path).parent / "model.py") ).load_module() model = Architecture(model_params) + print(f"Loading model from {checkpoint_path}") checkpoint = torch.load(checkpoint_path) new_checkpoint = deepcopy(checkpoint) for key in checkpoint["model"].keys(): diff --git a/src/cellmap_models/pytorch/cosem/setup04/model.py b/src/cellmap_models/pytorch/cosem/setup04/model.py index 77f5656..7009e22 100755 --- a/src/cellmap_models/pytorch/cosem/setup04/model.py +++ b/src/cellmap_models/pytorch/cosem/setup04/model.py @@ -26,14 +26,3 @@ final_feature_width = 12 * 6 classes_out = 14 - -# download pretrained model checkpoints from s3 -urls = { - "1820500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/1820500", - "625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/625000", - "975000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup04/975000", -} -for name, url in urls.items(): - if not (Path(__file__).parent / f"{name}").exists(): - print(f"Downloading {name} from {url}") - download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth")) diff --git a/src/cellmap_models/pytorch/cosem/setup26.1/model.py b/src/cellmap_models/pytorch/cosem/setup26.1/model.py index 80f9277..0662d22 100755 --- a/src/cellmap_models/pytorch/cosem/setup26.1/model.py +++ b/src/cellmap_models/pytorch/cosem/setup26.1/model.py @@ -26,13 +26,3 @@ final_feature_width = 12 * 6 classes_out = 3 - -# download pretrained model checkpoints from s3 -urls = { - "2580000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/2580000", - "650000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup26.1/650000", -} -for name, url in urls.items(): - if not (Path(__file__).parent / f"{name}").exists(): - print(f"Downloading {name} from {url}") - download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth")) diff --git a/src/cellmap_models/pytorch/cosem/setup28/model.py b/src/cellmap_models/pytorch/cosem/setup28/model.py index bf6e81e..722a292 100755 --- a/src/cellmap_models/pytorch/cosem/setup28/model.py +++ b/src/cellmap_models/pytorch/cosem/setup28/model.py @@ -26,13 +26,3 @@ final_feature_width = 12 * 6 classes_out = 2 - -# download pretrained model checkpoints from s3 -urls = { - "1440000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/1440000", - "775000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup28/775000", -} -for name, url in urls.items(): - if not (Path(__file__).parent / f"{name}").exists(): - print(f"Downloading {name} from {url}") - download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth")) diff --git a/src/cellmap_models/pytorch/cosem/setup36/model.py b/src/cellmap_models/pytorch/cosem/setup36/model.py index 78eae57..483d8ed 100755 --- a/src/cellmap_models/pytorch/cosem/setup36/model.py +++ b/src/cellmap_models/pytorch/cosem/setup36/model.py @@ -26,13 +26,3 @@ final_feature_width = 12 * 6 classes_out = 2 - -# download pretrained model checkpoints from s3 -urls = { - "1100000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/1100000", - "500000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup36/500000", -} -for name, url in urls.items(): - if not (Path(__file__).parent / f"{name}").exists(): - print(f"Downloading {name} from {url}") - download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth")) diff --git a/src/cellmap_models/pytorch/cosem/setup45/model.py b/src/cellmap_models/pytorch/cosem/setup45/model.py index 1fa8cb7..ca24b83 100755 --- a/src/cellmap_models/pytorch/cosem/setup45/model.py +++ b/src/cellmap_models/pytorch/cosem/setup45/model.py @@ -26,13 +26,3 @@ final_feature_width = 12 * 6 classes_out = 2 - -# download pretrained model checkpoints from s3 -urls = { - "1634500": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/1634500", - "625000": "https://janelia-cosem-networks.s3.amazonaws.com/v0003.2-pytorch/cosem_models/cosem_models/setup45/625000", -} -for name, url in urls.items(): - if not (Path(__file__).parent / f"{name}").exists(): - print(f"Downloading {name} from {url}") - download_url_to_file(url, str(Path(__file__).parent / f"{name}.pth")) diff --git a/src/cellmap_models/utils.py b/src/cellmap_models/utils.py index 3cf1d4d..c853aaf 100644 --- a/src/cellmap_models/utils.py +++ b/src/cellmap_models/utils.py @@ -7,8 +7,7 @@ def download_url_to_file(url, dst, progress=True): - # Originally from CellPose - """Download object at the given URL to a local path. + r"""Download object at the given URL to a local path. Thanks to torch, slightly modified Args: url (string): URL of the object to download @@ -31,8 +30,8 @@ def download_url_to_file(url, dst, progress=True): # We deliberately save it in a temp file and move it after dst = os.path.expanduser(dst) dst_dir = os.path.dirname(dst) - f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) - try: + os.makedirs(dst_dir, exist_ok=True) + with tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) as f: with tqdm( total=file_size, disable=not progress, @@ -48,7 +47,3 @@ def download_url_to_file(url, dst, progress=True): pbar.update(len(buffer)) f.close() shutil.move(f.name, dst) - finally: - f.close() - if os.path.exists(f.name): - os.remove(f.name)