diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 2561d8e2..3470ea4c 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -97,21 +97,10 @@ def automatic_instance_segmentation( else: image_data = util.load_image_data(input_path, key) - if ndim == 3 or image_data.ndim == 3: - if image_data.ndim != 3: - raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'") + if ndim == 2: + assert image_data.ndim == 2 or image_data.shape[-1] == 3, \ + f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}" - instances = automatic_3d_segmentation( - volume=image_data, - predictor=predictor, - segmentor=segmenter, - embedding_path=embedding_path, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - **generate_kwargs - ) - else: # Precompute the image embeddings. image_embeddings = util.precompute_image_embeddings( predictor=predictor, @@ -137,6 +126,20 @@ def automatic_instance_segmentation( instances = np.zeros(this_shape, dtype="uint32") else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) + else: + if image_data.ndim != 3: + raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'") + + instances = automatic_3d_segmentation( + volume=image_data, + predictor=predictor, + segmentor=segmenter, + embedding_path=embedding_path, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + **generate_kwargs + ) if output_path is not None: # Save the instance segmentation diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 7ecf41cd..f29cbb67 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -246,6 +246,11 @@ def __call__(self, x, y): # +def normalize_to_8bit(raw): + raw = normalize(raw) * 255 + return raw + + class ResizeRawTrafo: def __init__(self, desired_shape, do_rescaling=False, padding="constant"): self.desired_shape = desired_shape diff --git a/workshops/download_datasets.py b/workshops/download_datasets.py index f148cb73..e96d95ea 100644 --- a/workshops/download_datasets.py +++ b/workshops/download_datasets.py @@ -6,7 +6,7 @@ from torch_em.util.image import load_data -def _download_sample_data(path, data_dir, download, url, checksum): +def _download_sample_data(path, data_dir, url, checksum, download): if os.path.exists(data_dir): return @@ -23,7 +23,7 @@ def _get_cellpose_sample_data_paths(path, download): url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download" checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a" - _download_sample_data(path, data_dir, download, url, checksum) + _download_sample_data(path, data_dir, url, checksum, download) raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) @@ -31,7 +31,7 @@ def _get_cellpose_sample_data_paths(path, download): return raw_paths, label_paths -def _get_hpa_data_paths(path, download): +def _get_hpa_data_paths(path, split, download): urls = [ "https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train "https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val @@ -43,23 +43,26 @@ def _get_hpa_data_paths(path, download): "8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test ] splits = ["train", "val", "test"] + assert split in splits, f"'{split}' is not a valid split." - for url, checksum, split in zip(urls, checksums, splits): - data_dir = os.path.join(path, split) - _download_sample_data(path, data_dir, download, url, checksum) + for url, checksum, _split in zip(urls, checksums, splits): + data_dir = os.path.join(path, _split) + _download_sample_data(path, data_dir, url, checksum, download) - # NOTE: For visualization, we choose the train set. - raw_paths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.tif"))) - label_paths = natsorted(glob(os.path.join(data_dir, "train", "labels", "*.tif"))) + raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif"))) - return raw_paths, label_paths + if split == "test": # The 'test' split for HPA does not have labels. + return raw_paths, None + else: + label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif"))) + return raw_paths, label_paths def _get_dataset_paths(path, dataset_name, view=False): dataset_paths = { # 2d LM dataset for cell segmentation "cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True), - "hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True), + "hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"), # 3d LM dataset for nuclei segmentation "embedseg": lambda: datasets.embedseg_data.get_embedseg_paths( path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True, diff --git a/workshops/finetune_sam.py b/workshops/finetune_sam.py index 0d139340..be08b9bb 100644 --- a/workshops/finetune_sam.py +++ b/workshops/finetune_sam.py @@ -8,15 +8,9 @@ The functionalities shown here should work for your (microscopy) images too. """ - import os -from glob import glob -from pathlib import Path -from natsort import natsorted -from typing import Union, Tuple +from typing import Union, Tuple, Literal, List -import h5py -import numpy as np import imageio.v3 as imageio from matplotlib import pyplot as plt from skimage.measure import label as connected_components @@ -26,91 +20,80 @@ from torch_em.util.debug import check_loader from torch_em.util.util import get_random_colors -from torch_em.data.datasets.light_microscopy.hpa import get_hpa_segmentation_paths import micro_sam.training as sam_training +from micro_sam.training.util import normalize_to_8bit from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation +from download_datasets import _get_hpa_data_paths + -def download_dataset(path: Union[os.PathLike, str]) -> Tuple[str, str]: +def download_dataset( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = True, +) -> Tuple[List[str], List[str]]: """Download the HPA dataset. - This functionality downloads the images, assorts the input data (thanks to `torch-em`) - and stores the images and corresponding labels as `tif` files. + This functionality downloads the images and corresponding labels stored as `tif` files. Args: path: Filepath to the directory where the data will be stored. + split: The choice of data split. Either 'train', 'val' or 'test'. + download: Whether to download the dataset. Returns: - Filepath to the folder for the image data. - Filepath to the folder for the label data. + List of filepaths for the image data. + List of filepaths for the label data. """ - # Download the data into a directory - volume_paths = get_hpa_segmentation_paths(path=os.path.join(path, "hpa"), split="test", download=True) - - # Store inputs as tif files - image_dir = os.path.join(path, "hpa", "preprocessed", "images") - label_dir = os.path.join(path, "hpa", "preprocessed", "labels") - os.makedirs(image_dir, exist_ok=True) - os.makedirs(label_dir, exist_ok=True) - - for volume_path in volume_paths: - fname = Path(volume_path).stem - - with h5py.File(volume_path, "r") as f: - # Get the channel-wise inputs - image = np.stack( - [f["raw/microtubules"][:], f["raw/protein"][:], f["raw/nuclei"][:], f["raw/er"][:]], axis=-1 - ) - # labels = f["labels"][:] - - image_path = os.path.join(image_dir, f"{fname}.tif") - # label_path = os.path.join(label_dir, f"{fname}.tif") - - imageio.imwrite(image_path, image, compression="zlib") - # imageio.imwrite(label_path, labels, compression="zlib") - - print(f"The inputs have been preprocessed and stored at: '{os.path.join(path, 'hpa', 'preprocessed')}'") + data_path = os.path.join(path, "hpa") + image_paths, label_paths = _get_hpa_data_paths(path=data_path, split=split, download=download) + return image_paths, label_paths - return image_dir, label_dir +def verify_inputs(image_paths: List[str], label_paths: List[str]): + """Verify the downloaded inputs and preprocess them. -def verify_inputs(image_dir: Union[os.PathLike, str], label_dir: Union[os.PathLike, str]): - """Verify the downloaded inputs. + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. """ - image_paths = natsorted(glob(os.path.join(image_dir, "*.tif"))) - label_paths = natsorted(glob(os.path.join(label_dir, "*.tif"))) - for image_path, label_path in zip(image_paths, label_paths): image = imageio.imread(image_path) labels = imageio.imread(label_path) # The images should be of shape: H, W, 4 -> where, 4 is the number of channels. - print(f"Shape of inputs: '{image.shape}'") + if (image.ndim == 3 and image.shape[-1] == 3) or image.ndim == 2: + print(f"Inputs '{image.shape}' match the channel expectations.") + else: + print(f"Inputs '{image.shape}' must match the channel expectations (of either one or three channels).") + # The labels should be of shape: H, W print(f"Shape of corresponding labels: '{labels.shape}'") break # comment this line out in case you would like to verify the shapes for all inputs. -def preprocess_inputs(image_dir: Union[os.PathLike, str]): +def preprocess_inputs(image_paths: List[str]): """Preprocess the input images. - """ - image_paths = natsorted(glob(os.path.join(image_dir, "*.tif"))) + Args: + image_paths: List of filepaths for the image data. + """ # We remove the 'er' channel, i.e. the last channel. - for image_path in zip(image_paths): + for image_path in image_paths: image = imageio.imread(image_path) - image = image[..., :-1] - imageio.imwrite(image_path, image) + if image.ndim == 3 and image.shape[-1] == 4: # Convert 4 channel inputs to 3 channels. + image = image[..., :-1] + imageio.imwrite(image_path, image) -def visualize_inputs(image_dir: Union[os.PathLike, str], label_dir: Union[os.PathLike, str]): - """ - """ - image_paths = natsorted(glob(os.path.join(image_dir, "*.tif"))) - label_paths = natsorted(glob(os.path.join(label_dir, "*.tif"))) +def visualize_inputs(image_paths: List[str], label_paths: List[str]): + """Visualize the images and corresponding labels. + + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. + """ for image_path, label_path in zip(image_paths, label_paths): image = imageio.imread(image_path) labels = imageio.imread(label_path) @@ -132,25 +115,30 @@ def visualize_inputs(image_dir: Union[os.PathLike, str], label_dir: Union[os.Pat def get_dataloaders( - image_dir: Union[os.PathLike, str], - label_dir: Union[os.PathLike, str], + train_image_paths: List[str], + train_label_paths: List[str], + val_image_paths: List[str], + val_label_paths: List[str], view: bool, train_instance_segmentation: bool, ) -> Tuple[DataLoader, DataLoader]: - """ - """ - # Get filepaths to the image and corresponding label data. - image_paths = natsorted(glob(os.path.join(image_dir, "*.tif"))) - label_paths = natsorted(glob(os.path.join(label_dir, "*.tif"))) + """Get the HPA dataloaders for cell segmentation. + Args: + train_image_paths: List of filepaths for the training image data. + train_label_paths: List of filepaths for the training label data. + val_image_paths: List of filepaths for the validation image data. + val_label_paths: List of filepaths for the validation label data. + view: Whether to view the samples out of training dataloader. + train_instance_segmentation: Whether to finetune SAM with additional instance segmentation decoder. + + Returns: + The PyTorch DataLoader for training. + The PyTorch DataLoader for validation. + """ # Load images from tif stacks by setting `raw_key` and `label_key` to None. raw_key, label_key = None, None - # Split the image and corresponding labels to establish train-test split. - # Here, we select the first 2000 images for the train split and the other frames for the val split. - train_image_paths, val_image_paths = image_paths[:2000], image_paths[2000:] - train_label_paths, val_label_paths = label_paths[:2000], label_paths[2000:] - batch_size = 1 # the training batch size patch_shape = (512, 512) # the size of patches for training @@ -165,6 +153,8 @@ def get_dataloaders( with_segmentation_decoder=train_instance_segmentation, batch_size=batch_size, shuffle=True, + raw_transform=normalize_to_8bit, + n_samples=100, ) val_loader = sam_training.default_sam_loader( raw_paths=val_image_paths, @@ -177,6 +167,7 @@ def get_dataloaders( with_segmentation_decoder=train_instance_segmentation, batch_size=batch_size, shuffle=True, + raw_transform=normalize_to_8bit, ) if view: @@ -189,26 +180,36 @@ def run_finetuning( train_loader: DataLoader, val_loader: DataLoader, save_root: Union[os.PathLike, str], - train_instance_segmentation: bool + train_instance_segmentation: bool, + device: Union[torch.device, str], + model_type: str, + overwrite: bool, ) -> str: """ """ # All hyperparameters for training. n_objects_per_batch = 5 # the number of objects per batch that will be sampled - device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training - n_epochs = 10 # how long we train (in epochs) - - # The model_type determines which base model is used to initialize the weights that are finetuned. - # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results. - model_type = "vit_b" + n_epochs = 5 # how long we train (in epochs) # The name of the checkpoint. The checkpoints will be stored in './checkpoints/' checkpoint_name = "sam_hpa" + # Let's spot our best checkpoint and run inference for automatic instance segmentation. + if save_root is None: + save_root = os.getcwd() + + best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") + if os.path.exists(best_checkpoint) and not overwrite: + print( + "It looks like the training has completed. You must pass the argument '--overwrite' to overwrite " + "the already finetuned model (or provide a new filepath at '--save_root' for training new models)." + ) + return best_checkpoint + # Run training sam_training.train_sam( name=checkpoint_name, - save_root=os.path.join(save_root, "models"), + save_root=save_root, model_type=model_type, train_loader=train_loader, val_loader=val_loader, @@ -218,17 +219,14 @@ def run_finetuning( device=device, ) - # Let's spot our best checkpoint and download it to get started with the annotation tool - best_checkpoint = os.path.join(save_root, "models", "checkpoints", checkpoint_name, "best.pt") - return best_checkpoint -def run_automatic_instance_segmentation( +def run_instance_segmentation_with_decoder( + test_image_paths: List[str], model_type: str, checkpoint: Union[os.PathLike, str], device: Union[torch.device, str], - test_image_dir: Union[os.PathLike, str], ): """ """ @@ -237,16 +235,12 @@ def run_automatic_instance_segmentation( # Get the 'predictor' and 'segmenter' to perform automatic instance segmentation. predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, device=device) - # Let's check the last 10 images. Feel free to comment out the line below to run inference on all images. - image_paths = image_paths[-10:] - - for image_path in image_paths: + for image_path in test_image_paths: image = imageio.imread(image_path) + image = normalize_to_8bit(image) - # Predicted instances - prediction = run_automatic_instance_segmentation( - image=image, checkpoint_path=checkpoint, model_type=model_type, device=device - ) + # Predicting the instances. + prediction = automatic_instance_segmentation(predictor=predictor, segmenter=segmenter, input_path=image, ndim=2) # Visualize the predictions fig, ax = plt.subplots(1, 2, figsize=(10, 10)) @@ -265,44 +259,77 @@ def run_automatic_instance_segmentation( def main(): import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input_path", type=str, default="./data") - parser.add_argument("--view", action="store_true") + parser = argparse.ArgumentParser(description="Run finetuning for Segment Anything model for microscopy images.") + parser.add_argument( + "-i", "--input_path", type=str, default="./data", + help="The filepath to the folder where the image data will be downloaded. " + "By default, the data will be stored in your current working directory at './data'." + ) + parser.add_argument( + "-s", "--save_root", type=str, default=None, + help="The filepath to store the model checkpoint and tensorboard logs. " + "By default, they will be stored in your current working directory at 'checkpoints' and 'logs'." + ) + parser.add_argument( + "--view", action="store_true", + help="Whether to visualize the raw inputs, samples from the dataloader, instance segmentation outputs, etc." + ) + parser.add_argument( + "--overwrite", action="store_true", help="Whether to overwrite the already finetuned model checkpoints." + ) args = parser.parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" # the device / GPU used for training and inference. + + # The model_type determines which base model is used to initialize the weights that are finetuned. + # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results. + model_type = "vit_b" + + # Train an additional convolutional decoder for end-to-end automatic instance segmentation + train_instance_segmentation = True + # Step 1: Download the dataset. - image_dir, label_dir = download_dataset(path=args.input_path) + train_image_paths, train_label_paths = download_dataset(path=args.input_path, split="train") + val_image_paths, val_label_paths = download_dataset(path=args.input_path, split="val") + test_image_paths, _ = download_dataset(path=args.input_path, split="test") - # Step 2: Verify the spatial shape of inputs. - verify_inputs(image_dir=image_dir, label_dir=label_dir) + # Step 2: Verify the spatial shape of inputs (only for the 'train' split) + verify_inputs(image_paths=train_image_paths, label_paths=train_label_paths) # Step 3: Preprocess input images. - preprocess_inputs(image_dir=image_dir) - - # Step 4: Visualize the images and corresponding labels. - visualize_inputs(image_dir=image_dir, label_dir=label_dir) + preprocess_inputs(image_paths=train_image_paths) + preprocess_inputs(image_paths=val_image_paths) + preprocess_inputs(image_paths=test_image_paths) - # Step 5: Get the dataloaders. - # Train an additional convolutional decoder for end-to-end automatic instance segmentation - train_instance_segmentation = True + if args.view: + # Step 3(a): Visualize the images and corresponding labels (only for the 'train' split) + visualize_inputs(image_paths=train_image_paths, label_paths=train_label_paths) + # Step 4: Get the dataloaders. train_loader, val_loader = get_dataloaders( - image_dir=image_dir, - label_dir=label_dir, + train_image_paths=train_image_paths, + train_label_paths=train_label_paths, + val_image_paths=val_image_paths, + val_label_paths=val_label_paths, view=args.view, train_instance_segmentation=train_instance_segmentation, ) - # Step 6: Run the finetuning for Segment Anything Model. + # Step 5: Run the finetuning for Segment Anything Model. checkpoint_path = run_finetuning( train_loader=train_loader, val_loader=val_loader, save_root=args.save_root, train_instance_segmentation=train_instance_segmentation, + device=device, + model_type=model_type, + overwrite=args.overwrite, ) - # Step 7: Run automatic instance segmentation using the finetuned model. - run_automatic_instance_segmentation() + # Step 6: Run automatic instance segmentation using the finetuned model. + run_instance_segmentation_with_decoder( + test_image_paths=test_image_paths, model_type=model_type, checkpoint=checkpoint_path, device=device, + ) if __name__ == "__main__":