Skip to content

Commit

Permalink
Fix automatic segmentation cli - check tile shape and return expected…
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored and psobolewskiPhD committed Nov 29, 2024
1 parent d38bead commit 652900f
Showing 1 changed file with 12 additions and 32 deletions.
44 changes: 12 additions & 32 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
import numpy as np
import imageio.v3 as imageio

from torch_em.data.datasets.util import split_kwargs

from . import util
from .instance_segmentation import (
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase
)
from .multi_dimensional_segmentation import automatic_3d_segmentation

Expand All @@ -33,7 +30,7 @@ def get_predictor_and_segmenter(
Otherwise AIS will be used, which requires a special segmentation decoder.
If not specified AIS will be used if it is available and otherwise AMG will be used.
is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
kwargs: Keyword arguments for the automatic mask generation class.
kwargs: Keyword arguments for the automatic instance segmentation class.
Returns:
The Segment Anything model.
Expand All @@ -49,17 +46,20 @@ def get_predictor_and_segmenter(

if amg is None:
amg = "decoder_state" not in state

if amg:
decoder = None
else:
if "decoder_state" not in state:
raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.")
decoder_state = state["decoder_state"]
decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)

segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)

segmenter = get_amg(
predictor=predictor,
is_tiled=is_tiled,
decoder=decoder,
**kwargs
)
return predictor, segmenter


Expand Down Expand Up @@ -134,7 +134,6 @@ def automatic_instance_segmentation(
instances = np.zeros(this_shape, dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)

else:
if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
Expand Down Expand Up @@ -192,7 +191,7 @@ def main():
)
parser.add_argument(
"-c", "--checkpoint", default=None,
help="Checkpoint from which the SAM model will be loaded."
help="Checkpoint from which the SAM model will be loaded loaded."
)
parser.add_argument(
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
Expand All @@ -205,8 +204,7 @@ def main():
help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
)
parser.add_argument(
"--mode", type=str, default=None,
help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'."
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
)
parser.add_argument(
"-d", "--device", default=None,
Expand All @@ -226,33 +224,15 @@ def _convert_argval(value):

# NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
# the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
extra_kwargs = {
generate_kwargs = {
parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

# Separate extra arguments as per where they should be passed in the automatic segmentation class.
# This is done to ensure the extra arguments are allocated to the desired location.
# eg. for AMG, 'points_per_side' is expected by '__init__',
# and 'stability_score_thresh' is expected in 'generate' method.
amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)

# Validate for the expected automatic segmentation mode.
# By default, it is set to 'None', i.e. searches for the decoder state to prioritize AIS for finetuned models.
# Otherwise, runs AMG for all models in any case.
amg = None
if args.mode is not None:
assert args.mode in ["ais", "amg"], \
f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
amg = (args.mode == "amg")

predictor, segmenter = get_predictor_and_segmenter(
model_type=args.model_type,
checkpoint=args.checkpoint,
device=args.device,
amg=amg,
is_tiled=args.tile_shape is not None,
**amg_kwargs,
)

automatic_instance_segmentation(
Expand Down

0 comments on commit 652900f

Please sign in to comment.