From 4de4cf23f4a2bb5db27a8e063987d9ef9e78b97e Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Tue, 16 Jan 2024 13:40:40 +0000 Subject: [PATCH] update: notebook --- .../monai/3d_brain_tumor_segmentation.ipynb | 102 ++++++++++++++---- 1 file changed, 83 insertions(+), 19 deletions(-) diff --git a/colabs/monai/3d_brain_tumor_segmentation.ipynb b/colabs/monai/3d_brain_tumor_segmentation.ipynb index f83bfd48..42ca93a0 100644 --- a/colabs/monai/3d_brain_tumor_segmentation.ipynb +++ b/colabs/monai/3d_brain_tumor_segmentation.ipynb @@ -545,6 +545,13 @@ "## 🤖 Creating the Model, Loss, and Optimizer" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial we will be training a `SegResNet` model based on the paper [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). We create the `SegResNet` model that comes implemented as a PyTorch Module as part of the `monai.networks` API. We also create our optimizer and learning rate scheduler." + ] + }, { "cell_type": "code", "execution_count": null, @@ -552,6 +559,8 @@ "outputs": [], "source": [ "device = torch.device(\"cuda:0\")\n", + "\n", + "# create model\n", "model = SegResNet(\n", " blocks_down=[1, 2, 2, 4],\n", " blocks_up=[1, 1, 1],\n", @@ -560,26 +569,46 @@ " out_channels=3,\n", " dropout_prob=0.2,\n", ").to(device)\n", - "loss_function = DiceLoss(\n", - " smooth_nr=config.dice_loss_smoothen_numerator,\n", - " smooth_dr=config.dice_loss_smoothen_denominator,\n", - " squared_pred=config.dice_loss_squared_prediction,\n", - " to_onehot_y=config.dice_loss_target_onehot,\n", - " sigmoid=config.dice_loss_apply_sigmoid,\n", - ")\n", + "\n", + "# create optimizer\n", "optimizer = torch.optim.Adam(\n", " model.parameters(),\n", " config.initial_learning_rate,\n", " weight_decay=config.weight_decay,\n", ")\n", + "\n", + "# create learning rate scheduler\n", "lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer, T_max=config.max_train_epochs\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define our loss as multi-label `DiceLoss` using the `monai.losses` API and the corresponding dice metrics using the `monai.metrics` API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss_function = DiceLoss(\n", + " smooth_nr=config.dice_loss_smoothen_numerator,\n", + " smooth_dr=config.dice_loss_smoothen_denominator,\n", + " squared_pred=config.dice_loss_squared_prediction,\n", + " to_onehot_y=config.dice_loss_target_onehot,\n", + " sigmoid=config.dice_loss_apply_sigmoid,\n", ")\n", "\n", "dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n", "dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n", "post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n", "\n", + "# use automatic mixed-precision to accelerate training\n", "scaler = torch.cuda.amp.GradScaler()\n", "torch.backends.cudnn.benchmark = True" ] @@ -604,18 +633,21 @@ " return _compute(input)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🚝 Training and Validation\n", + "\n", + "Before we start training, let us define some metric properties which will later be logged with `wandb.log()` for tracking our training and validation experiments." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "batch_step = 0\n", - "validation_step = 0\n", - "metric_values = []\n", - "metric_values_tumor_core = []\n", - "metric_values_whole_tumor = []\n", - "metric_values_enhanced_tumor = []\n", "wandb.define_metric(\"epoch/epoch_step\")\n", "wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n", "wandb.define_metric(\"batch/batch_step\")\n", @@ -623,6 +655,27 @@ "wandb.define_metric(\"validation/validation_step\")\n", "wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")\n", "\n", + "batch_step = 0\n", + "validation_step = 0\n", + "metric_values = []\n", + "metric_values_tumor_core = []\n", + "metric_values_whole_tumor = []\n", + "metric_values_enhanced_tumor = []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 🍭 Execute Standard PyTorch Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n", "for epoch in epoch_progress_bar:\n", " model.train()\n", @@ -630,6 +683,8 @@ "\n", " total_batch_steps = len(train_dataset) // train_loader.batch_size\n", " batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n", + " \n", + " # Training Step\n", " for batch_data in batch_progress_bar:\n", " inputs, labels = (\n", " batch_data[\"image\"].to(device),\n", @@ -644,11 +699,13 @@ " scaler.update()\n", " epoch_loss += loss.item()\n", " batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n", + " ## Log batch-wise training loss to W&B\n", " wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n", " batch_step += 1\n", "\n", " lr_scheduler.step()\n", " epoch_loss /= total_batch_steps\n", + " ## Log batch-wise training loss and learning rate to W&B\n", " wandb.log(\n", " {\n", " \"epoch/epoch_step\": epoch,\n", @@ -658,6 +715,7 @@ " )\n", " epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n", "\n", + " # Validation and model checkpointing\n", " if (epoch + 1) % config.validation_intervals == 0:\n", " model.eval()\n", " with torch.no_grad():\n", @@ -681,28 +739,34 @@ "\n", " checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n", " torch.save(model.state_dict(), checkpoint_path)\n", + " \n", + " # Log and versison model checkpoints using W&B artifacts.\n", " artifact = wandb.Artifact(\n", " name=f\"{wandb.run.id}-checkpoint\", type=\"model\"\n", " )\n", " artifact.add_file(local_path=checkpoint_path)\n", " wandb.log_artifact(artifact, aliases=[f\"epoch_{epoch}\"])\n", "\n", + " # Log validation metrics to W&B dashboard.\n", " wandb.log(\n", " {\n", " \"validation/validation_step\": validation_step,\n", " \"validation/mean_dice\": metric_values[-1],\n", " \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n", - " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[\n", - " -1\n", - " ],\n", - " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[\n", - " -1\n", - " ],\n", + " \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n", + " \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n", " }\n", " )\n", " validation_step += 1" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔱 Inferece" + ] + }, { "cell_type": "code", "execution_count": null,