Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI for benchmarking datasets on SAM models #728

Merged
merged 19 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 67 additions & 26 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Dict, Optional, Union, Tuple
from typing import Optional, Union, Tuple, Dict

import numpy as np
import imageio.v3 as imageio
Expand All @@ -12,54 +12,85 @@
from .multi_dimensional_segmentation import automatic_3d_segmentation


def get_predictor_and_segmenter(
model_type: str,
checkpoint: Optional[Union[os.PathLike, str]] = None,
device: str = None,
amg: bool = False,
is_tiled: bool = False,
amg_kwargs: Dict = {}
) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
"""Get the Segment Anything model and class for automatic instance segmentation.

Args:
model_type: The Segment Anything model choice.
checkpoint: The filepath to the stored model checkpoints.
device: The torch device.
amg: Whether to perform automatic segmentation in AMG mode.
is_tiled: Whether to return segmenter for performing segmentation in tiling window style.

Returns:
The Segment Anything model.
The automatic instance segmentation class.
"""
# Get the device
device = util.get_device(device=device)

# Get the predictor and state for Segment Anything models.
predictor, state = util.get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
)

# Get the segmenter for automatic segmentation.
assert isinstance(amg_kwargs, Dict), "Please ensure 'amg_kwargs' gets arguments in a dictionary."

segmenter = get_amg(
predictor=predictor,
is_tiled=is_tiled,
decoder=get_decoder(
image_encoder=predictor.model.image_encoder,
decoder_state=state["decoder_state"],
device=device
) if "decoder_state" in state and not amg else None,
**amg_kwargs
)

return predictor, segmenter


def automatic_instance_segmentation(
predictor: util.SamPredictor,
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
input_path: Union[Union[os.PathLike, str], np.ndarray],
output_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str]] = None,
model_type: str = util._DEFAULT_MODEL,
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
key: Optional[str] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
use_amg: bool = False,
amg_kwargs: Optional[Dict] = None,
verbose: bool = True,
**generate_kwargs
) -> np.ndarray:
"""Run automatic segmentation for the input image.

Args:
predictor: The Segment Anything model.
segmenter: The automatic instance segmentation class.
input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
or a container file (e.g. hdf5 or zarr).
output_path: The output path where the instance segmentations will be saved.
embedding_path: The path where the embeddings are cached already / will be saved.
model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
ndim: The dimensionality of the data.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method.
amg_kwargs: optional keyword arguments for creating the AMG or AIS class.
generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class.
verbose: Verbosity flag.
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.

Returns:
The segmentation result.
"""
predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)

if "decoder_state" in state and not use_amg: # AIS
decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"])
segmenter = get_amg(
predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None,
**({} if amg_kwargs is None else amg_kwargs)
)
else: # AMG
segmenter = get_amg(
predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs)
)

# Load the input image file.
if isinstance(input_path, np.ndarray):
image_data = input_path
Expand All @@ -77,6 +108,7 @@ def automatic_instance_segmentation(
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)
else:
Expand All @@ -88,6 +120,7 @@ def automatic_instance_segmentation(
ndim=ndim,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
)

segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
Expand Down Expand Up @@ -162,6 +195,11 @@ def main():
parser.add_argument(
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
)
parser.add_argument(
"-d", "--device", default=None,
help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
"By default the most performant available device will be selected."
)

args, parameter_args = parser.parse_known_args()

Expand All @@ -179,17 +217,20 @@ def _convert_argval(value):
parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

predictor, segmenter = get_predictor_and_segmenter(
model_type=args.model_type, checkpoint=args.checkpoint, device=args.device,
)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

automatic_instance_segmentation(
predictor=predictor,
segmenter=segmenter,
input_path=args.input_path,
output_path=args.output_path,
embedding_path=args.embedding_path,
model_type=args.model_type,
checkpoint_path=args.checkpoint,
key=args.key,
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
use_amg=args.amg,
**generate_kwargs,
)

Expand Down
Loading
Loading