Skip to content

Commit

Permalink
Make caching embeddings optional
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 21, 2024
1 parent 5724a24 commit ef9fbb0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
30 changes: 22 additions & 8 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,13 @@ def run_amg(
iou_thresh_values: Optional[List[float]] = None,
stability_score_values: Optional[List[float]] = None,
peft_kwargs: Optional[Dict] = None,
cache_embeddings: bool = False
) -> str:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)
if cache_embeddings:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)
else:
embedding_folder = None

predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
amg = AutomaticMaskGenerator(predictor)
Expand All @@ -572,9 +576,15 @@ def run_amg(
)

instance_segmentation.run_instance_segmentation_grid_search_and_inference(
amg, grid_search_values,
val_image_paths, val_gt_paths, test_image_paths,
embedding_folder, prediction_folder, gs_result_folder,
segmenter=amg,
grid_search_values=grid_search_values,
val_image_paths=val_image_paths,
val_gt_paths=val_gt_paths,
test_image_paths=test_image_paths,
embedding_dir=embedding_folder,
prediction_dir=prediction_folder,
result_dir=gs_result_folder,
experiment_folder=experiment_folder,
)
return prediction_folder

Expand All @@ -592,9 +602,13 @@ def run_instance_segmentation_with_decoder(
val_gt_paths: List[Union[str, os.PathLike]],
test_image_paths: List[Union[str, os.PathLike]],
peft_kwargs: Optional[Dict] = None,
cache_embeddings: bool = False,
) -> str:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)
if cache_embeddings:
embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved
os.makedirs(embedding_folder, exist_ok=True)
else:
embedding_folder = None

predictor, decoder = get_predictor_and_decoder(
model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
Expand All @@ -616,6 +630,6 @@ def run_instance_segmentation_with_decoder(
segmenter, grid_search_values,
val_image_paths, val_gt_paths, test_image_paths,
embedding_dir=embedding_folder, prediction_dir=prediction_folder,
result_dir=gs_result_folder,
result_dir=gs_result_folder, experiment_folder=experiment_folder,
)
return prediction_folder
17 changes: 11 additions & 6 deletions micro_sam/evaluation/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,15 @@ def run_instance_segmentation_inference(
assert os.path.exists(image_path), image_path
image = imageio.imread(image_path)

embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
)
if embedding_dir is None:
segmenter.initialize(image)
else:
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
)
segmenter.initialize(image, image_embeddings)

segmenter.initialize(image, image_embeddings)
masks = segmenter.generate(**generate_kwargs)

if len(masks) == 0: # the instance segmentation can have no masks, hence we just save empty labels
Expand Down Expand Up @@ -362,6 +365,7 @@ def run_instance_segmentation_grid_search_and_inference(
test_image_paths: List[Union[str, os.PathLike]],
embedding_dir: Union[str, os.PathLike],
prediction_dir: Union[str, os.PathLike],
experiment_folder: Union[str, os.PathLike],
result_dir: Union[str, os.PathLike],
fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
verbose_gs: bool = True,
Expand All @@ -379,6 +383,7 @@ def run_instance_segmentation_grid_search_and_inference(
test_image_paths: The input images for inference.
embedding_dir: Folder to cache the image embeddings.
prediction_dir: Folder to save the predictions.
experiment_dir: Folder for caching best
result_dir: Folder to cache the evaluation results per image.
fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
Expand All @@ -394,7 +399,7 @@ def run_instance_segmentation_grid_search_and_inference(
print("Best grid-search result:", best_msa, "with parmeters:\n", best_param_str)
print()

save_grid_search_best_params(best_kwargs, best_msa, Path(embedding_dir).parent)
save_grid_search_best_params(best_kwargs, best_msa, experiment_folder)

generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
generate_kwargs.update(best_kwargs)
Expand Down

0 comments on commit ef9fbb0

Please sign in to comment.