diff --git a/colabs/README.md b/colabs/README.md index 141c049b..d6144c25 100644 --- a/colabs/README.md +++ b/colabs/README.md @@ -29,6 +29,7 @@ | 🦄 Fine-tune a Torchvision Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/torchvision_keras.ipynb) | | 🦄 Fine-tune a Timm Model with KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/timm_keras.ipynb) | | 🦄 Medical Image Classification Tutorial using MonAI and KerasCore | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/master/colabs/keras/keras_core/monai_medmnist_keras.ipynb) | +| 🩻 Brain tumor 3D segmentation with MONAI and Weights & Biases | [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/main/colabs/monai/3d_brain_tumor_segmentation.ipynb) | # 🏋🏽‍♂️ W&B Features diff --git a/colabs/monai/3d_brain_tumor_segmentation.ipynb b/colabs/monai/3d_brain_tumor_segmentation.ipynb new file mode 100644 index 00000000..7eba44b6 --- /dev/null +++ b/colabs/monai/3d_brain_tumor_segmentation.ipynb @@ -0,0 +1,986 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Brain tumor 3D segmentation with MONAI and Weights & Biases\n", + "\n", + "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wandb/examples/blob/main/colabs/monai/3d_brain_tumor_segmentation.ipynb)\n", + "\n", + "This tutorial shows how to construct a training workflow of multi-labels 3D brain tumor segmentation task using [MONAI](https://github.com/Project-MONAI/MONAI) and use experiment tracking and data visualization features of [Weights & Biases](https://wandb.ai/site). The tutorial contains the following features:\n", + "\n", + "1. Initialize a Weights & Biases run and synchrozize all configs associated with the run for reproducibility.\n", + "2. MONAI transform API:\n", + " 1. MONAI Transforms for dictionary format data.\n", + " 2. How to define a new transform according to MONAI `transforms` API.\n", + " 3. How to randomly adjust intensity for data augmentation.\n", + "3. Data Loading and Visualization:\n", + " 1. Load Nifti image with metadata, load a list of images and stack them.\n", + " 2. Cache IO and transforms to accelerate training and validation.\n", + " 3. Visualize the data using `wandb.Table` and interactive segmentation overlay on Weights & Biases.\n", + "4. Training a 3D `SegResNet` model\n", + " 1. Using the `networks`, `losses`, and `metrics` APIs from MONAI.\n", + " 2. Training the 3D `SegResNet` model using a PyTorch training loop.\n", + " 3. Track the training experiment using Weights & Biases.\n", + " 4. Log and version model checkpoints as model artifacts on Weights & Biases.\n", + "5. Visualize and compare the predictions on the validation dataset using `wandb.Table` and interactive segmentation overlay on Weights & Biases." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🌴 Setup and Installation\n", + "\n", + "First, let us install the latest version of both MONAI and Weights and Biases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q -U \"monai[nibabel, tqdm]\"\n", + "!python -c \"import wandb\" || pip install -q -U wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "from tqdm.auto import tqdm\n", + "import wandb\n", + "\n", + "from monai.apps import DecathlonDataset\n", + "from monai.data import DataLoader, decollate_batch\n", + "from monai.losses import DiceLoss\n", + "from monai.config import print_config\n", + "from monai.inferers import sliding_window_inference\n", + "from monai.metrics import DiceMetric\n", + "from monai.networks.nets import SegResNet\n", + "from monai.transforms import (\n", + " Activations,\n", + " AsDiscrete,\n", + " Compose,\n", + " LoadImaged,\n", + " MapTransform,\n", + " NormalizeIntensityd,\n", + " Orientationd,\n", + " RandFlipd,\n", + " RandScaleIntensityd,\n", + " RandShiftIntensityd,\n", + " RandSpatialCropd,\n", + " Spacingd,\n", + " EnsureTyped,\n", + " EnsureChannelFirstd,\n", + ")\n", + "from monai.utils import set_determinism\n", + "\n", + "import torch\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will then authenticate this colab instance to use W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🌳 Initialize a W&B Run\n", + "\n", + "We will start a new W&B run to start tracking our experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.init(project=\"monai-brain-tumor-segmentation\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = wandb.config\n", + "config.seed = 0\n", + "config.roi_size = [224, 224, 144]\n", + "config.batch_size = 1\n", + "config.num_workers = 4\n", + "config.max_train_images_visualized = 20\n", + "config.max_val_images_visualized = 20\n", + "config.dice_loss_smoothen_numerator = 0\n", + "config.dice_loss_smoothen_denominator = 1e-5\n", + "config.dice_loss_squared_prediction = True\n", + "config.dice_loss_target_onehot = False\n", + "config.dice_loss_apply_sigmoid = True\n", + "config.initial_learning_rate = 1e-4\n", + "config.weight_decay = 1e-5\n", + "config.max_train_epochs = 50\n", + "config.validation_intervals = 1\n", + "config.dataset_dir = \"./dataset/\"\n", + "config.checkpoint_dir = \"./checkpoints\"\n", + "config.inference_roi_size = (128, 128, 64)\n", + "config.max_prediction_images_visualized = 20" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We would also need to set the random seed for modules to enable or disable deterministic training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "set_determinism(seed=config.seed)\n", + "\n", + "# Create directories\n", + "os.makedirs(config.dataset_dir, exist_ok=True)\n", + "os.makedirs(config.checkpoint_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 💿 Data Loading and Transformation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we use the `monai.transforms` API to create a custom transform that converts the multi-classes labels into multi-labels segmentation task in one-hot format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n", + " \"\"\"\n", + " Convert labels to multi channels based on brats classes:\n", + " label 1 is the peritumoral edema\n", + " label 2 is the GD-enhancing tumor\n", + " label 3 is the necrotic and non-enhancing tumor core\n", + " The possible classes are TC (Tumor core), WT (Whole tumor)\n", + " and ET (Enhancing tumor).\n", + "\n", + " Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb\n", + "\n", + " \"\"\"\n", + "\n", + " def __call__(self, data):\n", + " d = dict(data)\n", + " for key in self.keys:\n", + " result = []\n", + " # merge label 2 and label 3 to construct TC\n", + " result.append(torch.logical_or(d[key] == 2, d[key] == 3))\n", + " # merge labels 1, 2 and 3 to construct WT\n", + " result.append(\n", + " torch.logical_or(\n", + " torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1\n", + " )\n", + " )\n", + " # label 2 is ET\n", + " result.append(d[key] == 2)\n", + " d[key] = torch.stack(result, axis=0).float()\n", + " return d" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we set up transforms for training and validation datasets respectively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_transform = Compose(\n", + " [\n", + " # load 4 Nifti images and stack them together\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " RandSpatialCropd(\n", + " keys=[\"image\", \"label\"], roi_size=config.roi_size, random_size=False\n", + " ),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=0),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=1),\n", + " RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=2),\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " RandScaleIntensityd(keys=\"image\", factors=0.1, prob=1.0),\n", + " RandShiftIntensityd(keys=\"image\", offsets=0.1, prob=1.0),\n", + " ]\n", + ")\n", + "val_transform = Compose(\n", + " [\n", + " LoadImaged(keys=[\"image\", \"label\"]),\n", + " EnsureChannelFirstd(keys=\"image\"),\n", + " EnsureTyped(keys=[\"image\", \"label\"]),\n", + " ConvertToMultiChannelBasedOnBratsClassesd(keys=\"label\"),\n", + " Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " Spacingd(\n", + " keys=[\"image\", \"label\"],\n", + " pixdim=(1.0, 1.0, 1.0),\n", + " mode=(\"bilinear\", \"nearest\"),\n", + " ),\n", + " NormalizeIntensityd(keys=\"image\", nonzero=True, channel_wise=True),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🍁 The Dataset\n", + "\n", + "The dataset that we will use for this experiment comes from http://medicaldecathlon.com/. We will use Multimodal multisite MRI data (FLAIR, T1w, T1gd, T2w) to segment Gliomas, necrotic/active tumour, and oedema. The dataset consists of 750 4D volumes (484 Training + 266 Testing).\n", + "\n", + "We will use the `DecathlonDataset` to automatically download and extract the dataset. It inherits MONAI `CacheDataset` which enables us to set `cache_num=N` to cache `N` items for training and use the default args to cache all the items for validation, depending on your memory size." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = DecathlonDataset(\n", + " root_dir=config.dataset_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=val_transform,\n", + " section=\"training\",\n", + " download=True,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")\n", + "val_dataset = DecathlonDataset(\n", + " root_dir=config.dataset_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " transform=val_transform,\n", + " section=\"validation\",\n", + " download=False,\n", + " cache_rate=0.0,\n", + " num_workers=4,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** Instead of applying the `train_transform` to the `train_dataset`, we have applied `val_transform` to both the training and validation datasets. This is because, before training, we would be visualizing samples from both the splits of the dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📸 Visualizing the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Weights & Biases supports images, video, audio, and more. Log rich media to explore our results and visually compare our runs, models, and datasets. We would be using the [segmentation mask overlay system](https://docs.wandb.ai/guides/track/log/media#image-overlays-in-tables) to visualize our data volumes. To log segmentation masks in [tables](https://docs.wandb.ai/guides/tables), we will need to provide a `wandb.Image`` object for each row in the table.\n", + "\n", + "An example is provided in the Code snippet below:\n", + "\n", + "```python\n", + "table = wandb.Table(columns=[\"ID\", \"Image\"])\n", + "\n", + "for id, img, label in zip(ids, images, labels):\n", + " mask_img = wandb.Image(\n", + " img,\n", + " masks={\n", + " \"prediction\": {\"mask_data\": label, \"class_labels\": class_labels}\n", + " # ...\n", + " },\n", + " )\n", + "\n", + " table.add_data(id, img)\n", + "\n", + "wandb.log({\"Table\": table})\n", + "```\n", + "\n", + "Let us now write a simple utility function that takes a sample image, label, `wandb.Table` object and some associated metadata and populate the rows of a table that would be logged to our Weights & Biases dashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def log_data_samples_into_tables(\n", + " sample_image: np.array,\n", + " sample_label: np.array,\n", + " split: str = None,\n", + " data_idx: int = None,\n", + " table: wandb.Table = None,\n", + "):\n", + " num_channels, _, _, num_slices = sample_image.shape\n", + " with tqdm(total=num_slices, leave=False) as progress_bar:\n", + " for slice_idx in range(num_slices):\n", + " ground_truth_wandb_images = []\n", + " for channel_idx in range(num_channels):\n", + " ground_truth_wandb_images.append(\n", + " wandb.Image(\n", + " sample_image[channel_idx, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Tumor-Core\": {\n", + " \"mask_data\": sample_label[0, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n", + " },\n", + " \"ground-truth/Whole-Tumor\": {\n", + " \"mask_data\": sample_label[1, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n", + " },\n", + " \"ground-truth/Enhancing-Tumor\": {\n", + " \"mask_data\": sample_label[2, :, :, slice_idx] * 3,\n", + " \"class_labels\": {0: \"background\", 3: \"Enhancing Tumor\"},\n", + " },\n", + " },\n", + " )\n", + " )\n", + " table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)\n", + " progress_bar.update(1)\n", + " return table" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we define the `wandb.Table` object and what columns it consists of so that we can populate with our data visualizations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table = wandb.Table(\n", + " columns=[\n", + " \"Split\",\n", + " \"Data Index\",\n", + " \"Slice Index\",\n", + " \"Image-Channel-0\",\n", + " \"Image-Channel-1\",\n", + " \"Image-Channel-2\",\n", + " \"Image-Channel-3\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we loop over the `train_dataset` and `val_dataset` respectively to generate the visualizations for the data samples and populate the rows of the table which we would log to our dashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate visualizations for train_dataset\n", + "max_samples = (\n", + " min(config.max_train_images_visualized, len(train_dataset))\n", + " if config.max_train_images_visualized > 0\n", + " else len(train_dataset)\n", + ")\n", + "progress_bar = tqdm(\n", + " enumerate(train_dataset[:max_samples]),\n", + " total=max_samples,\n", + " desc=\"Generating Train Dataset Visualizations:\",\n", + ")\n", + "for data_idx, sample in progress_bar:\n", + " sample_image = sample[\"image\"].detach().cpu().numpy()\n", + " sample_label = sample[\"label\"].detach().cpu().numpy()\n", + " table = log_data_samples_into_tables(\n", + " sample_image,\n", + " sample_label,\n", + " split=\"train\",\n", + " data_idx=data_idx,\n", + " table=table,\n", + " )\n", + "\n", + "# Generate visualizations for val_dataset\n", + "max_samples = (\n", + " min(config.max_val_images_visualized, len(val_dataset))\n", + " if config.max_val_images_visualized > 0\n", + " else len(val_dataset)\n", + ")\n", + "progress_bar = tqdm(\n", + " enumerate(val_dataset[:max_samples]),\n", + " total=max_samples,\n", + " desc=\"Generating Validation Dataset Visualizations:\",\n", + ")\n", + "for data_idx, sample in progress_bar:\n", + " sample_image = sample[\"image\"].detach().cpu().numpy()\n", + " sample_label = sample[\"label\"].detach().cpu().numpy()\n", + " table = log_data_samples_into_tables(\n", + " sample_image,\n", + " sample_label,\n", + " split=\"val\",\n", + " data_idx=data_idx,\n", + " table=table,\n", + " )\n", + "\n", + "# Log the table to your dashboard\n", + "wandb.log({\"Tumor-Segmentation-Data\": table})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The data appears to us on our W&B dashboard in an interactive tabular format. We can see each channel of a particular slice from a data volume overlayed with the respective segmentation mask in each row. Let us write [Weave queries](https://docs.wandb.ai/guides/weave) to filter the data on our table and focus on one particular row.\n", + "\n", + "![](./assets/viz-1.gif)\n", + "\n", + "Let us now open an image and check how we can interact with each of the segmentation masks using the interactive overlay.\n", + "\n", + "![](./assets/viz-2.gif)\n", + "\n", + "**Note:** The labels in the dataset consist of non-overlapping masks across classes, hence, they were logged as separate masks in the overlay." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🛫 Loading the Data\n", + "\n", + "We create the PyTorch dataloaders for loading the data from the datasets. Note that before creating the dataloaders, we set the `transform` for `train_dataset` to `train_transform` to preprocess and transform the data for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# apply train_transforms to the training dataset\n", + "train_dataset.transform = train_transform\n", + "\n", + "# create the train_loader\n", + "train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=True,\n", + " num_workers=config.num_workers,\n", + ")\n", + "\n", + "# create the val_loader\n", + "val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=config.batch_size,\n", + " shuffle=False,\n", + " num_workers=config.num_workers,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🤖 Creating the Model, Loss, and Optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial we will be training a `SegResNet` model based on the paper [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). We create the `SegResNet` model that comes implemented as a PyTorch Module as part of the `monai.networks` API. We also create our optimizer and learning rate scheduler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\")\n", + "\n", + "# create model\n", + "model = SegResNet(\n", + " blocks_down=[1, 2, 2, 4],\n", + " blocks_up=[1, 1, 1],\n", + " init_filters=16,\n", + " in_channels=4,\n", + " out_channels=3,\n", + " dropout_prob=0.2,\n", + ").to(device)\n", + "\n", + "# create optimizer\n", + "optimizer = torch.optim.Adam(\n", + " model.parameters(),\n", + " config.initial_learning_rate,\n", + " weight_decay=config.weight_decay,\n", + ")\n", + "\n", + "# create learning rate scheduler\n", + "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", + " optimizer, T_max=config.max_train_epochs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define our loss as multi-label `DiceLoss` using the `monai.losses` API and the corresponding dice metrics using the `monai.metrics` API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_function = DiceLoss(\n", + " smooth_nr=config.dice_loss_smoothen_numerator,\n", + " smooth_dr=config.dice_loss_smoothen_denominator,\n", + " squared_pred=config.dice_loss_squared_prediction,\n", + " to_onehot_y=config.dice_loss_target_onehot,\n", + " sigmoid=config.dice_loss_apply_sigmoid,\n", + ")\n", + "\n", + "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", + "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n", + "post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n", + "\n", + "# use automatic mixed-precision to accelerate training\n", + "scaler = torch.cuda.amp.GradScaler()\n", + "torch.backends.cudnn.benchmark = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def inference(model, input):\n", + " def _compute(input):\n", + " return sliding_window_inference(\n", + " inputs=input,\n", + " roi_size=(240, 240, 160),\n", + " sw_batch_size=1,\n", + " predictor=model,\n", + " overlap=0.5,\n", + " )\n", + "\n", + " with torch.cuda.amp.autocast():\n", + " return _compute(input)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🚝 Training and Validation\n", + "\n", + "Before we start training, let us define some metric properties which will later be logged with `wandb.log()` for tracking our training and validation experiments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wandb.define_metric(\"epoch/epoch_step\")\n", + "wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n", + "wandb.define_metric(\"batch/batch_step\")\n", + "wandb.define_metric(\"batch/*\", step_metric=\"batch/batch_step\")\n", + "wandb.define_metric(\"validation/validation_step\")\n", + "wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")\n", + "\n", + "batch_step = 0\n", + "validation_step = 0\n", + "metric_values = []\n", + "metric_values_tumor_core = []\n", + "metric_values_whole_tumor = []\n", + "metric_values_enhanced_tumor = []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🍭 Execute Standard PyTorch Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a W&B Artifact object\n", + "artifact = wandb.Artifact(\n", + " name=f\"{wandb.run.id}-checkpoint\", type=\"model\"\n", + ")\n", + "\n", + "epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n", + "\n", + "for epoch in epoch_progress_bar:\n", + " model.train()\n", + " epoch_loss = 0\n", + "\n", + " total_batch_steps = len(train_dataset) // train_loader.batch_size\n", + " batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n", + " \n", + " # Training Step\n", + " for batch_data in batch_progress_bar:\n", + " inputs, labels = (\n", + " batch_data[\"image\"].to(device),\n", + " batch_data[\"label\"].to(device),\n", + " )\n", + " optimizer.zero_grad()\n", + " with torch.cuda.amp.autocast():\n", + " outputs = model(inputs)\n", + " loss = loss_function(outputs, labels)\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " epoch_loss += loss.item()\n", + " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n", + " ## Log batch-wise training loss to W&B\n", + " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n", + " batch_step += 1\n", + "\n", + " lr_scheduler.step()\n", + " epoch_loss /= total_batch_steps\n", + " ## Log batch-wise training loss and learning rate to W&B\n", + " wandb.log(\n", + " {\n", + " \"epoch/epoch_step\": epoch,\n", + " \"epoch/mean_train_loss\": epoch_loss,\n", + " \"epoch/learning_rate\": lr_scheduler.get_last_lr()[0],\n", + " }\n", + " )\n", + " epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n", + "\n", + " # Validation and model checkpointing\n", + " if (epoch + 1) % config.validation_intervals == 0:\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for val_data in val_loader:\n", + " val_inputs, val_labels = (\n", + " val_data[\"image\"].to(device),\n", + " val_data[\"label\"].to(device),\n", + " )\n", + " val_outputs = inference(model, val_inputs)\n", + " val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]\n", + " dice_metric(y_pred=val_outputs, y=val_labels)\n", + " dice_metric_batch(y_pred=val_outputs, y=val_labels)\n", + "\n", + " metric_values.append(dice_metric.aggregate().item())\n", + " metric_batch = dice_metric_batch.aggregate()\n", + " metric_values_tumor_core.append(metric_batch[0].item())\n", + " metric_values_whole_tumor.append(metric_batch[1].item())\n", + " metric_values_enhanced_tumor.append(metric_batch[2].item())\n", + " dice_metric.reset()\n", + " dice_metric_batch.reset()\n", + "\n", + " checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n", + " torch.save(model.state_dict(), checkpoint_path)\n", + " \n", + " # Log and versison model checkpoints using W&B artifacts.\n", + " artifact.add_file(local_path=checkpoint_path)\n", + " wandb.log_artifact(artifact, aliases=[f\"epoch_{epoch}\"])\n", + "\n", + " # Log validation metrics to W&B dashboard.\n", + " wandb.log(\n", + " {\n", + " \"validation/validation_step\": validation_step,\n", + " \"validation/mean_dice\": metric_values[-1],\n", + " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n", + " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n", + " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n", + " }\n", + " )\n", + " validation_step += 1\n", + "\n", + "\n", + "# Wait for this artifact to finish logging\n", + "artifact.wait()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instrumenting our code with `wandb.log` not only enables us to track all the metrics associated with our training and validation process, but also the all system metrics (our CPU and GPU in this case) on our W&B dashboard.\n", + "\n", + "![](./assets/viz-3.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we navigate to the artifacts tab in the W&B run dashboard, we will be able to access the different versions of model checkpoint artifacts that we logged during training.\n", + "\n", + "![](./assets/viz-4.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔱 Inference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using the artifacts interface, we can select which version of the artifact is the best model checkpoint, in this case, the mean epoch-wise training loss. We can also explore the entire lineage of the artifact and also use the version that we need.\n", + "\n", + "![](./assets/viz-5.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us fetch the version of the model artifact with the best epoch-wise mean training loss and load the checkpoint state dictionary to the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_artifact = wandb.use_artifact(\n", + " \"geekyrakshit/monai-brain-tumor-segmentation/d5ex6n4a-checkpoint:v49\",\n", + " type=\"model\",\n", + ")\n", + "model_artifact_dir = model_artifact.download()\n", + "model.load_state_dict(torch.load(os.path.join(model_artifact_dir, \"model.pth\")))\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 📸 Visualizing Predictions and Comparing with the Ground Truth Labels\n", + "\n", + "In order to visualize the predictions of the pre-trained model and compare them with the corresponding ground-truth segmentation mask using the interactive segmentation mask overlay, let us create another ultility function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def log_predictions_into_tables(\n", + " sample_image: np.array,\n", + " sample_label: np.array,\n", + " predicted_label: np.array,\n", + " split: str = None,\n", + " data_idx: int = None,\n", + " table: wandb.Table = None,\n", + "):\n", + " num_channels, _, _, num_slices = sample_image.shape\n", + " with tqdm(total=num_slices, leave=False) as progress_bar:\n", + " for slice_idx in range(num_slices):\n", + " wandb_images = []\n", + " for channel_idx in range(num_channels):\n", + " wandb_images += [\n", + " wandb.Image(\n", + " sample_image[channel_idx, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Tumor-Core\": {\n", + " \"mask_data\": sample_label[0, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Tumor Core\"},\n", + " },\n", + " \"prediction/Tumor-Core\": {\n", + " \"mask_data\": predicted_label[0, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Tumor Core\"},\n", + " },\n", + " },\n", + " ),\n", + " wandb.Image(\n", + " sample_image[channel_idx, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Whole-Tumor\": {\n", + " \"mask_data\": sample_label[1, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Whole Tumor\"},\n", + " },\n", + " \"prediction/Whole-Tumor\": {\n", + " \"mask_data\": predicted_label[1, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Whole Tumor\"},\n", + " },\n", + " },\n", + " ),\n", + " wandb.Image(\n", + " sample_image[channel_idx, :, :, slice_idx],\n", + " masks={\n", + " \"ground-truth/Enhancing-Tumor\": {\n", + " \"mask_data\": sample_label[2, :, :, slice_idx],\n", + " \"class_labels\": {0: \"background\", 1: \"Enhancing Tumor\"},\n", + " },\n", + " \"prediction/Enhancing-Tumor\": {\n", + " \"mask_data\": predicted_label[2, :, :, slice_idx] * 2,\n", + " \"class_labels\": {0: \"background\", 2: \"Enhancing Tumor\"},\n", + " },\n", + " },\n", + " ),\n", + " ]\n", + " table.add_data(split, data_idx, slice_idx, *wandb_images)\n", + " progress_bar.update(1)\n", + " return table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create the prediction table\n", + "prediction_table = wandb.Table(\n", + " columns=[\n", + " \"Split\",\n", + " \"Data Index\",\n", + " \"Slice Index\",\n", + " \"Image-Channel-0/Tumor-Core\",\n", + " \"Image-Channel-1/Tumor-Core\",\n", + " \"Image-Channel-2/Tumor-Core\",\n", + " \"Image-Channel-3/Tumor-Core\",\n", + " \"Image-Channel-0/Whole-Tumor\",\n", + " \"Image-Channel-1/Whole-Tumor\",\n", + " \"Image-Channel-2/Whole-Tumor\",\n", + " \"Image-Channel-3/Whole-Tumor\",\n", + " \"Image-Channel-0/Enhancing-Tumor\",\n", + " \"Image-Channel-1/Enhancing-Tumor\",\n", + " \"Image-Channel-2/Enhancing-Tumor\",\n", + " \"Image-Channel-3/Enhancing-Tumor\",\n", + " ]\n", + ")\n", + "\n", + "# Perform inference and visualization\n", + "with torch.no_grad():\n", + " config.max_prediction_images_visualized\n", + " max_samples = (\n", + " min(config.max_prediction_images_visualized, len(val_dataset))\n", + " if config.max_prediction_images_visualized > 0\n", + " else len(val_dataset)\n", + " )\n", + " progress_bar = tqdm(\n", + " enumerate(val_dataset[:max_samples]),\n", + " total=max_samples,\n", + " desc=\"Generating Predictions:\",\n", + " )\n", + " for data_idx, sample in progress_bar:\n", + " val_input = sample[\"image\"].unsqueeze(0).to(device)\n", + " val_output = inference(model, val_input)\n", + " val_output = post_trans(val_output[0])\n", + " prediction_table = log_predictions_into_tables(\n", + " sample_image=sample[\"image\"].cpu().numpy(),\n", + " sample_label=sample[\"label\"].cpu().numpy(),\n", + " predicted_label=val_output.cpu().numpy(),\n", + " data_idx=data_idx,\n", + " split=\"validation\",\n", + " table=prediction_table,\n", + " )\n", + "\n", + " wandb.log({\"Predictions/Tumor-Segmentation-Data\": prediction_table})\n", + "\n", + "\n", + "# End the experiment\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let us see how we can analyze and compare the predicted segmentation masks and the ground-truth labels for each class using the interactive segmentation mask overlay.\n", + "\n", + "![](./assets/viz-6.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also check out the report [Brain Tumor Segmentation using MONAI and WandB](https://wandb.ai/geekyrakshit/brain-tumor-segmentation/reports/Brain-Tumor-Segmentation-using-MONAI-and-WandB---Vmlldzo0MjUzODIw) for more details regarding training a brain-tumor segmentation model using MONAI and W&B." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/colabs/monai/assets/viz-1.gif b/colabs/monai/assets/viz-1.gif new file mode 100644 index 00000000..3eb97b47 Binary files /dev/null and b/colabs/monai/assets/viz-1.gif differ diff --git a/colabs/monai/assets/viz-2.gif b/colabs/monai/assets/viz-2.gif new file mode 100644 index 00000000..45ebb1db Binary files /dev/null and b/colabs/monai/assets/viz-2.gif differ diff --git a/colabs/monai/assets/viz-3.gif b/colabs/monai/assets/viz-3.gif new file mode 100644 index 00000000..df8d3954 Binary files /dev/null and b/colabs/monai/assets/viz-3.gif differ diff --git a/colabs/monai/assets/viz-4.gif b/colabs/monai/assets/viz-4.gif new file mode 100644 index 00000000..6cdea3b7 Binary files /dev/null and b/colabs/monai/assets/viz-4.gif differ diff --git a/colabs/monai/assets/viz-5.gif b/colabs/monai/assets/viz-5.gif new file mode 100644 index 00000000..2a0be82a Binary files /dev/null and b/colabs/monai/assets/viz-5.gif differ diff --git a/colabs/monai/assets/viz-6.gif b/colabs/monai/assets/viz-6.gif new file mode 100644 index 00000000..11eaa5ba Binary files /dev/null and b/colabs/monai/assets/viz-6.gif differ