Skip to content

Commit

Permalink
Add CLI for benchmarking datasets on SAM models (#728)
Browse files Browse the repository at this point in the history
Add scripts for benchmarking SAM models on microscopy datasets
  • Loading branch information
anwai98 authored Oct 14, 2024
1 parent c48d68f commit 766aa9b
Show file tree
Hide file tree
Showing 11 changed files with 884 additions and 95 deletions.
93 changes: 67 additions & 26 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Dict, Optional, Union, Tuple
from typing import Optional, Union, Tuple, Dict

import numpy as np
import imageio.v3 as imageio
Expand All @@ -12,54 +12,85 @@
from .multi_dimensional_segmentation import automatic_3d_segmentation


def get_predictor_and_segmenter(
model_type: str,
checkpoint: Optional[Union[os.PathLike, str]] = None,
device: str = None,
amg: bool = False,
is_tiled: bool = False,
amg_kwargs: Dict = {}
) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
"""Get the Segment Anything model and class for automatic instance segmentation.
Args:
model_type: The Segment Anything model choice.
checkpoint: The filepath to the stored model checkpoints.
device: The torch device.
amg: Whether to perform automatic segmentation in AMG mode.
is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
Returns:
The Segment Anything model.
The automatic instance segmentation class.
"""
# Get the device
device = util.get_device(device=device)

# Get the predictor and state for Segment Anything models.
predictor, state = util.get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
)

# Get the segmenter for automatic segmentation.
assert isinstance(amg_kwargs, Dict), "Please ensure 'amg_kwargs' gets arguments in a dictionary."

segmenter = get_amg(
predictor=predictor,
is_tiled=is_tiled,
decoder=get_decoder(
image_encoder=predictor.model.image_encoder,
decoder_state=state["decoder_state"],
device=device
) if "decoder_state" in state and not amg else None,
**amg_kwargs
)

return predictor, segmenter


def automatic_instance_segmentation(
predictor: util.SamPredictor,
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
input_path: Union[Union[os.PathLike, str], np.ndarray],
output_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str]] = None,
model_type: str = util._DEFAULT_MODEL,
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
key: Optional[str] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
use_amg: bool = False,
amg_kwargs: Optional[Dict] = None,
verbose: bool = True,
**generate_kwargs
) -> np.ndarray:
"""Run automatic segmentation for the input image.
Args:
predictor: The Segment Anything model.
segmenter: The automatic instance segmentation class.
input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
or a container file (e.g. hdf5 or zarr).
output_path: The output path where the instance segmentations will be saved.
embedding_path: The path where the embeddings are cached already / will be saved.
model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
ndim: The dimensionality of the data.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method.
amg_kwargs: optional keyword arguments for creating the AMG or AIS class.
generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class.
verbose: Verbosity flag.
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.
"""
predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)

if "decoder_state" in state and not use_amg: # AIS
decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"])
segmenter = get_amg(
predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None,
**({} if amg_kwargs is None else amg_kwargs)
)
else: # AMG
segmenter = get_amg(
predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs)
)

# Load the input image file.
if isinstance(input_path, np.ndarray):
image_data = input_path
Expand All @@ -77,6 +108,7 @@ def automatic_instance_segmentation(
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)
else:
Expand All @@ -88,6 +120,7 @@ def automatic_instance_segmentation(
ndim=ndim,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
)

segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
Expand Down Expand Up @@ -162,6 +195,11 @@ def main():
parser.add_argument(
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
)
parser.add_argument(
"-d", "--device", default=None,
help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
"By default the most performant available device will be selected."
)

args, parameter_args = parser.parse_known_args()

Expand All @@ -179,17 +217,20 @@ def _convert_argval(value):
parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

predictor, segmenter = get_predictor_and_segmenter(
model_type=args.model_type, checkpoint=args.checkpoint, device=args.device,
)

automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=args.input_path,
output_path=args.output_path,
embedding_path=args.embedding_path,
model_type=args.model_type,
checkpoint_path=args.checkpoint,
key=args.key,
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
use_amg=args.amg,
**generate_kwargs,
)

Expand Down
Loading

0 comments on commit 766aa9b

Please sign in to comment.