diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index b24b31ad..51244d8c 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -247,7 +247,7 @@ def run_instance_segmentation_grid_search( def run_instance_segmentation_inference( segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], prediction_dir: Union[str, os.PathLike], generate_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -279,13 +279,16 @@ def run_instance_segmentation_inference( image = imageio.imread(image_path) if embedding_dir is None: - segmenter.initialize(image) + embedding_path = None else: + assert predictor is not None embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") - image_embeddings = util.precompute_image_embeddings( - predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings - ) - segmenter.initialize(image, image_embeddings) + + image_embeddings = util.precompute_image_embeddings( + predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings + ) + + segmenter.initialize(image, image_embeddings) masks = segmenter.generate(**generate_kwargs) @@ -365,7 +368,7 @@ def run_instance_segmentation_grid_search_and_inference( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], - embedding_dir: Union[str, os.PathLike], + embedding_dir: Optional[Union[str, os.PathLike]], prediction_dir: Union[str, os.PathLike], experiment_folder: Union[str, os.PathLike], result_dir: Union[str, os.PathLike],