diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index e6d908f9e..6cf9a4868 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -87,7 +87,8 @@ def automatic_instance_segmentation( embedding_path: The path where the embeddings are cached already / will be saved. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. - ndim: The dimensionality of the data. + ndim: The dimensionality of the data. By default the dimensionality of the data will be used. + If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. @@ -102,21 +103,12 @@ def automatic_instance_segmentation( else: image_data = util.load_image_data(input_path, key) - if ndim == 3 or image_data.ndim == 3: - if image_data.ndim != 3: - raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'") + ndim = image_data.ndim if ndim is None else ndim + + if ndim == 2: + if image_data.ndim != 2 or image_data.shape[-1] != 3: + raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") - instances = automatic_3d_segmentation( - volume=image_data, - predictor=predictor, - segmentor=segmenter, - embedding_path=embedding_path, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - **generate_kwargs - ) - else: # Precompute the image embeddings. image_embeddings = util.precompute_image_embeddings( predictor=predictor, @@ -142,6 +134,20 @@ def automatic_instance_segmentation( instances = np.zeros(this_shape, dtype="uint32") else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) + else: + if image_data.ndim != 3 or image_data.shape[-1] != 3: + raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") + + instances = automatic_3d_segmentation( + volume=image_data, + predictor=predictor, + segmentor=segmenter, + embedding_path=embedding_path, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + **generate_kwargs + ) if output_path is not None: # Save the instance segmentation diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 7ecf41cd0..f29cbb670 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -246,6 +246,11 @@ def __call__(self, x, y): # +def normalize_to_8bit(raw): + raw = normalize(raw) * 255 + return raw + + class ResizeRawTrafo: def __init__(self, desired_shape, do_rescaling=False, padding="constant"): self.desired_shape = desired_shape diff --git a/notebooks/sam_finetuning.ipynb b/notebooks/sam_finetuning.ipynb index 3c8f007cc..b0bc80a17 100644 --- a/notebooks/sam_finetuning.ipynb +++ b/notebooks/sam_finetuning.ipynb @@ -263,9 +263,7 @@ "import torch\n", "\n", "import torch_em\n", - "from torch_em.model import UNETR\n", "from torch_em.util.debug import check_loader\n", - "from torch_em.loss import DiceBasedDistanceLoss\n", "from torch_em.util.util import get_random_colors\n", "from torch_em.transform.label import PerObjectDistanceTransform\n", "\n", @@ -610,9 +608,9 @@ "# It supports image data in various formats. Here, we load image data and labels from the two\n", "# folders with tif images that were downloaded by the example data functionality, by specifying\n", "# `raw_key` and `label_key` as `*.tif`. This means all images in the respective folders that end with\n", - "# .tif will be loadded.\n", + "# .tif will be loaded.\n", "# The function supports many other file formats. For example, if you have tif stacks with multiple slices\n", - "# instead of multiple tif images in a foldder, then you can pass raw_key=label_key=None.\n", + "# instead of multiple tif images in a folder, then you can pass raw_key=label_key=None.\n", "# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n", "\n", "# Load images from multiple files in folder via pattern (here: all tif files)\n", @@ -950,13 +948,15 @@ "outputs": [], "source": [ "def run_automatic_instance_segmentation(image, checkpoint_path, model_type=\"vit_b_lm\", device=None):\n", - " \"\"\"Automatic Instance Segmentation by training an additional instance decoder in SAM.\n", + " \"\"\"Automatic Instance Segmentation trained with an additional instance segmentation decoder in SAM.\n", "\n", " NOTE: It is supported only for `µsam` models.\n", " \n", " Args:\n", " image: The input image.\n", + " checkpoint: Filepath to the model checkpoints.\n", " model_type: The choice of the `µsam` model.\n", + " device: The torch device.\n", " \n", " Returns:\n", " The instance segmentation.\n", @@ -1392,7 +1392,7 @@ "assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n", "assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n", "\n", - "# # Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n", + "# Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n", "image_paths = image_paths[:5]\n", "\n", "for image_path in image_paths:\n", @@ -1400,10 +1400,7 @@ " \n", " # Predicted instances\n", " prediction = run_automatic_instance_segmentation(\n", - " image=image,\n", - " checkpoint_path=best_checkpoint,\n", - " model_type=model_type,\n", - " device=device\n", + " image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n", " )\n", "\n", " # Visualize the predictions\n", diff --git a/workshops/README.md b/workshops/README.md new file mode 100644 index 000000000..ae14719d5 --- /dev/null +++ b/workshops/README.md @@ -0,0 +1,102 @@ +# Hands-On Analysis using `micro-sam` + +## Upcoming Workshops: +1. I2K 2024 (Milan, Italy) +2. Virtual I2K 2024 (Online) + +## Introduction + +In this document, we walk you through different steps involved to participate in hands-on image annotation experiments our tool. + +- Here is our [official documentation](https://computational-cell-analytics.github.io/micro-sam/) for detailed explanation of our tools, library and the finetuned models. +- Here is the playlist for our [tutorial videos](https://youtube.com/playlist?list=PLwYZXQJ3f36GQPpKCrSbHjGiH39X4XjSO&si=3q-cIRD6KuoZFmAM) hosted on YouTube, elaborating in detail on the features of our tools. + +## Steps: + +### Step 1: Download the Datasets + +- We provide the script `download_datasets.py` for automatic download of datasets to be used for interactive annotation using `micro-sam`. +- You can run the script as follows: +```bash +$ python download_datasets.py -i -d +``` +where, `DATA_DIRECTORY` is the filepath to the directory where the datasets will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_datasets.py -h` in the terminal for more details). + +> NOTE: We have chosen a) subset of the CellPose `cyto` dataset, b) one volume from the EmbedSeg `Mouse-Skull-Nuclei-CBG` dataset from the train split (namely, `X1.tif`), c) one volume from the Platynereis `Membrane` dataset from the train split (namely, `train_data_membrane_02.n5`) and d) the entire `HPA` dataset for the following tasks in `micro-sam`. + +### Step 2: Download the Precomputed Embeddings + +- We provide the script `download_embeddings.py` for automatic download of precompute image embeddings for volumetric data to be used for interactive annotation using `micro-sam`. +- You can run the script as follows: + +```bash +$ python download_embeddings -e -d +``` +where, `EMBEDDING_DIRECTORY` is the filepath to the directory where the precomputed image embeddings will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_embeddings.py -h` in the terminal for more details). + +### Additional Section: Precompute the Embeddings Yourself! + +Here is an example guide to precompute the image embeddings (eg. for volumetric data). + +#### EmbedSeg + +```bash +$ micro_sam.precompute_embeddings -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where inputs are stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm'). + -e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1 # Filepath where computed embeddings will be cached. +``` + +#### Platynereis + +```bash +$ micro_sam.precompute_embeddings -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where inputs are stored. + -k volumes/raw/s1 # Key to access the data group in container-style data structures. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles'). + -e embeddings/platynereis/vit_b/platynereis_train_data_membrane_02 # Filepath where computed embeddings will be cached. +``` + +### Step 3: Run the `micro-sam` Annotators (WIP) + +Run the `micro-sam` annotators with the following scripts: + +We recommend using the napari GUI for the interactive annotation. You can use the widget to specify all the essential parameters (eg. the choice of model, the filepath to the precomputed embeddings, etc). + +TODO: add more details here. + +There is another option to use `micro-sam`'s CLI to start our annotator tools. + +#### 2D Annotator (Cell Segmentation in Light Microscopy): + +```bash +$ micro_sam.annotator_2d -i data/cellpose/cyto/test/... # Filepath where the 2d image is stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') + [OPTIONAL] -e embeddings/cellpose/vit_b/... # Filepath where the computed embeddings will be cached (you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### 3D Annotator (EmbedSeg - Nuclei Segmentation in Light Microscopy): + +```bash +$ micro_sam.annotator_3d -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where the 3d volume is stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') + -e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1.zarr # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### 3D Annotator (Platynereis - Membrane Segmentation in Electron Microscopy): + +```bash +$ micro_sam.annotator_3d -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where the 2d image is stored. + -k volumes/raw/s1 # Key to access the data group in container-style data structures. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles') + -e embeddings/platynereis/vit_b/... # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### Image Series Annotator (Multiple Light Microscopy 2D Images for Cell Segmentation): + +```bash +$ micro_sam.image_series_annotator -i ... + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') +``` + +### Step 4: Finetune Segment Anything on Microscopy Images (WIP) + +- We provide a notebook `finetune_sam.ipynb` / `finetune_sam.py` for finetuning Segment Anything Model for cell segmentation in confocal microscopy images. diff --git a/workshops/download_datasets.py b/workshops/download_datasets.py new file mode 100644 index 000000000..e96d95eaf --- /dev/null +++ b/workshops/download_datasets.py @@ -0,0 +1,140 @@ +import os +from glob import glob +from natsort import natsorted + +from torch_em.data import datasets +from torch_em.util.image import load_data + + +def _download_sample_data(path, data_dir, url, checksum, download): + if os.path.exists(data_dir): + return + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "data.zip") + datasets.util.download_source(path=zip_path, url=url, download=download, checksum=checksum) + datasets.util.unzip(zip_path=zip_path, dst=path) + + +def _get_cellpose_sample_data_paths(path, download): + data_dir = os.path.join(path, "cellpose", "cyto", "test") + + url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download" + checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a" + + _download_sample_data(path, data_dir, url, checksum, download) + + raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) + label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) + + return raw_paths, label_paths + + +def _get_hpa_data_paths(path, split, download): + urls = [ + "https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train + "https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val + "https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download", # test + ] + checksums = [ + "6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab", # train + "4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c", # val + "8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test + ] + splits = ["train", "val", "test"] + assert split in splits, f"'{split}' is not a valid split." + + for url, checksum, _split in zip(urls, checksums, splits): + data_dir = os.path.join(path, _split) + _download_sample_data(path, data_dir, url, checksum, download) + + raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif"))) + + if split == "test": # The 'test' split for HPA does not have labels. + return raw_paths, None + else: + label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif"))) + return raw_paths, label_paths + + +def _get_dataset_paths(path, dataset_name, view=False): + dataset_paths = { + # 2d LM dataset for cell segmentation + "cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True), + "hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"), + # 3d LM dataset for nuclei segmentation + "embedseg": lambda: datasets.embedseg_data.get_embedseg_paths( + path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True, + ), + # 3d EM dataset for membrane segmentation + "platynereis": lambda: datasets.platynereis.get_platynereis_paths( + path=os.path.join(path, "platynereis"), sample_ids=None, name="cells", download=True, + ), + } + + dataset_keys = { + "cellpose": [None, None], + "embedseg": [None, None], + "platynereis": ["volumes/raw/s1", "volumes/labels/segmentation/s1"] + } + + if dataset_name is None: # Download all datasets. + dataset_names = list(dataset_paths.keys()) + else: # Download specific datasets. + dataset_names = [dataset_name] + + for dname in dataset_names: + if dname not in dataset_paths: + raise ValueError( + f"'{dname}' is not a supported dataset enabled for download. " + f"Please choose from {list(dataset_paths.keys())}." + ) + + paths = dataset_paths[dname]() + print(f"'{dataset_name}' is download at {path}.") + + if view: + import napari + + if isinstance(paths, tuple): # datasets with explicit raw and label paths + raw_paths, label_paths = paths + else: + raw_paths = label_paths = paths + + raw_key, label_key = dataset_keys[dname] + for raw_path, label_path in zip(raw_paths, label_paths): + raw = load_data(raw_path, raw_key) + labels = load_data(label_path, label_key) + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(labels) + napari.run() + + break # comment this line out in case you would like to visualize all samples. + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Download the dataset necessary for the workshop.") + parser.add_argument( + "-i", "--input_path", type=str, default="./data", + help="The filepath to the folder where the image data will be downloaded. " + "By default, the data will be stored in your current working directory at './data'." + ) + parser.add_argument( + "-d", "--dataset_name", type=str, default=None, + help="The choice of dataset you would like to download. By default, it downloads all the datasets. " + "Optionally, you can choose to download either of 'cellpose', 'hpa', 'embedseg' or 'platynereis'." + ) + parser.add_argument( + "-v", "--view", action="store_true", help="Whether to view the downloaded data." + ) + args = parser.parse_args() + + _get_dataset_paths(path=args.input_path, dataset_name=args.dataset_name, view=args.view) + + +if __name__ == "__main__": + main() diff --git a/workshops/download_embeddings.py b/workshops/download_embeddings.py new file mode 100644 index 000000000..6df079f53 --- /dev/null +++ b/workshops/download_embeddings.py @@ -0,0 +1,89 @@ +import os + +from torch_em.data.datasets.util import download_source, unzip + + +URLS = { + "lucchi": [ + "https://owncloud.gwdg.de/index.php/s/kQMA1B8L9LOvYrl/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/U8xs6moRg0cQhkS/download", # vit_b_em_organelles + ], + "embedseg": [ + "https://owncloud.gwdg.de/index.php/s/EF9ZdMzYjDjl8fd/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/7IVekm8K7ln7yQ6/download", # vit_b_lm + ], + "platynereis": [ + "https://owncloud.gwdg.de/index.php/s/1OgOEeMIK9Ok2Kj/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/i9DrXe6YFL8jvgP/download", # vit_b_em_organelles + ], +} + +CHECKSUMS = { + "lucchi": [ + "e0d064765f1758a1a0823b2c02d399caa5cae0d8ac5a1e2ed96548a647717433", # vit_b + "e0b5ab781c42e6f68b746fc056c918d56559ccaeedb4e4f2848b1e5e8f1bec58", # vit_b_em_organelles + ], + "embedseg": [ + "82f5351486e484dda5a3a327381458515c89da5dda8a48a0b1ab96ef10d23f02", # vit_b + "80fd701c01b81bbfb32beed6e2ece8c5706625dbc451776d8ba1c22253f097b9", # vit_b_lm + ], + "platynereis": [ + "95c5e31c5e55e94780568f3fb8a3fdf33f8586a4c6a375d28dccba6567f37a47", # vit_b + "3d8d91313656fde271a48ea0a3552762f2536955a357ffb43e7c43b5b27e0627", # vit_b_em_organelles + ], +} + + +def _download_embeddings(embedding_dir, dataset_name): + if dataset_name is None: # Download embeddings for all datasets. + dataset_names = list(URLS.keys()) + else: # Download embeddings for specific dataset. + dataset_names = [dataset_name] + + for dname in dataset_names: + if dname not in URLS: + raise ValueError( + f"'{dname}' does not have precomputed embeddings to download. Please choose from {list(URLS.keys())}" + ) + + urls = URLS[dname] + checksums = CHECKSUMS[dname] + + data_embedding_dir = os.path.join(embedding_dir, dname) + os.makedirs(data_embedding_dir, exist_ok=True) + + # Download the precomputed embeddings as zipfiles and unzip the embeddings per model. + for url, checksum in zip(urls, checksums): + if all([p.startswith("vit_b") for p in os.listdir(data_embedding_dir)]): + continue + + zip_path = os.path.join(data_embedding_dir, "embeddings.zip") + download_source(path=zip_path, url=url, download=True, checksum=checksum) + unzip(zip_path=zip_path, dst=data_embedding_dir) + + print(f"The precompted embeddings for '{dname}' are downloaded at f{data_embedding_dir}") + + +def main(): + import argparse + parser = argparse.ArgumentParser( + description="Download the precomputed image embeddings necessary for interactive annotation." + ) + parser.add_argument( + "-e", "--embedding_dir", type=str, default="./embeddings", + help="The filepath to the folder where the precomputed image embeddings will be downloaded. " + "By default, the embeddings will be stored in your current working directory at './embeddings'." + ) + parser.add_argument( + "-d", "--dataset_name", type=str, default=None, + help="The choice of volumetric dataset for which you would like to download the embeddings. " + "By default, it downloads all the precomputed embeddings. Optionally, you can choose to download either of the " + "volumetric datasets: 'lucchi', 'embedseg' or 'platynereis'." + ) + args = parser.parse_args() + + _download_embeddings(embedding_dir=args.embedding_dir, dataset_name=args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/workshops/finetune_sam.ipynb b/workshops/finetune_sam.ipynb new file mode 100644 index 000000000..62938d4ea --- /dev/null +++ b/workshops/finetune_sam.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Finetuning Segment Anything with `µsam`\n", + "\n", + "This notebook shows how to use Segment Anything for Microscopy to fine-tune a Segment Anything Model (SAM) on an open-source data with multiple channels.\n", + "\n", + "We use confocal microscopy images from the HPA Kaggle Challenge for protein identification (from [Ouyang et al.](https://doi.org/10.1038/s41592-019-0658-6)) in this notebook for the cell segmentation task. The functionalities shown here should work for your (microscopy) images too." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running this notebook\n", + "\n", + "If you have an environment with `µsam` on your computer you can run this notebook in there. You can follow the [installation instructions](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation) to install it on your computer.\n", + "\n", + "You can also run this notebook in the cloud on [Kaggle Notebooks](https://www.kaggle.com/code/). This service offers free usage of a GPU to speed up running the code. The next cells will take care of the installation for you if you are using it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if we are running this notebook on kaggle, google colab or local compute resources.\n", + "\n", + "import os\n", + "current_spot = os.getcwd()\n", + "\n", + "if current_spot.startswith(\"/kaggle/working\"):\n", + " print(\"Kaggle says hi!\")\n", + " root_dir = \"/kaggle/working\"\n", + "\n", + "elif current_spot.startswith(\"/content\"):\n", + " print(\"Google Colab says hi!\")\n", + " print(\" NOTE: The scripts have not been tested on Google Colab, you might need to adapt the installations a bit.\")\n", + " root_dir = \"/content\"\n", + "\n", + " # You might need to install condacolab on Google Colab to be able to install packages using conda / mamba\n", + " # !pip install -q condacolab\n", + " # import condacolab\n", + " # condacolab.install()\n", + "\n", + "else:\n", + " msg = \"You are using a behind-the-scenes resource. Follow our installation instructions here:\"\n", + " msg += \" https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation\"\n", + " print(msg)\n", + " root_dir = \"\" # overwrite to set the root directory, where the data, checkpoints, and all relevant stuff will be stored" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The next cells will install the `micro_sam` library on Kaggle Notebooks. **Please skip these cells and go to `Importing the libraries` if you are running the notebook on your own computer.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/computational-cell-analytics/micro-sam.git\n", + "tmp_dir = os.path.join(root_dir, \"micro-sam\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/constantinpape/torch-em.git\n", + "tmp_dir = os.path.join(root_dir, \"torch-em\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/constantinpape/elf.git\n", + "tmp_dir = os.path.join(root_dir, \"elf\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Known Issues on **Kaggle Notebooks**:\n", + "\n", + "1. `warning libmamba Cache file \"/opt/conda/pkgs/cache/2ce54b42.json\" was modified by another program` (multiples lines of such warnings)\n", + " - We have received this warning while testing this notebook on Kaggle. It does not lead to any issues while making use of the installed packages. You can proceed and ignore the warnings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mamba install -q -y -c conda-forge nifty affogato zarr z5py\n", + "!pip uninstall -y --quiet qtpy # qtpy is not supported in Kaggle / Google Colab, let's remove it to avoid errors." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Importing the libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from glob import glob\n", + "from typing import List\n", + "from natsort import natsorted\n", + "from IPython.display import FileLink\n", + "\n", + "import imageio.v3 as imageio\n", + "from matplotlib import pyplot as plt\n", + "from skimage.measure import label as connected_components\n", + "\n", + "import torch\n", + "\n", + "from torch_em.data import datasets\n", + "from torch_em.util.debug import check_loader\n", + "from torch_em.util.util import get_random_colors\n", + "\n", + "import micro_sam.training as sam_training\n", + "from micro_sam.training.util import normalize_to_8bit\n", + "from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's download the dataset\n", + "\n", + "First, we download the images and corresponding labels stored as `tif` files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download the data into a directory\n", + "DATA_FOLDER = os.path.join(root_dir, \"hpa\")\n", + "\n", + "URLS = [\n", + " \"https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download\", # train\n", + " \"https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download\", # val\n", + " \"https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download\", # test\n", + "]\n", + "\n", + "CHECKSUMS = [\n", + " \"6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab\", # train\n", + " \"4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c\", # val\n", + " \"8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe\", # test\n", + "]\n", + "\n", + "SPLITS = [\"train\", \"val\", \"test\"]\n", + "\n", + "for url, checksum, split in zip(URLS, CHECKSUMS, SPLITS):\n", + " data_dir = os.path.join(DATA_FOLDER, split)\n", + " if os.path.exists(data_dir):\n", + " continue\n", + " \n", + " os.makedirs(DATA_FOLDER, exist_ok=True)\n", + " zip_path = os.path.join(DATA_FOLDER, \"data.zip\")\n", + " datasets.util.download_source(path=zip_path, url=url, download=True, checksum=checksum)\n", + " datasets.util.unzip(zip_path=zip_path, dst=DATA_FOLDER)\n", + "\n", + "# Get filepaths to the image data.\n", + "train_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "val_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "test_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "\n", + "# Get filepaths to the label data.\n", + "train_label_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"labels\", \"*.tif\")))\n", + "val_label_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"labels\", \"*.tif\")))\n", + "\n", + "print(f\"The inputs have been preprocessed and stored at: '{DATA_FOLDER}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's understand our inputs' data structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for image_path, label_path in zip(train_image_paths, train_label_paths): # Checking the inputs for the train split.\n", + " image = imageio.imread(image_path)\n", + " labels = imageio.imread(label_path)\n", + "\n", + " # The images should be of shape: H, W, 4 -> where, 4 is the number of channels.\n", + " if (image.ndim == 3 and image.shape[-1] == 3) or image.ndim == 2:\n", + " print(f\"Inputs '{image.shape}' match the channel expectations.\")\n", + " else:\n", + " print(f\"Inputs '{image.shape}' must match the channel expectations (of either one or three channels).\")\n", + "\n", + " # The labels should be of shape: H, W\n", + " print(f\"Shape of corresponding labels: '{labels.shape}'\")\n", + "\n", + " break # comment this line out in case you would like to verify the shapes for all inputs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Segment Anything accepts inputs of either 1 channel or 3 channels. To fine-tune Segment Anything on our data, we must select either 1 channel or 3 channels out of the 4 channels available.\n", + "\n", + "Let's make the choice to choose the `microtubule` (first channel), `protein` (second channel) and `nuclei` (third channel) for finetuning Segment Anything." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We remove the 'er' channel, i.e. the last channel.\n", + "def preprocess_inputs(image_paths: List[str]):\n", + " for image_path in image_paths:\n", + " image = imageio.imread(image_path)\n", + "\n", + " if image.ndim == 3 and image.shape[-1] == 4: # Convert 4 channel inputs to 3 channels.\n", + " image = image[..., :-1]\n", + " imageio.imwrite(image_path, image)\n", + "\n", + "preprocess_inputs(train_image_paths)\n", + "preprocess_inputs(val_label_paths)\n", + "preprocess_inputs(test_image_paths)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's create the dataloaders\n", + "\n", + "Our task is to segment cells in confocal microscopy images. The dataset comes from https://zenodo.org/records/4665863, and the dataloader has been implemented in [torch-em](https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/light_microscopy/hpa.py)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### First, let's visualize how our samples look." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for image_path, label_path in zip(train_image_paths, train_label_paths): # Visualize inputs for the train split.\n", + " image = imageio.imread(image_path)\n", + " labels = imageio.imread(label_path)\n", + "\n", + " fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + " ax[0].imshow(image, cmap=\"gray\")\n", + " ax[0].set_title(\"Input Image\")\n", + " ax[0].axis(\"off\")\n", + " \n", + " labels = connected_components(labels)\n", + " ax[1].imshow(labels, cmap=get_random_colors(labels), interpolation=\"nearest\")\n", + " ax[1].set_title(\"Ground Truth Instances\")\n", + " ax[1].axis(\"off\")\n", + " \n", + " plt.show()\n", + " plt.close()\n", + " \n", + " break # comment this out in case you want to visualize all the images" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Next, let's create the dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# micro_sam.training.default_sam_loader is a convenience function to build a torch dataloader.\n", + "# from image data and labels for training segmentation models.\n", + "# This is wrapped around the 'torch_em.default_segmentation_loader'.\n", + "# It supports image data in various formats. Here, we load image data and corresponding labels by providing\n", + "# filepaths to the respective tif files that were download and preprocessed using the functionality above.\n", + "# Next, we create a list of filepaths for the image and label data by fetching all '*.tif' files in the\n", + "# respective directories.\n", + "# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n", + "# Here is a detailed notebook on finetuning Segment Anything: https://github.com/computational-cell-analytics/micro-sam/blob/master/notebooks/sam_finetuning.ipynb\n", + "\n", + "# Load images from tif stacks by setting `raw_key` and `label_key` to None.\n", + "raw_key, label_key = None, None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The script below returns the train or val data loader for finetuning Segment Anything.\n", + "\n", + "# The data loader must be a torch dataloader that returns `x, y` tensors,\n", + "# where `x` is the image data and `y` are the corresponding labels.\n", + "# The labels have to be in a label mask instance segmentation format.\n", + "# i.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.\n", + "# IMPORTANT: the ID 0 is reserved for backgroun, and the IDS must be consecutive.\n", + "\n", + "# Here, we use `micro_sam.training.default_sam_loader` for creating the suitable data loader from\n", + "# the HPA data. You can either adapt this for your own data or write a suitable torch dataloader yourself.\n", + "# Here is a quickstart notebook to create your own dataloaders: https://github.com/constantinpape/torch-em/blob/main/notebooks/tutorial_create_dataloaders.ipynb\n", + "\n", + "batch_size = 1 # the training batch size\n", + "patch_shape = (512, 512) # the size of patches for training\n", + "\n", + "# Train an additional convolutional decoder for end-to-end automatic instance segmentation\n", + "train_instance_segmentation = True\n", + "\n", + "# The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth\n", + "# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances\n", + "# to the object centers and object boundaries for automatic segmentation.\n", + "\n", + "train_loader = sam_training.default_sam_loader(\n", + " raw_paths=train_image_paths,\n", + " raw_key=raw_key,\n", + " label_paths=train_label_paths,\n", + " label_key=label_key,\n", + " is_seg_dataset=False,\n", + " patch_shape=patch_shape,\n", + " with_channels=True,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " raw_transform=normalize_to_8bit,\n", + " n_samples=100,\n", + ")\n", + "\n", + "val_loader = sam_training.default_sam_loader(\n", + " raw_paths=val_image_paths,\n", + " raw_key=raw_key,\n", + " label_paths=val_label_paths,\n", + " label_key=label_key,\n", + " is_seg_dataset=False,\n", + " patch_shape=patch_shape,\n", + " with_channels=True,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " batch_size=batch_size,\n", + " raw_transform=normalize_to_8bit,\n", + " shuffle=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's check how our samples lookm from the dataloader.\n", + "check_loader(train_loader, 4, plt=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the actual model finetuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# All hyperparameters for training.\n", + "n_objects_per_batch = 5 # the number of objects per batch that will be sampled\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # the device/GPU used for training\n", + "n_epochs = 5 # how long we train (in epochs)\n", + "\n", + "# The model_type determines which base model is used to initialize the weights that are finetuned.\n", + "# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.\n", + "model_type = \"vit_b\"\n", + "\n", + "# The name of the checkpoint. The checkpoints will be stored in './checkpoints/'\n", + "checkpoint_name = \"sam_hpa\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**NOTE**: The user needs to decide whether to finetune the Segment Anything model, or the `µsam`'s \"finetuned microscopy models\" for their dataset. Here, we finetune on the Segment Anything model for simplicity. For example, if you choose to finetune the model from the light microscopy generalist models, you need to update the `model_type` to `vit_b_lm` and it takes care of initializing the model with the desired weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run training\n", + "sam_training.train_sam(\n", + " name=checkpoint_name,\n", + " save_root=os.path.join(root_dir, \"models\"),\n", + " model_type=model_type,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " n_epochs=n_epochs,\n", + " n_objects_per_batch=n_objects_per_batch,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's spot our best checkpoint and download it to get started with the annotation tool\n", + "best_checkpoint = os.path.join(\"models\", \"checkpoints\", checkpoint_name, \"best.pt\")\n", + "\n", + "# # Download link is automatically generated for the best model.\n", + "print(\"Click here \\u2193\")\n", + "FileLink(best_checkpoint)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's run the automatic instance segmentation (AIS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_automatic_instance_segmentation(image, checkpoint, model_type=\"vit_b\", device=None):\n", + " \"\"\"Automatic Instance Segmentation trained with an additional instance segmentation decoder in SAM.\n", + "\n", + " NOTE: It is supported only for `µsam` models.\n", + " \n", + " Args:\n", + " image: The input image.\n", + " checkpoint: Filepath to the model checkpoints.\n", + " model_type: The choice of the `µsam` model.\n", + " device: The torch device.\n", + "\n", + " Returns:\n", + " The instance segmentation.\n", + " \"\"\"\n", + " # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.\n", + " predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, device=device)\n", + "\n", + " # Step 2: Get the instance segmentation for the given image.\n", + " instances = automatic_instance_segmentation(predictor=predictor, segmenter=segmenter, input_path=image, ndim=2)\n", + "\n", + " return instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n", + "assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n", + "\n", + "for image_path in test_image_paths:\n", + " image = imageio.imread(image_path)\n", + " \n", + " # Predicted instances\n", + " prediction = run_automatic_instance_segmentation(\n", + " image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n", + " )\n", + "\n", + " # Visualize the predictions\n", + " fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + " ax[0].imshow(image, cmap=\"gray\")\n", + " ax[0].axis(\"off\")\n", + " ax[0].set_title(\"Input Image\")\n", + "\n", + " ax[1].imshow(prediction, cmap=get_random_colors(prediction), interpolation=\"nearest\")\n", + " ax[1].axis(\"off\")\n", + " ax[1].set_title(\"Predictions (AIS)\")\n", + "\n", + " plt.show()\n", + " plt.close()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What next?\n", + "\n", + "It's time to get started with your custom finetuned model using the annotator tool. Here is the documentation on how to get started with `µsam`: [Annotation Tools](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#annotation-tools)\n", + "\n", + "Happy annotating!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*This notebook was last ran on October 18, 2024*" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/workshops/finetune_sam.py b/workshops/finetune_sam.py new file mode 100644 index 000000000..14ad6360a --- /dev/null +++ b/workshops/finetune_sam.py @@ -0,0 +1,355 @@ +"""Finetuning Segment Anything using µsam. + +This python script shows how to use Segment Anything for Microscopy to fine-tune a Segment Anything Model (SAM) +on an open-source data with multiple channels. + +We use confocal microscopy images from the HPA Kaggle Challenge for protein identification +(from Ouyang et al. - https://doi.org/10.1038/s41592-019-0658-6) in this script for the cell segmentation task. +The functionalities shown here should work for your (microscopy) images too. +""" + +import os +from typing import Union, Tuple, Literal, Optional, List + +import imageio.v3 as imageio +from matplotlib import pyplot as plt +from skimage.measure import label as connected_components + +import torch +from torch.utils.data import DataLoader + +from torch_em.util.debug import check_loader +from torch_em.util.util import get_random_colors + +from micro_sam import util +import micro_sam.training as sam_training +from micro_sam.training.util import normalize_to_8bit +from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation + +from download_datasets import _get_hpa_data_paths + + +def download_dataset( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = True, +) -> Tuple[List[str], List[str]]: + """Download the HPA dataset. + + This functionality downloads the images and corresponding labels stored as `tif` files. + + Args: + path: Filepath to the directory where the data will be stored. + split: The choice of data split. Either 'train', 'val' or 'test'. + download: Whether to download the dataset. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_path = os.path.join(path, "hpa") + image_paths, label_paths = _get_hpa_data_paths(path=data_path, split=split, download=download) + return image_paths, label_paths + + +def verify_inputs(image_paths: List[str], label_paths: List[str]): + """Verify the downloaded inputs and preprocess them. + + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. + """ + for image_path, label_path in zip(image_paths, label_paths): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + # The images should be of shape: H, W, 4 -> where, 4 is the number of channels. + if (image.ndim == 3 and image.shape[-1] == 3) or image.ndim == 2: + print(f"Inputs '{image.shape}' match the channel expectations.") + else: + print(f"Inputs '{image.shape}' must match the channel expectations (of either one or three channels).") + + # The labels should be of shape: H, W + print(f"Shape of corresponding labels: '{labels.shape}'") + + break # comment this line out in case you would like to verify the shapes for all inputs. + + +def preprocess_inputs(image_paths: List[str]): + """Preprocess the input images. + + Args: + image_paths: List of filepaths for the image data. + """ + # We remove the 'er' channel, i.e. the last channel. + for image_path in image_paths: + image = imageio.imread(image_path) + + if image.ndim == 3 and image.shape[-1] == 4: # Convert 4 channel inputs to 3 channels. + image = image[..., :-1] + imageio.imwrite(image_path, image) + + +def visualize_inputs(image_paths: List[str], label_paths: List[str]): + """Visualize the images and corresponding labels. + + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. + """ + for image_path, label_path in zip(image_paths, label_paths): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + fig, ax = plt.subplots(1, 2, figsize=(10, 10)) + ax[0].imshow(image, cmap="gray") + ax[0].set_title("Input Image") + ax[0].axis("off") + + labels = connected_components(labels) + ax[1].imshow(labels, cmap=get_random_colors(labels), interpolation="nearest") + ax[1].set_title("Ground Truth Instances") + ax[1].axis("off") + + plt.show() + plt.close() + + break # comment this out in case you want to visualize all the images + + +def get_dataloaders( + train_image_paths: List[str], + train_label_paths: List[str], + val_image_paths: List[str], + val_label_paths: List[str], + view: bool, + train_instance_segmentation: bool, +) -> Tuple[DataLoader, DataLoader]: + """Get the HPA dataloaders for cell segmentation. + + Args: + train_image_paths: List of filepaths for the training image data. + train_label_paths: List of filepaths for the training label data. + val_image_paths: List of filepaths for the validation image data. + val_label_paths: List of filepaths for the validation label data. + view: Whether to view the samples out of training dataloader. + train_instance_segmentation: Whether to finetune SAM with additional instance segmentation decoder. + + Returns: + The PyTorch DataLoader for training. + The PyTorch DataLoader for validation. + """ + # Load images from tif stacks by setting `raw_key` and `label_key` to None. + raw_key, label_key = None, None + + batch_size = 1 # the training batch size + patch_shape = (512, 512) # the size of patches for training + + train_loader = sam_training.default_sam_loader( + raw_paths=train_image_paths, + raw_key=raw_key, + label_paths=train_label_paths, + label_key=label_key, + is_seg_dataset=False, + patch_shape=patch_shape, + with_channels=True, + with_segmentation_decoder=train_instance_segmentation, + batch_size=batch_size, + shuffle=True, + raw_transform=normalize_to_8bit, + n_samples=100, + ) + val_loader = sam_training.default_sam_loader( + raw_paths=val_image_paths, + raw_key=raw_key, + label_paths=val_label_paths, + label_key=label_key, + is_seg_dataset=False, + patch_shape=patch_shape, + with_channels=True, + with_segmentation_decoder=train_instance_segmentation, + batch_size=batch_size, + shuffle=True, + raw_transform=normalize_to_8bit, + ) + + if view: + check_loader(train_loader, 4, plt=True) + + return train_loader, val_loader + + +def run_finetuning( + train_loader: DataLoader, + val_loader: DataLoader, + save_root: Optional[Union[os.PathLike, str]], + train_instance_segmentation: bool, + device: Union[torch.device, str], + model_type: str, + overwrite: bool, +) -> str: + """Run finetuning for the Segment Anything model on microscopy images. + + Args: + train_loader: The PyTorch dataloader used for training. + val_loader: The PyTorch dataloader used for validation. + save_root: The filepath to the folder where the model checkpoints and tensorboard logs are stored. + train_instance_segmentation: Whether to finetune SAM with additional instance segmentation decoder. + device: The torch device. + model_type: The choice of Segment Anything model (connotated by the size of image encoder). + overwrite: Whether to overwrite the already finetuned model checkpoints. + + Returns: + Filepath where the (best) model checkpoint is stored. + """ + # All hyperparameters for training. + n_objects_per_batch = 5 # the number of objects per batch that will be sampled + n_epochs = 5 # how long we train (in epochs) + + # The name of the checkpoint. The checkpoints will be stored in './checkpoints/' + checkpoint_name = "sam_hpa" + + # Let's spot our best checkpoint and run inference for automatic instance segmentation. + if save_root is None: + save_root = os.getcwd() + + best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") + if os.path.exists(best_checkpoint) and not overwrite: + print( + "It looks like the training has completed. You must pass the argument '--overwrite' to overwrite " + "the already finetuned model (or provide a new filepath at '--save_root' for training new models)." + ) + return best_checkpoint + + # Run training + sam_training.train_sam( + name=checkpoint_name, + save_root=save_root, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader, + n_epochs=n_epochs, + n_objects_per_batch=n_objects_per_batch, + with_segmentation_decoder=train_instance_segmentation, + device=device, + ) + + return best_checkpoint + + +def run_instance_segmentation_with_decoder( + test_image_paths: List[str], model_type: str, checkpoint: Union[os.PathLike, str], device: Union[torch.device, str], +): + """Run automatic instance segmentation (AIS). + + Args: + test_image_paths: List of filepaths for the test image data. + model_type: The choice of Segment Anything model (connotated by the size of image encoder). + checkpoint: Filepath to the finetuned model checkpoints. + device: The torch device used for inference. + """ + assert os.path.exists(checkpoint), "Please train the model first to run inference on the finetuned model." + + # Get the 'predictor' and 'segmenter' to perform automatic instance segmentation. + predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, device=device) + + for image_path in test_image_paths: + image = imageio.imread(image_path) + image = normalize_to_8bit(image) + + # Predicting the instances. + prediction = automatic_instance_segmentation(predictor=predictor, segmenter=segmenter, input_path=image, ndim=2) + + # Visualize the predictions + fig, ax = plt.subplots(1, 2, figsize=(10, 10)) + + ax[0].imshow(image, cmap="gray") + ax[0].axis("off") + ax[0].set_title("Input Image") + + ax[1].imshow(prediction, cmap=get_random_colors(prediction), interpolation="nearest") + ax[1].axis("off") + ax[1].set_title("Predictions (AIS)") + + plt.show() + plt.close() + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Run finetuning for Segment Anything model for microscopy images.") + parser.add_argument( + "-i", "--input_path", type=str, default="./data", + help="The filepath to the folder where the image data will be downloaded. " + "By default, the data will be stored in your current working directory at './data'." + ) + parser.add_argument( + "-s", "--save_root", type=str, default=None, + help="The filepath to store the model checkpoint and tensorboard logs. " + "By default, they will be stored in your current working directory at 'checkpoints' and 'logs'." + ) + parser.add_argument( + "--view", action="store_true", + help="Whether to visualize the raw inputs, samples from the dataloader, instance segmentation outputs, etc." + ) + parser.add_argument( + "--overwrite", action="store_true", help="Whether to overwrite the already finetuned model checkpoints." + ) + parser.add_argument( + "--device", type=str, default=None, help="The choice of device to run training and inference." + ) + args = parser.parse_args() + + device = util.get_device(args.device) # the device / GPU used for training and inference. + + # The model_type determines which base model is used to initialize the weights that are finetuned. + # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results. + model_type = "vit_b" + + # Train an additional convolutional decoder for end-to-end automatic instance segmentation + train_instance_segmentation = True + + # Step 1: Download the dataset. + train_image_paths, train_label_paths = download_dataset(path=args.input_path, split="train") + val_image_paths, val_label_paths = download_dataset(path=args.input_path, split="val") + test_image_paths, _ = download_dataset(path=args.input_path, split="test") + + # Step 2: Verify the spatial shape of inputs (only for the 'train' split) + verify_inputs(image_paths=train_image_paths, label_paths=train_label_paths) + + # Step 3: Preprocess input images. + preprocess_inputs(image_paths=train_image_paths) + preprocess_inputs(image_paths=val_image_paths) + preprocess_inputs(image_paths=test_image_paths) + + if args.view: + # Step 3(a): Visualize the images and corresponding labels (only for the 'train' split) + visualize_inputs(image_paths=train_image_paths, label_paths=train_label_paths) + + # Step 4: Get the dataloaders. + train_loader, val_loader = get_dataloaders( + train_image_paths=train_image_paths, + train_label_paths=train_label_paths, + val_image_paths=val_image_paths, + val_label_paths=val_label_paths, + view=args.view, + train_instance_segmentation=train_instance_segmentation, + ) + + # Step 5: Run the finetuning for Segment Anything Model. + checkpoint_path = run_finetuning( + train_loader=train_loader, + val_loader=val_loader, + save_root=args.save_root, + train_instance_segmentation=train_instance_segmentation, + device=device, + model_type=model_type, + overwrite=args.overwrite, + ) + + # Step 6: Run automatic instance segmentation using the finetuned model. + run_instance_segmentation_with_decoder( + test_image_paths=test_image_paths, model_type=model_type, checkpoint=checkpoint_path, device=device, + ) + + +if __name__ == "__main__": + main()