diff --git a/.gitignore b/.gitignore index c18dd8d..1e2bcb3 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ __pycache__/ +*.egg-info diff --git a/cellSAM/model.py b/cellSAM/model.py index 499fea4..6c39f54 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -43,7 +43,7 @@ def download_file_with_progress(url, destination): if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR: Something went wrong") -def get_model(model: nn.Module = None) -> nn.Module: +def get_model(model: nn.Module = None, map_location="cpu") -> nn.Module: """ Returns a loaded CellSAM model. If model is None, downloads weights and loads the model with a progress bar. """ @@ -63,7 +63,7 @@ def get_model(model: nn.Module = None) -> nn.Module: model_path, ) model = CellSAM(config) - model.load_state_dict(torch.load(model_path)) + model.load_state_dict(torch.load(model_path, map_location=map_location)) return model def segment_cellular_image( @@ -82,7 +82,7 @@ def segment_cellular_image( if 'cuda' in device: assert torch.cuda.is_available(), "cuda is not available. Please use 'cpu' as device." - model = get_model(model).eval() + model = get_model(model, map_location=device).eval() model.bbox_threshold = bbox_threshold img = format_image_shape(img)