diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index b033055f..5d86067a 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -550,9 +550,13 @@ def run_amg( iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None, + cache_embeddings: bool = False ) -> str: - embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved - os.makedirs(embedding_folder, exist_ok=True) + if cache_embeddings: + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved + os.makedirs(embedding_folder, exist_ok=True) + else: + embedding_folder = None predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) amg = AutomaticMaskGenerator(predictor) @@ -572,9 +576,15 @@ def run_amg( ) instance_segmentation.run_instance_segmentation_grid_search_and_inference( - amg, grid_search_values, - val_image_paths, val_gt_paths, test_image_paths, - embedding_folder, prediction_folder, gs_result_folder, + segmenter=amg, + grid_search_values=grid_search_values, + val_image_paths=val_image_paths, + val_gt_paths=val_gt_paths, + test_image_paths=test_image_paths, + embedding_dir=embedding_folder, + prediction_dir=prediction_folder, + result_dir=gs_result_folder, + experiment_folder=experiment_folder, ) return prediction_folder @@ -592,9 +602,13 @@ def run_instance_segmentation_with_decoder( val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, + cache_embeddings: bool = False, ) -> str: - embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved - os.makedirs(embedding_folder, exist_ok=True) + if cache_embeddings: + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved + os.makedirs(embedding_folder, exist_ok=True) + else: + embedding_folder = None predictor, decoder = get_predictor_and_decoder( model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, @@ -616,6 +630,6 @@ def run_instance_segmentation_with_decoder( segmenter, grid_search_values, val_image_paths, val_gt_paths, test_image_paths, embedding_dir=embedding_folder, prediction_dir=prediction_folder, - result_dir=gs_result_folder, + result_dir=gs_result_folder, experiment_folder=experiment_folder, ) return prediction_folder diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index 5e657190..dbee0959 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -276,12 +276,15 @@ def run_instance_segmentation_inference( assert os.path.exists(image_path), image_path image = imageio.imread(image_path) - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings( - predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings - ) + if embedding_dir is None: + segmenter.initialize(image) + else: + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + image_embeddings = util.precompute_image_embeddings( + predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings + ) + segmenter.initialize(image, image_embeddings) - segmenter.initialize(image, image_embeddings) masks = segmenter.generate(**generate_kwargs) if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels @@ -362,6 +365,7 @@ def run_instance_segmentation_grid_search_and_inference( test_image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], + experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike], fixed_generate_kwargs: Optional[Dict[str, Any]] = None, verbose_gs: bool = True, @@ -379,6 +383,7 @@ def run_instance_segmentation_grid_search_and_inference( test_image_paths: The input images for inference. embedding_dir: Folder to cache the image embeddings. prediction_dir: Folder to save the predictions. + experiment_dir: Folder for caching best result_dir: Folder to cache the evaluation results per image. fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter. verbose_gs: Whether to run the gridsearch for individual images in a verbose mode. @@ -394,7 +399,7 @@ def run_instance_segmentation_grid_search_and_inference( print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str) print() - save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent) + save_grid_search_best_params(best_kwargs, best_msa, experiment_folder) generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs generate_kwargs.update(best_kwargs)