Skip to content

Commit

Permalink
update: notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Jan 16, 2024
1 parent 198c65c commit 4de4cf2
Showing 1 changed file with 83 additions and 19 deletions.
102 changes: 83 additions & 19 deletions colabs/monai/3d_brain_tumor_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,22 @@
"## 🤖 Creating the Model, Loss, and Optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial we will be training a `SegResNet` model based on the paper [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf). We create the `SegResNet` model that comes implemented as a PyTorch Module as part of the `monai.networks` API. We also create our optimizer and learning rate scheduler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\")\n",
"\n",
"# create model\n",
"model = SegResNet(\n",
" blocks_down=[1, 2, 2, 4],\n",
" blocks_up=[1, 1, 1],\n",
Expand All @@ -560,26 +569,46 @@
" out_channels=3,\n",
" dropout_prob=0.2,\n",
").to(device)\n",
"loss_function = DiceLoss(\n",
" smooth_nr=config.dice_loss_smoothen_numerator,\n",
" smooth_dr=config.dice_loss_smoothen_denominator,\n",
" squared_pred=config.dice_loss_squared_prediction,\n",
" to_onehot_y=config.dice_loss_target_onehot,\n",
" sigmoid=config.dice_loss_apply_sigmoid,\n",
")\n",
"\n",
"# create optimizer\n",
"optimizer = torch.optim.Adam(\n",
" model.parameters(),\n",
" config.initial_learning_rate,\n",
" weight_decay=config.weight_decay,\n",
")\n",
"\n",
"# create learning rate scheduler\n",
"lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer, T_max=config.max_train_epochs\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define our loss as multi-label `DiceLoss` using the `monai.losses` API and the corresponding dice metrics using the `monai.metrics` API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_function = DiceLoss(\n",
" smooth_nr=config.dice_loss_smoothen_numerator,\n",
" smooth_dr=config.dice_loss_smoothen_denominator,\n",
" squared_pred=config.dice_loss_squared_prediction,\n",
" to_onehot_y=config.dice_loss_target_onehot,\n",
" sigmoid=config.dice_loss_apply_sigmoid,\n",
")\n",
"\n",
"dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n",
"dice_metric_batch = DiceMetric(include_background=True, reduction=\"mean_batch\")\n",
"post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n",
"\n",
"# use automatic mixed-precision to accelerate training\n",
"scaler = torch.cuda.amp.GradScaler()\n",
"torch.backends.cudnn.benchmark = True"
]
Expand All @@ -604,32 +633,58 @@
" return _compute(input)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🚝 Training and Validation\n",
"\n",
"Before we start training, let us define some metric properties which will later be logged with `wandb.log()` for tracking our training and validation experiments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_step = 0\n",
"validation_step = 0\n",
"metric_values = []\n",
"metric_values_tumor_core = []\n",
"metric_values_whole_tumor = []\n",
"metric_values_enhanced_tumor = []\n",
"wandb.define_metric(\"epoch/epoch_step\")\n",
"wandb.define_metric(\"epoch/*\", step_metric=\"epoch/epoch_step\")\n",
"wandb.define_metric(\"batch/batch_step\")\n",
"wandb.define_metric(\"batch/*\", step_metric=\"batch/batch_step\")\n",
"wandb.define_metric(\"validation/validation_step\")\n",
"wandb.define_metric(\"validation/*\", step_metric=\"validation/validation_step\")\n",
"\n",
"batch_step = 0\n",
"validation_step = 0\n",
"metric_values = []\n",
"metric_values_tumor_core = []\n",
"metric_values_whole_tumor = []\n",
"metric_values_enhanced_tumor = []"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🍭 Execute Standard PyTorch Training Loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epoch_progress_bar = tqdm(range(config.max_train_epochs), desc=\"Training:\")\n",
"for epoch in epoch_progress_bar:\n",
" model.train()\n",
" epoch_loss = 0\n",
"\n",
" total_batch_steps = len(train_dataset) // train_loader.batch_size\n",
" batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)\n",
" \n",
" # Training Step\n",
" for batch_data in batch_progress_bar:\n",
" inputs, labels = (\n",
" batch_data[\"image\"].to(device),\n",
Expand All @@ -644,11 +699,13 @@
" scaler.update()\n",
" epoch_loss += loss.item()\n",
" batch_progress_bar.set_description(f\"train_loss: {loss.item():.4f}:\")\n",
" ## Log batch-wise training loss to W&B\n",
" wandb.log({\"batch/batch_step\": batch_step, \"batch/train_loss\": loss.item()})\n",
" batch_step += 1\n",
"\n",
" lr_scheduler.step()\n",
" epoch_loss /= total_batch_steps\n",
" ## Log batch-wise training loss and learning rate to W&B\n",
" wandb.log(\n",
" {\n",
" \"epoch/epoch_step\": epoch,\n",
Expand All @@ -658,6 +715,7 @@
" )\n",
" epoch_progress_bar.set_description(f\"Training: train_loss: {epoch_loss:.4f}:\")\n",
"\n",
" # Validation and model checkpointing\n",
" if (epoch + 1) % config.validation_intervals == 0:\n",
" model.eval()\n",
" with torch.no_grad():\n",
Expand All @@ -681,28 +739,34 @@
"\n",
" checkpoint_path = os.path.join(config.checkpoint_dir, \"model.pth\")\n",
" torch.save(model.state_dict(), checkpoint_path)\n",
" \n",
" # Log and versison model checkpoints using W&B artifacts.\n",
" artifact = wandb.Artifact(\n",
" name=f\"{wandb.run.id}-checkpoint\", type=\"model\"\n",
" )\n",
" artifact.add_file(local_path=checkpoint_path)\n",
" wandb.log_artifact(artifact, aliases=[f\"epoch_{epoch}\"])\n",
"\n",
" # Log validation metrics to W&B dashboard.\n",
" wandb.log(\n",
" {\n",
" \"validation/validation_step\": validation_step,\n",
" \"validation/mean_dice\": metric_values[-1],\n",
" \"validation/mean_dice_tumor_core\": metric_values_tumor_core[-1],\n",
" \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[\n",
" -1\n",
" ],\n",
" \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[\n",
" -1\n",
" ],\n",
" \"validation/mean_dice_whole_tumor\": metric_values_whole_tumor[-1],\n",
" \"validation/mean_dice_enhanced_tumor\": metric_values_enhanced_tumor[-1],\n",
" }\n",
" )\n",
" validation_step += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🔱 Inferece"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 4de4cf2

Please sign in to comment.