From 766aa9bc0231cc63bd0c05625c8f60ba3af09620 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 14 Oct 2024 19:34:59 +0200 Subject: [PATCH] Add CLI for benchmarking datasets on SAM models (#728) Add scripts for benchmarking SAM models on microscopy datasets --- micro_sam/automatic_segmentation.py | 93 ++- micro_sam/evaluation/benchmark_datasets.py | 721 ++++++++++++++++++ micro_sam/evaluation/evaluation.py | 7 +- micro_sam/evaluation/inference.py | 18 +- .../multi_dimensional_segmentation.py | 47 +- micro_sam/multi_dimensional_segmentation.py | 15 +- micro_sam/prompt_generators.py | 8 +- micro_sam/training/training.py | 4 +- setup.cfg | 1 + test/test_automatic_segmentation.py | 59 +- test/test_training.py | 6 +- 11 files changed, 884 insertions(+), 95 deletions(-) create mode 100644 micro_sam/evaluation/benchmark_datasets.py diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 79043eee8..2561d8e2b 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Optional, Union, Tuple +from typing import Optional, Union, Tuple, Dict import numpy as np import imageio.v3 as imageio @@ -12,54 +12,85 @@ from .multi_dimensional_segmentation import automatic_3d_segmentation +def get_predictor_and_segmenter( + model_type: str, + checkpoint: Optional[Union[os.PathLike, str]] = None, + device: str = None, + amg: bool = False, + is_tiled: bool = False, + amg_kwargs: Dict = {} +) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: + """Get the Segment Anything model and class for automatic instance segmentation. + + Args: + model_type: The Segment Anything model choice. + checkpoint: The filepath to the stored model checkpoints. + device: The torch device. + amg: Whether to perform automatic segmentation in AMG mode. + is_tiled: Whether to return segmenter for performing segmentation in tiling window style. + + Returns: + The Segment Anything model. + The automatic instance segmentation class. + """ + # Get the device + device = util.get_device(device=device) + + # Get the predictor and state for Segment Anything models. + predictor, state = util.get_sam_model( + model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, + ) + + # Get the segmenter for automatic segmentation. + assert isinstance(amg_kwargs, Dict), "Please ensure 'amg_kwargs' gets arguments in a dictionary." + + segmenter = get_amg( + predictor=predictor, + is_tiled=is_tiled, + decoder=get_decoder( + image_encoder=predictor.model.image_encoder, + decoder_state=state["decoder_state"], + device=device + ) if "decoder_state" in state and not amg else None, + **amg_kwargs + ) + + return predictor, segmenter + + def automatic_instance_segmentation( + predictor: util.SamPredictor, + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, embedding_path: Optional[Union[os.PathLike, str]] = None, - model_type: str = util._DEFAULT_MODEL, - checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, - use_amg: bool = False, - amg_kwargs: Optional[Dict] = None, + verbose: bool = True, **generate_kwargs ) -> np.ndarray: """Run automatic segmentation for the input image. Args: + predictor: The Segment Anything model. + segmenter: The automatic instance segmentation class. input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr). output_path: The output path where the instance segmentations will be saved. embedding_path: The path where the embeddings are cached already / will be saved. - model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. - checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. ndim: The dimensionality of the data. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. - use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method. - amg_kwargs: optional keyword arguments for creating the AMG or AIS class. - generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class. + verbose: Verbosity flag. + generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. Returns: The segmentation result. """ - predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) - - if "decoder_state" in state and not use_amg: # AIS - decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"]) - segmenter = get_amg( - predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None, - **({} if amg_kwargs is None else amg_kwargs) - ) - else: # AMG - segmenter = get_amg( - predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs) - ) - # Load the input image file. if isinstance(input_path, np.ndarray): image_data = input_path @@ -77,6 +108,7 @@ def automatic_instance_segmentation( embedding_path=embedding_path, tile_shape=tile_shape, halo=halo, + verbose=verbose, **generate_kwargs ) else: @@ -88,6 +120,7 @@ def automatic_instance_segmentation( ndim=ndim, tile_shape=tile_shape, halo=halo, + verbose=verbose, ) segmenter.initialize(image=image_data, image_embeddings=image_embeddings) @@ -162,6 +195,11 @@ def main(): parser.add_argument( "--amg", action="store_true", help="Whether to use automatic mask generation with the model." ) + parser.add_argument( + "-d", "--device", default=None, + help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." + "By default the most performant available device will be selected." + ) args, parameter_args = parser.parse_known_args() @@ -179,17 +217,20 @@ def _convert_argval(value): parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) } + predictor, segmenter = get_predictor_and_segmenter( + model_type=args.model_type, checkpoint=args.checkpoint, device=args.device, + ) + automatic_instance_segmentation( + predictor=predictor, + segmenter=segmenter, input_path=args.input_path, output_path=args.output_path, embedding_path=args.embedding_path, - model_type=args.model_type, - checkpoint_path=args.checkpoint, key=args.key, ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, - use_amg=args.amg, **generate_kwargs, ) diff --git a/micro_sam/evaluation/benchmark_datasets.py b/micro_sam/evaluation/benchmark_datasets.py new file mode 100644 index 000000000..53d39d9b0 --- /dev/null +++ b/micro_sam/evaluation/benchmark_datasets.py @@ -0,0 +1,721 @@ +import os +import time +from glob import glob +from tqdm import tqdm +from natsort import natsorted +from typing import Union, Optional, List, Literal + +import numpy as np +import pandas as pd +import imageio.v3 as imageio +from skimage.measure import label as connected_components + +from nifty.tools import blocking + +import torch + +from torch_em.data import datasets + +from micro_sam import util + +from . import run_evaluation +from ..training.training import _filter_warnings +from .inference import run_inference_with_iterative_prompting +from .evaluation import run_evaluation_for_iterative_prompting +from .multi_dimensional_segmentation import segment_slices_from_ground_truth +from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter + + +LM_2D_DATASETS = [ + "livecell", "deepbacs", "tissuenet", "neurips_cellseg", "dynamicnuclearnet", + "hpa", "covid_if", "pannuke", "lizard", "orgasegment", "omnipose", "dic_hepg2", +] + +LM_3D_DATASETS = [ + "plantseg_root", "plantseg_ovules", "gonuclear", "mouse_embryo", "embegseg", "cellseg3d" +] + +EM_2D_DATASETS = ["mitolab_tem"] + +EM_3D_DATASETS = [ + "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse", + "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito", +] + +DATASET_RETURNS_FOLDER = { + "deepbacs": "*.tif" +} + +DATASET_CONTAINER_KEYS = { + "lucchi": ["raw", "labels"], +} + + +def _download_benchmark_datasets(path, dataset_choice): + """Ensures whether all the datasets have been downloaded or not. + + Args: + path: The path to directory where the supported datasets will be downloaded + for benchmarking Segment Anything models. + dataset_choice: The choice of dataset, expects the lower case name for the dataset. + + Returns: + List of choice of dataset(s). + """ + available_datasets = { + # Light Microscopy datasets + "livecell": lambda: datasets.livecell.get_livecell_data( + path=os.path.join(path, "livecell"), split="test", download=True, + ), + "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( + path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, + ), + "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( + path=os.path.join(path, "tissuenet"), split="test", download=True, + ), + "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( + root=os.path.join(path, "neurips_cellseg"), split="test", download=True, + ), + "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( + path=os.path.join(path, "plantseg"), download=True, name="root", + ), + "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( + path=os.path.join(path, "plantseg"), download=True, name="ovules", + ), + "covid_if": lambda: datasets.covid_if.get_covid_if_data( + path=os.path.join(path, "covid_if"), download=True, + ), + "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( + path=os.path.join(path, "hpa"), download=True, + ), + "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( + path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, + ), + "pannuke": lambda: datasets.pannuke.get_pannuke_data( + path=os.path.join(path, "pannuke"), download=True, folds=["fold_1", "fold_2", "fold_3"], + ), + "lizard": lambda: datasets.lizard.get_lizard_data( + path=os.path.join(path, "lizard"), download=True, + ), + "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( + path=os.path.join(path, "orgasegment"), split="eval", download=True, + ), + "omnipose": lambda: datasets.omnipose.get_omnipose_data( + path=os.path.join(path, "omnipose"), download=True, + ), + "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( + path=os.path.join(path, "gonuclear"), download=True, + ), + "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( + path=os.path.join(path, "mouse_embryo"), download=True, + ), + "embedseg_data": lambda: [ + datasets.embedseg_data.get_embedseg_data(path=os.path.join(path, "embedseg_data"), download=True, name=name) + for name in datasets.embedseg_data.URLS.keys() + ], + "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( + path=os.path.join(path, "cellseg_3d"), download=True, + ), + "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_data( + path=os.path.join(path, "dic_hepg2"), download=True, + ), + + # Electron Microscopy datasets + "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data( + path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True, + ), + "mitoem_human": lambda: datasets.mitoem.get_mitoem_data( + path=os.path.join(path, "mitoem"), samples="human", split="test", download=True, + ), + "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data( + path=os.path.join(path, "platynereis"), name="nuclei", download=True, + ), + "platynereis_cilia": lambda: datasets.platynereis.get_platy_data( + path=os.path.join(path, "platynereis"), name="cilia", download=True, + ), + "lucchi": lambda: datasets.lucchi.get_lucchi_data( + path=os.path.join(path, "lucchi"), split="test", download=True, + ), + "mitolab_3d": lambda: [ + datasets.cem.get_benchmark_data( + path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, + ) for dataset_id in range(1, 7) + ], + "mitolab_tem": lambda: datasets.cem.get_benchmark_data( + path=os.path.join(path, "mitolab"), dataset_id=7, download=True + ), + "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( + path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, + ), + "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( + path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, + ), + "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( + path=os.path.join(path, "uro_cell"), download=True, + ), + "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( + path=os.path.join(path, "sponge_em"), download=True, + ), + "vnc": lambda: datasets.vnc.get_vnc_data( + path=os.path.join(path, "vnc"), download=True, + ), + "asem_mito": lambda: datasets.asem.get_asem_data( + path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, + ) + } + + if dataset_choice is None: + dataset_choice = available_datasets.keys() + else: + if not isinstance(dataset_choice, list): + dataset_choice = [dataset_choice] + + for choice in dataset_choice: + if choice in available_datasets: + available_datasets[choice]() + else: + raise ValueError(f"'{choice}' is not a supported choice of dataset.") + + return dataset_choice + + +def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): + """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. + + Args: + path: The path to directory where the supported datasets have be downloaded + for benchmarking Segment Anything models. + dataset_choice: The name of the dataset of choice to extract crops. + crops_per_input: The maximum number of crops to extract per inputs. + extract_2d: Whether to extract 2d crops from 3d patches. + + Returns: + Filepath to the folder where extracted images are stored. + Filepath to the folder where corresponding extracted labels are stored. + The number of dimensions supported by the input. + """ + ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 + tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) + + # For 3d inputs, we extract both 2d and 3d crops. + extract_2d_crops_from_volumes = (ndim == 3) + + available_datasets = { + # Light Microscopy datasets + "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), + "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), + "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), + "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), + "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), + "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), + "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), + "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="test"), + "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), + "pannuke": lambda: datasets.pannuke.get_pannuke_paths(path=path), + "lizard": lambda: datasets.lizard.get_lizard_paths(parth=path), + "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), + "omnipose": lambda: datasets.omnipose.get_omnipose_paths(path=path, split="test"), + "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path-path), + "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), + "embedseg_data": lambda: datasets.embedseg_data.get_embedseg_paths( + path=path, name=list(datasets.embedseg_data.URLS.keys())[0], split="test" + ), + "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), + "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_paths(path=path, split="test"), + + # Electron Microscopy datasets + "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"), + "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"), + "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"), + "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), + "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), + "mitolab_3d": lambda: ( + [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]], + [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]] + ), + "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7), + "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), + "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), + "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), + "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), + "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), + "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) + } + + if ndim == 2: + image_paths, gt_paths = available_datasets[dataset_choice]() + + if dataset_choice in DATASET_RETURNS_FOLDER: + image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) + gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) + + image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) + assert len(image_paths) == len(gt_paths) + + paths_set = zip(image_paths, gt_paths) + + else: + image_paths = available_datasets[dataset_choice]() + if isinstance(image_paths, str): + paths_set = [image_paths] + else: + paths_set = natsorted(image_paths) + + # Directory where we store the extracted ROIs. + save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] + save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] + if extract_2d_crops_from_volumes: + save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) + save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) + + _dir_exists = [ + os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir) + ] + if all(_dir_exists): + return ndim + + [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] + [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] + + # Logic to extract relevant patches for inference + image_counter = 1 + for per_paths in tqdm(paths_set, desc=f"Extracting patches for {dataset_choice}"): + if ndim == 2: + image_path, gt_path = per_paths + image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) + else: + image_path = per_paths + image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) + gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) + + skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() + + # Ensure ground truth has instance labels. + gt = connected_components(gt) + + if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. + continue + + # Let's extract and save all the crops. + # NOTE: The first round of extraction is always to match the desired input dimensions. + image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) + image_counter = _save_image_label_crops( + image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] + ) + + # NOTE: The next round of extraction is to get 2d crops from 3d inputs. + if extract_2d_crops_from_volumes: + curr_tile_shape = tile_shape[-2:] # NOTE: We expect 2d tile shape for this stage. + + curr_image_crops, curr_gt_crops = [], [] + for per_z_im, per_z_gt in zip(image, gt): + curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() + + image_crops, gt_crops = _get_crops_for_input( + image=per_z_im, gt=per_z_gt, ndim=2, + tile_shape=curr_tile_shape, + skip_smaller_shape=curr_skip_smaller_shape, + crops_per_input=crops_per_input, + ) + curr_image_crops.extend(image_crops) + curr_gt_crops.extend(gt_crops) + + image_counter = _save_image_label_crops( + curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] + ) + + return ndim + + +def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): + tiling = blocking([0] * ndim, gt.shape, tile_shape) + n_tiles = tiling.numberOfBlocks + tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] + crop_boxes = [ + tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles + ] + n_ids = [idx for idx in range(len(crop_boxes))] + n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] + + # Extract the desired number of patches with higher number of instances. + image_crops, gt_crops = [], [] + for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): + crop_box = crop_boxes[per_id] + crop_image, crop_gt = image[crop_box], gt[crop_box] + # NOTE: We avoid using the crops which do not match the desired tile shape. + if skip_smaller_shape and crop_image.shape != tile_shape: + continue + + # NOTE: There could be a case where some later patches are invalid. + if per_n_instance == 1: + break + + image_crops.append(crop_image) + gt_crops.append(crop_gt) + + # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. + if len(image_crops) > 0 and i >= crops_per_input: + break + + return image_crops, gt_crops + + +def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): + for image_crop, gt_crop in tqdm( + zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" + ): + fname = f"{dataset_choice}_{image_counter:05}.tif" + assert image_crop.shape == gt_crop.shape + imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") + imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") + image_counter += 1 + + return image_counter + + +def _get_image_label_paths(path, ndim): + image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) + gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) + return image_paths, gt_paths + + +def _run_automatic_segmentation_per_dataset( + image_paths: List[Union[os.PathLike, str]], + gt_paths: List[Union[os.PathLike, str]], + model_type: str, + output_folder: Union[os.PathLike, str], + ndim: Optional[int] = None, + device: Optional[Union[torch.device, str]] = None, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, + run_amg: bool = False, + **auto_seg_kwargs +): + """Functionality to run automatic segmentation for multiple input files at once. + It stores the evaluated automatic segmentation results (quantitative). + + Args: + image_paths: List of filepaths for the input image data. + gt_paths: List of filepaths for the corresponding label data. + model_type: The choice of image encoder for the Segment Anything model. + output_folder: Filepath to the folder where we store all the results. + ndim: The number of input dimensions. + device: The torch device. + checkpoint_path: The filepath where the model checkpoints are stored. + run_amg: Whether to run automatic segmentation in AMG mode. + auto_seg_kwargs: Additional arguments for automatic segmentation parameters. + """ + experiment_name = "AMG" if run_amg else "AIS" + fname = f"{experiment_name.lower()}_{ndim}d" + + result_path = os.path.join(output_folder, "results", f"{fname}.csv") + prediction_dir = os.path.join(output_folder, fname, "inference") + if os.path.exists(prediction_dir): + return + + os.makedirs(prediction_dir, exist_ok=True) + + # Get the predictor (and the additional instance segmentation decoder, if available). + predictor, segmenter = get_predictor_and_segmenter( + model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, + ) + + for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): + output_path = os.path.join(prediction_dir, os.path.basename(image_path)) + if os.path.exists(output_path): + continue + + # Run Automatic Segmentation (AMG and AIS) + automatic_instance_segmentation( + predictor=predictor, + segmenter=segmenter, + input_path=image_path, + output_path=output_path, + ndim=ndim, + verbose=False, + **auto_seg_kwargs + ) + + prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) + run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) + + +def _run_interactive_segmentation_per_dataset( + image_paths: List[Union[os.PathLike, str]], + gt_paths: List[Union[os.PathLike, str]], + output_folder: Union[os.PathLike, str], + model_type: str, + prompt_choice: Literal["box", "points"], + device: Optional[Union[torch.device, str]] = None, + ndim: Optional[int] = None, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, +): + """Functionality to run interactive segmentation for multiple input files at once. + It stores the evaluated interactive segmentation results. + + Args: + image_paths: List of filepaths for the input image data. + gt_paths: List of filepaths for the corresponding label data. + output_folder: Filepath to the folder where we store all the results. + model_type: The choice of model type for Segment Anything. + prompt_choice: The choice of initial prompts to begin the interactive segmentation. + device: The torch device. + ndim: The number of input dimensions. + checkpoint_path: The filepath for stored checkpoints. + """ + if ndim == 2: + # Get the Segment Anything predictor. + predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) + + # Run interactive instance segmentation + # (starting with box and points followed by iterative prompt-based correction) + run_inference_with_iterative_prompting( + predictor=predictor, + image_paths=image_paths, + gt_paths=gt_paths, + embedding_dir=None, # We set this to None to compute embeddings on-the-fly. + prediction_dir=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), + start_with_box_prompt=(prompt_choice == "box"), + # TODO: add parameter for deform over box prompts (to simulate prompts in practice). + ) + + # Evaluate the interactive instance segmentation. + run_evaluation_for_iterative_prompting( + gt_paths=gt_paths, + prediction_root=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), + experiment_folder=output_folder, + start_with_box_prompt=(prompt_choice == "box"), + ) + + else: + save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") + if os.path.exists(save_path): + print( + f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." + ) + return + + results = [] + for image_path, gt_path in tqdm( + zip(image_paths, gt_paths), total=len(image_paths), + desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" + ): + prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") + os.makedirs(prediction_dir, exist_ok=True) + + prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) + if os.path.exists(prediction_path): + continue + + per_vol_result = segment_slices_from_ground_truth( + volume=imageio.imread(image_path), + ground_truth=imageio.imread(gt_path), + model_type=model_type, + checkpoint_path=checkpoint_path, + save_path=prediction_path, + device=device, + interactive_seg_mode=prompt_choice, + min_size=10, + ) + results.append(per_vol_result) + + results = pd.concat(results) + results = results.groupby(results.index).mean() + results.to_csv(save_path) + + +def _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, +): + seg_kwargs = { + "image_paths": image_paths, + "gt_paths": gt_paths, + "output_folder": output_folder, + "ndim": ndim, + "model_type": model_type, + "device": device, + "checkpoint_path": checkpoint_path, + } + + # Perform: + # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) + # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) + # Else, it runs for AMG. + # Next, we check if the user expects to run AMG as well (after the run for AIS). + + # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). + _run_automatic_segmentation_per_dataset(run_amg=False, **seg_kwargs) + + # ii. Run automatic mask generation (AMG) (in case the first run is AIS). + _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) + + # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) + _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) + _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) + + +def _clear_cached_items(retain, path, output_folder): + import shutil + from pathlib import Path + + REMOVE_LIST = ["data", "crops", "auto", "int"] + if retain is None: + remove_list = REMOVE_LIST + else: + assert isinstance(retain, list) + remove_list = set(REMOVE_LIST) - set(retain) + + paths = [] + # Stage 1: Remove inputs. + if "data" in remove_list or "crops" in remove_list: + all_paths = glob(os.path.join(path, "*")) + + # In case we want to remove both data and crops, we remove the data folder entirely. + if "data" in remove_list and "crops" in remove_list: + paths.extend(all_paths) + return + + # Next, we verify whether the we only remove either of data or crops. + for curr_path in all_paths: + if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: + paths.append(curr_path) + elif "data" in remove_list: + paths.append(curr_path) + + # Stage 2: Remove predictions + if "auto" in remove_list: + paths.extend(glob(os.path.join(output_folder, "amg_*"))) + paths.extend(glob(os.path.join(output_folder, "ais_*"))) + + if "int" in remove_list: + paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) + + [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] + + +def run_benchmark_evaluations( + input_folder: Union[os.PathLike, str], + dataset_choice: str, + model_type: str = util._DEFAULT_MODEL, + output_folder: Optional[Union[str, os.PathLike]] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + run_amg: bool = False, + retain: Optional[List[str]] = None, + ignore_warnings: bool = False, +): + """Run evaluation for benchmarking Segment Anything models on microscopy datasets. + + Args: + input_folder: The path to directory where all inputs will be stored and preprocessed. + dataset_choice: The dataset choice. + model_type: The model choice for SAM. + output_folder: The path to directory where all outputs will be stored. + checkpoint_path: The checkpoint path + run_amg: Whether to run automatic segmentation in AMG mode. + retain: Whether to retain certain parts of the benchmark runs. + By default, removes everything besides quantitative results. + There is the choice to retain 'data', 'crops', 'auto', or 'int'. + ignore_warnings: Whether to ignore warnings. + """ + start = time.time() + + with _filter_warnings(ignore_warnings): + device = util._get_default_device() + + # Ensure if all the datasets have been installed by default. + dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) + + for choice in dataset_choice: + output_folder = os.path.join(output_folder, choice) + result_dir = os.path.join(output_folder, "results") + if os.path.exists(result_dir): + continue + + os.makedirs(result_dir, exist_ok=True) + + data_path = os.path.join(input_folder, choice) + + # Extrapolate desired set from the datasets: + # a. for 2d datasets - 2d patches with the most number of labels present + # (in case of volumetric data, choose 2d patches per slice). + # b. for 3d datasets - 3d regions of interest with the most number of labels present. + ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) + + # Run inference and evaluation scripts on benchmark datasets. + image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) + _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg + ) + + # Run inference and evaluation scripts on '2d' crops for volumetric datasets + if ndim == 3: + image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) + _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg + ) + + _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) + + diff = time.time() - start + hours, rest = divmod(diff, 3600) + minutes, seconds = divmod(rest, 60) + print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s") + + +def main(): + """@private""" + import argparse + + available_models = list(util.get_model_names()) + available_models = ", ".join(available_models) + + parser = argparse.ArgumentParser( + description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." + ) + parser.add_argument( + "-i", "--input_folder", type=str, required=True, + help="The path to a directory where the microscopy datasets are / will be stored." + ) + parser.add_argument( + "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, + help=f"The segment anything model that will be used, one of {available_models}." + ) + parser.add_argument( + "-c", "--checkpoint_path", type=str, default=None, + help="Checkpoint from which the SAM model will be loaded loaded." + ) + parser.add_argument( + "-d", "--dataset_choice", type=str, nargs='*', default=None, + help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified." + ) + parser.add_argument( + "-o", "--output_folder", type=str, required=True, + help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'." + ) + parser.add_argument( + "--amg", action="store_true", + help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." + ) + parser.add_argument( + "--retain", nargs="*", default=None, + help="By default, the functionality removes all besides quantitative results required for running benchmarks. " + "In case you would like to retain parts of the benchmark evaluation for visualization / reproducability, " + "you should choose one or multiple of 'data', 'crops', 'auto', 'int'. " + "where they are responsible for either retaining original inputs / extracted crops / " + "predictions of automatic segmentation / predictions of interactive segmentation, respectively." + ) + args = parser.parse_args() + + run_benchmark_evaluations( + input_folder=args.input_folder, + dataset_choice=args.dataset_choice, + model_type=args.model_type, + output_folder=args.output_folder, + checkpoint_path=args.checkpoint_path, + run_amg=args.amg, + retain=args.retain, + ignore_warnings=True, + ) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/evaluation/evaluation.py b/micro_sam/evaluation/evaluation.py index a52a11266..869334fc1 100644 --- a/micro_sam/evaluation/evaluation.py +++ b/micro_sam/evaluation/evaluation.py @@ -62,9 +62,7 @@ def run_evaluation( msas, sa50s, sa75s = _run_evaluation(gt_paths, prediction_paths, verbose=verbose) results = pd.DataFrame.from_dict({ - "msa": [np.mean(msas)], - "sa50": [np.mean(sa50s)], - "sa75": [np.mean(sa75s)], + "mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)], }) if save_path is not None: @@ -110,7 +108,7 @@ def run_evaluation_for_iterative_prompting( # If the results have been computed already, it's not needed to re-run it again. if os.path.exists(csv_path): - print(pd.read_csv(csv_path)) + print(f"Results with iterative prompting for interactive segmentation are already stored at '{csv_path}'.") return list_of_results = [] @@ -120,7 +118,6 @@ def run_evaluation_for_iterative_prompting( pred_paths = sorted(glob(os.path.join(pred_folder, "*"))) result = run_evaluation(gt_paths=gt_paths, prediction_paths=pred_paths, save_path=None) list_of_results.append(result) - print(result) res_df = pd.concat(list_of_results, ignore_index=True) res_df.to_csv(csv_path) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index e1736fa5d..b033055f4 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -473,21 +473,21 @@ def run_inference_with_iterative_prompting( gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], - start_with_box_prompt: bool, + start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False ) -> None: - """Run segment anything inference for multiple images using prompts iteratively - derived from model outputs and groundtruth + """Run Segment Anything inference for multiple images using prompts iteratively + derived from model outputs and ground-truth. Args: - predictor: The SegmentAnything predictor. + predictor: The Segment Anything predictor. image_paths: The image file paths. gt_paths: The ground-truth segmentation file paths. embedding_dir: The directory where the image embeddings will be saved or are already saved. - prediction_dir: The directory where the predictions from SegmentAnything will be saved per iteration. + prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. start_with_box_prompt: Whether to use the first prompt as bounding box or a single point dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled. @@ -506,8 +506,7 @@ def run_inference_with_iterative_prompting( print("The iterative prompting will make use of logits masks from previous iterations.") for image_path, gt_path in tqdm( - zip(image_paths, gt_paths), - total=len(image_paths), + zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with iterative prompting for all images", ): image_name = os.path.basename(image_path) @@ -524,7 +523,10 @@ def run_inference_with_iterative_prompting( gt = imageio.imread(gt_path).astype("uint32") gt = relabel_sequential(gt)[0] - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + if embedding_dir is None: + embedding_path = None + else: + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") _run_inference_with_iterative_prompting_for_image( predictor, image, gt, start_with_box_prompt=start_with_box_prompt, diff --git a/micro_sam/evaluation/multi_dimensional_segmentation.py b/micro_sam/evaluation/multi_dimensional_segmentation.py index e54cafb59..07b5820f0 100644 --- a/micro_sam/evaluation/multi_dimensional_segmentation.py +++ b/micro_sam/evaluation/multi_dimensional_segmentation.py @@ -6,6 +6,8 @@ from itertools import product from typing import Union, Tuple, Optional, List, Dict +import imageio.v3 as imageio + import torch from elf.evaluation import mean_segmentation_accuracy @@ -58,8 +60,9 @@ def segment_slices_from_ground_truth( volume: np.ndarray, ground_truth: np.ndarray, model_type: str, - checkpoint_path: Union[str, os.PathLike], - embedding_path: Union[str, os.PathLike], + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, + save_path: Optional[Union[str, os.PathLike]] = None, iou_threshold: float = 0.8, projection: Union[str, dict] = "mask", box_extension: Union[float, int] = 0.025, @@ -81,6 +84,7 @@ def segment_slices_from_ground_truth( model_type: Choice of segment anything model. checkpoint_path: Path to the model checkpoint. embedding_path: Path to cache the computed embeddings. + save_path: Path to store the segmentations. iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. projection: The projection (prompting) method to generate prompts for consecutive slices. box_extension: Extension factor for increasing the box size after projection. @@ -97,7 +101,7 @@ def segment_slices_from_ground_truth( # Compute the image embeddings embeddings = util.precompute_image_embeddings( - predictor=predictor, input_=volume, save_path=embedding_path, ndim=3 + predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, ) # Compute instance ids (without the background) @@ -133,7 +137,7 @@ def segment_slices_from_ground_truth( _get_points, _get_box = False, True else: raise ValueError( - "The provided interactive prompting for the first slice isn't supported.", + f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." "Please choose from 'box' / 'points'." ) @@ -145,14 +149,20 @@ def segment_slices_from_ground_truth( get_box_prompts=_get_box ) _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) - point_prompts, point_labels, box_prompts, _ = prompt_generator(this_slice_seg, [box_coords[1]]) + point_prompts, point_labels, box_prompts, _ = prompt_generator( + segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), + bbox_coordinates=[box_coords[1]], + ) # Prompt-based segmentation on middle slice of the current object output_slice = batched_inference( - predictor=predictor, image=volume[slice_choice], batch_size=1, + predictor=predictor, + image=volume[slice_choice], + batch_size=1, boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, - point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels + point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, + verbose_embeddings=verbose, ) output_seg = np.zeros_like(ground_truth) output_seg[slice_choice][output_slice == 1] = 1 @@ -173,18 +183,25 @@ def segment_slices_from_ground_truth( # Store the entire segmented object final_segmentation[this_seg == 1] = label_id + # Save the volumetric segmentation + if save_path is not None: + imageio.imwrite(save_path, final_segmentation, compression="zlib") + # Evaluate the volumetric segmentation if skipped_label_ids: - gt_copy = ground_truth.copy() - gt_copy[np.isin(gt_copy, skipped_label_ids)] = 0 - msa = mean_segmentation_accuracy(final_segmentation, gt_copy) + curr_gt = ground_truth.copy() + curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 else: - msa = mean_segmentation_accuracy(final_segmentation, ground_truth) + curr_gt = ground_truth + + msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True) + results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} + results = pd.DataFrame.from_dict([results]) if return_segmentation: - return msa, final_segmentation + return results, final_segmentation else: - return msa + return results def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values): @@ -266,7 +283,7 @@ def run_multi_dimensional_segmentation_grid_search( net_list = [] for gs_kwargs in tqdm(gs_combinations): - msa = segment_slices_from_ground_truth( + results = segment_slices_from_ground_truth( volume=volume, ground_truth=ground_truth, model_type=model_type, @@ -279,7 +296,7 @@ def run_multi_dimensional_segmentation_grid_search( **gs_kwargs ) - result_dict = {"mSA": msa, **gs_kwargs} + result_dict = {**results, **gs_kwargs} tmp_df = pd.DataFrame([result_dict]) net_list.append(tmp_df) diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index fd44ef64b..7ecaf14e6 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -397,7 +397,13 @@ def automatic_3d_segmentation( min_object_size = kwargs.pop("min_object_size", 0) image_embeddings = util.precompute_image_embeddings( - predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, tile_shape=tile_shape, halo=halo, + predictor=predictor, + input_=volume, + save_path=embedding_path, + ndim=3, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, ) for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): @@ -415,7 +421,12 @@ def automatic_3d_segmentation( segmentation[i] = seg segmentation = merge_instance_segmentation_3d( - segmentation, beta=0.5, with_background=with_background, gap_closing=gap_closing, min_z_extent=min_z_extent + segmentation, + beta=0.5, + with_background=with_background, + gap_closing=gap_closing, + min_z_extent=min_z_extent, + verbose=verbose, ) return segmentation diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 839077410..df521e4fb 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -191,12 +191,8 @@ def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): coord_list, label_list = [], [] - coord_list, label_list = self._sample_positive_points( - object_mask[0], center_coords, coord_list, label_list - ) - coord_list, label_list = self._sample_negative_points( - object_mask[0], bbox_coords, coord_list, label_list - ) + coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) + coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) all_coords.append(coord_list) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 165a10ae9..6c66ccb39 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -235,8 +235,8 @@ def train_sam( t_start = time.time() - _check_loader(train_loader, with_segmentation_decoder, verify_n_labels_in_loader) - _check_loader(val_loader, with_segmentation_decoder, verify_n_labels_in_loader) + _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) + _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) device = get_device(device) # Get the trainable segment anything model. diff --git a/setup.cfg b/setup.cfg index 2c6b5e38b..d7f976b28 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,7 @@ console_scripts = micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main micro_sam.precompute_embeddings = micro_sam.precompute_state:main micro_sam.automatic_segmentation = micro_sam.automatic_segmentation:main + micro_sam.benchmark_sam = micro_sam.evaluation.benchmark_datasets:main # make sure it gets included in your package [options.package_data] diff --git a/test/test_automatic_segmentation.py b/test/test_automatic_segmentation.py index 47b460f45..e0bb4287a 100644 --- a/test/test_automatic_segmentation.py +++ b/test/test_automatic_segmentation.py @@ -66,89 +66,92 @@ def tearDown(self): torch.mps.empty_cache() def test_automatic_mask_generator_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.mask, self.image + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=False, amg_kwargs={"points_per_side": 4} + ) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type, ndim=2, use_amg=True, - amg_kwargs={"points_per_side": 4} + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, ) self.assertEqual(mask.shape, instances.shape) def test_tiled_automatic_mask_generator_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.large_mask, self.large_image + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=True, amg_kwargs={"points_per_side": 4} + ) instances = automatic_instance_segmentation( - input_path=image, - model_type=self.model_type, - ndim=2, - tile_shape=self.tile_shape, - halo=self.halo, - use_amg=True, - amg_kwargs={"points_per_side": 4} + predictor=predictor, segmenter=segmenter, input_path=image, + ndim=2, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(mask.shape, instances.shape) def test_instance_segmentation_with_decoder_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.mask, self.image + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type_ais, ndim=2 + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, ) self.assertEqual(mask.shape, instances.shape) def test_tiled_instance_segmentation_with_decoder_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.large_mask, self.large_image + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type_ais, + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(mask.shape, instances.shape) @unittest.skip("Skipping long running tests by default.") def test_automatic_mask_generator_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=False) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type, ndim=3, use_amg=True + predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) self.assertEqual(labels.shape, instances.shape) @unittest.skip("Skipping long running tests by default.") def test_tiled_automatic_mask_generator_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=True) instances = automatic_instance_segmentation( - input_path=volume, - model_type=self.model_type, - ndim=3, - tile_shape=self.tile_shape, - halo=self.halo, - use_amg=True, + predictor=predictor, segmenter=segmenter, input_path=volume, + ndim=3, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(labels.shape, instances.shape) def test_instance_segmentation_with_decoder_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type_ais, ndim=3, + predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) self.assertEqual(labels.shape, instances.shape) def test_tiled_instance_segmentation_with_decoder_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type_ais, ndim=3, tile_shape=self.tile_shape, halo=self.halo, + predictor=predictor, segmenter=segmenter, input_path=volume, + ndim=3, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(labels.shape, instances.shape) diff --git a/test/test_training.py b/test/test_training.py index 2ad809fda..d89a15978 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -125,7 +125,7 @@ def _run_inference_and_check_results( self.assertEqual(len(pred_paths), len(label_paths)) eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False) - result = eval_res["sa50"].values.item() + result = eval_res["SA50"].values.item() # We check against the expected segmentation accuracy. self.assertGreater(result, expected_sa) @@ -172,7 +172,7 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=box_inference, expected_sa=0.95, + inference_function=box_inference, expected_sa=0.8, ) # Check the model with interactive inference. @@ -184,7 +184,7 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=iterative_inference, expected_sa=0.95, + inference_function=iterative_inference, expected_sa=0.8, )