Skip to content

Commit

Permalink
Add materials for hands-on workshops (#742)
Browse files Browse the repository at this point in the history
Add materials for workshop and minor fixes
  • Loading branch information
anwai98 authored Oct 18, 2024
1 parent a01fad1 commit 69f3c01
Show file tree
Hide file tree
Showing 8 changed files with 1,282 additions and 25 deletions.
36 changes: 21 additions & 15 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def automatic_instance_segmentation(
embedding_path: The path where the embeddings are cached already / will be saved.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
ndim: The dimensionality of the data.
ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
verbose: Verbosity flag.
Expand All @@ -102,21 +103,12 @@ def automatic_instance_segmentation(
else:
image_data = util.load_image_data(input_path, key)

if ndim == 3 or image_data.ndim == 3:
if image_data.ndim != 3:
raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'")
ndim = image_data.ndim if ndim is None else ndim

if ndim == 2:
if image_data.ndim != 2 or image_data.shape[-1] != 3:
raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")

instances = automatic_3d_segmentation(
volume=image_data,
predictor=predictor,
segmentor=segmenter,
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)
else:
# Precompute the image embeddings.
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
Expand All @@ -142,6 +134,20 @@ def automatic_instance_segmentation(
instances = np.zeros(this_shape, dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
else:
if image_data.ndim != 3 or image_data.shape[-1] != 3:
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")

instances = automatic_3d_segmentation(
volume=image_data,
predictor=predictor,
segmentor=segmenter,
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
**generate_kwargs
)

if output_path is not None:
# Save the instance segmentation
Expand Down
5 changes: 5 additions & 0 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def __call__(self, x, y):
#


def normalize_to_8bit(raw):
raw = normalize(raw) * 255
return raw


class ResizeRawTrafo:
def __init__(self, desired_shape, do_rescaling=False, padding="constant"):
self.desired_shape = desired_shape
Expand Down
17 changes: 7 additions & 10 deletions notebooks/sam_finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@
"import torch\n",
"\n",
"import torch_em\n",
"from torch_em.model import UNETR\n",
"from torch_em.util.debug import check_loader\n",
"from torch_em.loss import DiceBasedDistanceLoss\n",
"from torch_em.util.util import get_random_colors\n",
"from torch_em.transform.label import PerObjectDistanceTransform\n",
"\n",
Expand Down Expand Up @@ -610,9 +608,9 @@
"# It supports image data in various formats. Here, we load image data and labels from the two\n",
"# folders with tif images that were downloaded by the example data functionality, by specifying\n",
"# `raw_key` and `label_key` as `*.tif`. This means all images in the respective folders that end with\n",
"# .tif will be loadded.\n",
"# .tif will be loaded.\n",
"# The function supports many other file formats. For example, if you have tif stacks with multiple slices\n",
"# instead of multiple tif images in a foldder, then you can pass raw_key=label_key=None.\n",
"# instead of multiple tif images in a folder, then you can pass raw_key=label_key=None.\n",
"# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n",
"\n",
"# Load images from multiple files in folder via pattern (here: all tif files)\n",
Expand Down Expand Up @@ -950,13 +948,15 @@
"outputs": [],
"source": [
"def run_automatic_instance_segmentation(image, checkpoint_path, model_type=\"vit_b_lm\", device=None):\n",
" \"\"\"Automatic Instance Segmentation by training an additional instance decoder in SAM.\n",
" \"\"\"Automatic Instance Segmentation trained with an additional instance segmentation decoder in SAM.\n",
"\n",
" NOTE: It is supported only for `µsam` models.\n",
" \n",
" Args:\n",
" image: The input image.\n",
" checkpoint: Filepath to the model checkpoints.\n",
" model_type: The choice of the `µsam` model.\n",
" device: The torch device.\n",
" \n",
" Returns:\n",
" The instance segmentation.\n",
Expand Down Expand Up @@ -1392,18 +1392,15 @@
"assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n",
"assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n",
"\n",
"# # Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n",
"# Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n",
"image_paths = image_paths[:5]\n",
"\n",
"for image_path in image_paths:\n",
" image = imageio.imread(image_path)\n",
" \n",
" # Predicted instances\n",
" prediction = run_automatic_instance_segmentation(\n",
" image=image,\n",
" checkpoint_path=best_checkpoint,\n",
" model_type=model_type,\n",
" device=device\n",
" image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n",
" )\n",
"\n",
" # Visualize the predictions\n",
Expand Down
102 changes: 102 additions & 0 deletions workshops/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Hands-On Analysis using `micro-sam`

## Upcoming Workshops:
1. I2K 2024 (Milan, Italy)
2. Virtual I2K 2024 (Online)

## Introduction

In this document, we walk you through different steps involved to participate in hands-on image annotation experiments our tool.

- Here is our [official documentation](https://computational-cell-analytics.github.io/micro-sam/) for detailed explanation of our tools, library and the finetuned models.
- Here is the playlist for our [tutorial videos](https://youtube.com/playlist?list=PLwYZXQJ3f36GQPpKCrSbHjGiH39X4XjSO&si=3q-cIRD6KuoZFmAM) hosted on YouTube, elaborating in detail on the features of our tools.

## Steps:

### Step 1: Download the Datasets

- We provide the script `download_datasets.py` for automatic download of datasets to be used for interactive annotation using `micro-sam`.
- You can run the script as follows:
```bash
$ python download_datasets.py -i <DATA_DIRECTORY> -d <DATASET_NAME>
```
where, `DATA_DIRECTORY` is the filepath to the directory where the datasets will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_datasets.py -h` in the terminal for more details).

> NOTE: We have chosen a) subset of the CellPose `cyto` dataset, b) one volume from the EmbedSeg `Mouse-Skull-Nuclei-CBG` dataset from the train split (namely, `X1.tif`), c) one volume from the Platynereis `Membrane` dataset from the train split (namely, `train_data_membrane_02.n5`) and d) the entire `HPA` dataset for the following tasks in `micro-sam`.
### Step 2: Download the Precomputed Embeddings

- We provide the script `download_embeddings.py` for automatic download of precompute image embeddings for volumetric data to be used for interactive annotation using `micro-sam`.
- You can run the script as follows:

```bash
$ python download_embeddings -e <EMBEDDING_DIRECTORY> -d <DATASET_NAME>
```
where, `EMBEDDING_DIRECTORY` is the filepath to the directory where the precomputed image embeddings will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_embeddings.py -h` in the terminal for more details).

### Additional Section: Precompute the Embeddings Yourself!

Here is an example guide to precompute the image embeddings (eg. for volumetric data).

#### EmbedSeg

```bash
$ micro_sam.precompute_embeddings -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where inputs are stored.
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm').
-e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1 # Filepath where computed embeddings will be cached.
```

#### Platynereis

```bash
$ micro_sam.precompute_embeddings -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where inputs are stored.
-k volumes/raw/s1 # Key to access the data group in container-style data structures.
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles').
-e embeddings/platynereis/vit_b/platynereis_train_data_membrane_02 # Filepath where computed embeddings will be cached.
```

### Step 3: Run the `micro-sam` Annotators (WIP)

Run the `micro-sam` annotators with the following scripts:

We recommend using the napari GUI for the interactive annotation. You can use the widget to specify all the essential parameters (eg. the choice of model, the filepath to the precomputed embeddings, etc).

TODO: add more details here.

There is another option to use `micro-sam`'s CLI to start our annotator tools.

#### 2D Annotator (Cell Segmentation in Light Microscopy):

```bash
$ micro_sam.annotator_2d -i data/cellpose/cyto/test/... # Filepath where the 2d image is stored.
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
[OPTIONAL] -e embeddings/cellpose/vit_b/... # Filepath where the computed embeddings will be cached (you can choose to not pass it to compute the embeddings on-the-fly).
```

#### 3D Annotator (EmbedSeg - Nuclei Segmentation in Light Microscopy):

```bash
$ micro_sam.annotator_3d -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where the 3d volume is stored.
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
-e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1.zarr # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly).
```

#### 3D Annotator (Platynereis - Membrane Segmentation in Electron Microscopy):

```bash
$ micro_sam.annotator_3d -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where the 2d image is stored.
-k volumes/raw/s1 # Key to access the data group in container-style data structures.
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles')
-e embeddings/platynereis/vit_b/... # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly).
```

#### Image Series Annotator (Multiple Light Microscopy 2D Images for Cell Segmentation):

```bash
$ micro_sam.image_series_annotator -i ...
-m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm')
```

### Step 4: Finetune Segment Anything on Microscopy Images (WIP)

- We provide a notebook `finetune_sam.ipynb` / `finetune_sam.py` for finetuning Segment Anything Model for cell segmentation in confocal microscopy images.
140 changes: 140 additions & 0 deletions workshops/download_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
from glob import glob
from natsort import natsorted

from torch_em.data import datasets
from torch_em.util.image import load_data


def _download_sample_data(path, data_dir, url, checksum, download):
if os.path.exists(data_dir):
return

os.makedirs(path, exist_ok=True)

zip_path = os.path.join(path, "data.zip")
datasets.util.download_source(path=zip_path, url=url, download=download, checksum=checksum)
datasets.util.unzip(zip_path=zip_path, dst=path)


def _get_cellpose_sample_data_paths(path, download):
data_dir = os.path.join(path, "cellpose", "cyto", "test")

url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download"
checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a"

_download_sample_data(path, data_dir, url, checksum, download)

raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png")))
label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png")))

return raw_paths, label_paths


def _get_hpa_data_paths(path, split, download):
urls = [
"https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train
"https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val
"https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download", # test
]
checksums = [
"6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab", # train
"4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c", # val
"8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test
]
splits = ["train", "val", "test"]
assert split in splits, f"'{split}' is not a valid split."

for url, checksum, _split in zip(urls, checksums, splits):
data_dir = os.path.join(path, _split)
_download_sample_data(path, data_dir, url, checksum, download)

raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif")))

if split == "test": # The 'test' split for HPA does not have labels.
return raw_paths, None
else:
label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif")))
return raw_paths, label_paths


def _get_dataset_paths(path, dataset_name, view=False):
dataset_paths = {
# 2d LM dataset for cell segmentation
"cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True),
"hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"),
# 3d LM dataset for nuclei segmentation
"embedseg": lambda: datasets.embedseg_data.get_embedseg_paths(
path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True,
),
# 3d EM dataset for membrane segmentation
"platynereis": lambda: datasets.platynereis.get_platynereis_paths(
path=os.path.join(path, "platynereis"), sample_ids=None, name="cells", download=True,
),
}

dataset_keys = {
"cellpose": [None, None],
"embedseg": [None, None],
"platynereis": ["volumes/raw/s1", "volumes/labels/segmentation/s1"]
}

if dataset_name is None: # Download all datasets.
dataset_names = list(dataset_paths.keys())
else: # Download specific datasets.
dataset_names = [dataset_name]

for dname in dataset_names:
if dname not in dataset_paths:
raise ValueError(
f"'{dname}' is not a supported dataset enabled for download. "
f"Please choose from {list(dataset_paths.keys())}."
)

paths = dataset_paths[dname]()
print(f"'{dataset_name}' is download at {path}.")

if view:
import napari

if isinstance(paths, tuple): # datasets with explicit raw and label paths
raw_paths, label_paths = paths
else:
raw_paths = label_paths = paths

raw_key, label_key = dataset_keys[dname]
for raw_path, label_path in zip(raw_paths, label_paths):
raw = load_data(raw_path, raw_key)
labels = load_data(label_path, label_key)

v = napari.Viewer()
v.add_image(raw)
v.add_labels(labels)
napari.run()

break # comment this line out in case you would like to visualize all samples.


def main():
import argparse
parser = argparse.ArgumentParser(description="Download the dataset necessary for the workshop.")
parser.add_argument(
"-i", "--input_path", type=str, default="./data",
help="The filepath to the folder where the image data will be downloaded. "
"By default, the data will be stored in your current working directory at './data'."
)
parser.add_argument(
"-d", "--dataset_name", type=str, default=None,
help="The choice of dataset you would like to download. By default, it downloads all the datasets. "
"Optionally, you can choose to download either of 'cellpose', 'hpa', 'embedseg' or 'platynereis'."
)
parser.add_argument(
"-v", "--view", action="store_true", help="Whether to view the downloaded data."
)
args = parser.parse_args()

_get_dataset_paths(path=args.input_path, dataset_name=args.dataset_name, view=args.view)


if __name__ == "__main__":
main()
Loading

0 comments on commit 69f3c01

Please sign in to comment.