From e086f13255c7b1378a67683dfde87205d4433f43 Mon Sep 17 00:00:00 2001 From: Ross Barnowski Date: Tue, 9 Apr 2024 12:47:59 -0700 Subject: [PATCH 1/2] Make model loadable on devices w/out cuda. --- cellSAM/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cellSAM/model.py b/cellSAM/model.py index 499fea4..6c39f54 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -43,7 +43,7 @@ def download_file_with_progress(url, destination): if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR: Something went wrong") -def get_model(model: nn.Module = None) -> nn.Module: +def get_model(model: nn.Module = None, map_location="cpu") -> nn.Module: """ Returns a loaded CellSAM model. If model is None, downloads weights and loads the model with a progress bar. """ @@ -63,7 +63,7 @@ def get_model(model: nn.Module = None) -> nn.Module: model_path, ) model = CellSAM(config) - model.load_state_dict(torch.load(model_path)) + model.load_state_dict(torch.load(model_path, map_location=map_location)) return model def segment_cellular_image( @@ -82,7 +82,7 @@ def segment_cellular_image( if 'cuda' in device: assert torch.cuda.is_available(), "cuda is not available. Please use 'cpu' as device." - model = get_model(model).eval() + model = get_model(model, map_location=device).eval() model.bbox_threshold = bbox_threshold img = format_image_shape(img) From eae27473741e729663933371f29551e967bf389c Mon Sep 17 00:00:00 2001 From: Ross Barnowski Date: Tue, 9 Apr 2024 12:48:35 -0700 Subject: [PATCH 2/2] Add developer cruft to .gitignore. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c18dd8d..1e2bcb3 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ __pycache__/ +*.egg-info