Skip to content

Commit

Permalink
Merge branch 'changes_2024' of github.com:dlmbl/image_restoration int…
Browse files Browse the repository at this point in the history
…o changes_2024
  • Loading branch information
Cateek committed Aug 22, 2024
2 parents 3945ebe + 7e7b76d commit e8948fb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 94 deletions.
32 changes: 9 additions & 23 deletions 03_COSDD/exercise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"\n",
"import torch\n",
"import tifffile\n",
Expand Down Expand Up @@ -307,8 +306,8 @@
"metadata": {},
"outputs": [],
"source": [
"real_batch_size = 16\n",
"n_grad_batches = 1\n",
"real_batch_size = 4\n",
"n_grad_batches = 4\n",
"print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n",
"crop_size = (256, 256)\n",
"train_split = 0.9\n",
Expand Down Expand Up @@ -408,9 +407,9 @@
"outputs": [],
"source": [
"dimensions = ... ### Insert a value here\n",
"s_code_channels = 16\n",
"s_code_channels = 32\n",
"\n",
"n_layers = 4\n",
"n_layers = 6\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -428,8 +427,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=... ### Insert a value here\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_filters=32,\n",
" n_layers=4,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -462,7 +461,7 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=False,\n",
" checkpointed=True,\n",
")"
]
},
Expand Down Expand Up @@ -579,7 +578,6 @@
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down Expand Up @@ -613,22 +611,12 @@
"# Exercise 2. Inference with COSDD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logger = logging.getLogger('pytorch_lightning')\n",
"logger.setLevel(logging.WARNING)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1. Load test data\n",
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference."
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference."
]
},
{
Expand All @@ -638,7 +626,7 @@
"outputs": [],
"source": [
"lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n",
"n_test_images = 2\n",
"n_test_images = 10\n",
"# load the data\n",
"test_set = tifffile.imread(lowsnr_path)\n",
"test_set = test_set[:n_test_images, np.newaxis]\n",
Expand Down Expand Up @@ -724,7 +712,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand All @@ -747,7 +734,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down
95 changes: 25 additions & 70 deletions 03_COSDD/solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"\n",
"import torch\n",
"import tifffile\n",
Expand All @@ -53,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -120,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"tags": [
"solution"
Expand Down Expand Up @@ -359,20 +358,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Effective batch size: 16\n"
]
}
],
"source": [
"real_batch_size = 16\n",
"n_grad_batches = 1\n",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"real_batch_size = 4\n",
"n_grad_batches = 4\n",
"print(f\"Effective batch size: {real_batch_size * n_grad_batches}\")\n",
"crop_size = (256, 256)\n",
"train_split = 0.9\n",
Expand Down Expand Up @@ -472,9 +463,9 @@
"outputs": [],
"source": [
"dimensions = ... ### Insert a value here\n",
"s_code_channels = 16\n",
"s_code_channels = 32\n",
"\n",
"n_layers = 4\n",
"n_layers = 6\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -492,8 +483,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=... ### Insert a value here\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_filters=32,\n",
" n_layers=4,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -526,35 +517,24 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=False,\n",
" checkpointed=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'vae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ar_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ar_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 's_decoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['s_decoder'])`.\n",
"/home/TA-bs/miniforge3/envs/05_image_restoration/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'direct_denoiser' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['direct_denoiser'])`.\n"
]
}
],
"outputs": [],
"source": [
"dimensions = 2 ### Insert a value here\n",
"s_code_channels = 16\n",
"s_code_channels = 32\n",
"\n",
"n_layers = 4\n",
"n_layers = 6\n",
"z_dims = [s_code_channels // 2] * n_layers\n",
"downsampling = [1] * n_layers\n",
"lvae = LadderVAE(\n",
Expand All @@ -572,8 +552,8 @@
" s_code_channels=s_code_channels,\n",
" kernel_size=5,\n",
" noise_direction=\"x\", ### Insert a value here\n",
" n_filters=16,\n",
" n_layers=3,\n",
" n_filters=32,\n",
" n_layers=4,\n",
" n_gaussians=4,\n",
" dimensions=dimensions,\n",
")\n",
Expand Down Expand Up @@ -606,7 +586,7 @@
" data_mean=low_snr.mean(),\n",
" data_std=low_snr.std(),\n",
" n_grad_batches=n_grad_batches,\n",
" checkpointed=False,\n",
" checkpointed=True,\n",
")"
]
},
Expand Down Expand Up @@ -723,30 +703,18 @@
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"tags": [
"solution"
]
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"outputs": [],
"source": [
"model_name = \"mito-confocal\" ### Insert a value here\n",
"checkpoint_path = os.path.join(\"checkpoints\", model_name)\n",
Expand All @@ -764,7 +732,6 @@
" max_time=max_time, # Remove this time limit to train the model fully\n",
" log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),\n",
" callbacks=[EarlyStopping(patience=patience, monitor=\"val/elbo\")],\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down Expand Up @@ -798,22 +765,12 @@
"# Exercise 2. Inference with COSDD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logger = logging.getLogger('pytorch_lightning')\n",
"logger.setLevel(logging.WARNING)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1. Load test data\n",
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 2 to speed up inference."
"The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference."
]
},
{
Expand All @@ -823,7 +780,7 @@
"outputs": [],
"source": [
"lowsnr_path = \"./../data/mito-confocal-lowsnr.tif\"\n",
"n_test_images = 2\n",
"n_test_images = 10\n",
"# load the data\n",
"test_set = tifffile.imread(lowsnr_path)\n",
"test_set = test_set[:n_test_images, np.newaxis]\n",
Expand Down Expand Up @@ -909,7 +866,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand All @@ -932,7 +888,6 @@
" enable_progress_bar=False,\n",
" enable_checkpointing=False,\n",
" logger=False,\n",
" precision=\"bf16-mixed\",\n",
")"
]
},
Expand Down
2 changes: 1 addition & 1 deletion 04_DenoiSplit/exercise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@
},
{
"cell_type": "markdown",
"id": "ccb2694d",
"id": "337b207c",
"metadata": {},
"source": [
"<hr style=\"height:2px;\"><div class=\"alert alert-block alert-success\"><h1>End of the exercise</h1>\n",
Expand Down

0 comments on commit e8948fb

Please sign in to comment.