From 3710a1e9987ec47d88c0944fc0399a067d1b6228 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Mon, 4 Dec 2023 16:25:05 -0500 Subject: [PATCH 1/7] Update Image_Classification_using_PyTorch_Lightning.ipynb Updates links and pip installs --- ...lassification_using_PyTorch_Lightning.ipynb | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb index 0b98bdfe..f8ddb4df 100644 --- a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb @@ -27,7 +27,7 @@ "\n", "# Image Classification using PyTorch Lightning ⚡️\n", "\n", - "We will build an image classification pipeline using PyTorch Lightning. We will follow this [style guide](https://pytorch-lightning.readthedocs.io/en/stable/starter/style_guide.html) to increase the readability and reproducibility of our code. A cool explanation of this available [here](https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY)." + "We will build an image classification pipeline using PyTorch Lightning. We will follow this [style guide](https://lightning.ai/docs/pytorch/stable/starter/style_guide.html) to increase the readability and reproducibility of our code. A cool explanation of this available [here](https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY)." ] }, { @@ -46,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install pytorch-lightning -q\n", + "!pip install lightning -q\n", "# install weights and biases\n", "!pip install wandb -qU" ] @@ -65,9 +65,9 @@ "metadata": {}, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "# your favorite machine learning tracking tool\n", - "from pytorch_lightning.loggers import WandbLogger\n", + "from lightning.pytorch.loggers import WandbLogger\n", "\n", "import torch\n", "from torch import nn\n", @@ -115,7 +115,7 @@ "- Apply transforms (rotate, tokenize, etc…).\n", "- Wrap inside a DataLoader.\n", "\n", - "Learn more about datamodules [here](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html). Let's build a datamodule for the Cifar-10 dataset. " + "Learn more about datamodules [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html). Let's build a datamodule for the Cifar-10 dataset. " ] }, { @@ -168,8 +168,8 @@ "source": [ "## 📱 Callbacks\n", "\n", - "A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few [built-in callbacks](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks) which are regularly used. \n", - "Learn more about callbacks in PyTorch Lightning [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html)." + "A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few [built-in callbacks](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html#built-in-callbacks) which are regularly used. \n", + "Learn more about callbacks in PyTorch Lightning [here](https://lightning.ai/docs/pytorch/latest/extensions/callbacks.html)." ] }, { @@ -179,7 +179,7 @@ "source": [ "### Built-in Callbacks\n", "\n", - "In this tutorial, we will use [Early Stopping](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping) and [Model Checkpoint](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint) built-in callbacks. They can be passed to the `Trainer`.\n" + "In this tutorial, we will use [Early Stopping](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.callbacks.EarlyStopping) and [Model Checkpoint](https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint) built-in callbacks. They can be passed to the `Trainer`.\n" ] }, { @@ -437,7 +437,7 @@ "I hope you find this report helpful. I will encourage to play with the code and train an image classifier with a dataset of your choice. \n", "\n", "Here are some resources to learn more about PyTorch Lightning:\n", - "- [Step-by-step walk-through](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html) - This is one of the official tutorials. Their documentation is really well written and I highly encourage it as a good learning resource.\n", + "- [Step-by-step walk-through](https://lightning.ai/docs/pytorch/latest/starter/introduction.html) - This is one of the official tutorials. Their documentation is really well written and I highly encourage it as a good learning resource.\n", "- [Use Pytorch Lightning with Weights & Biases](https://wandb.me/lightning) - This is a quick colab that you can run through to learn more about how to use W&B with PyTorch Lightning." ] } From c63bb1b3e394133c08e2087a5db3c6753609a0d1 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Mon, 4 Dec 2023 17:03:46 -0500 Subject: [PATCH 2/7] Update fine tuning ptl notebook --- ..._a_Transformer_with_Pytorch_Lightning.ipynb | 18 +++++++++++++++--- ...lassification_using_PyTorch_Lightning.ipynb | 14 +++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb index 9a2a1f10..0f91a0c9 100644 --- a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb @@ -64,7 +64,7 @@ "outputs": [], "source": [ "# Install some dependencies\n", - "!pip install pandas torch pytorch-lightning transformers==4.1.1 -q\n", + "!pip install pandas torch lightning transformers\n", "!pip install -Uq wandb" ] }, @@ -81,7 +81,7 @@ "import transformers\n", "import numpy as np\n", "import pandas as pd\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ] }, { @@ -426,7 +426,7 @@ " gpus = -1 if torch.cuda.is_available() else 0\n", " \n", " # Construct a Trainer object with the W&B logger we created and epoch set by the config object\n", - " trainer = pl.Trainer(max_epochs=config.epochs, gpus=gpus, logger=logger)\n", + " trainer = pl.Trainer(max_epochs=config.epochs, logger=logger)\n", " \n", " # Build data loaders for our datasets, using the batch_size from our config object\n", " train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size)\n", @@ -536,6 +536,18 @@ "kernelspec": { "display_name": "Python 3", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } }, "nbformat": 4, diff --git a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb index f8ddb4df..82c9402b 100644 --- a/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Image_Classification_using_PyTorch_Lightning.ipynb @@ -46,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install lightning -q\n", + "!pip install lightning torchvision -q\n", "# install weights and biases\n", "!pip install wandb -qU" ] @@ -452,6 +452,18 @@ "kernelspec": { "display_name": "Python 3", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } }, "nbformat": 4, From d57448e7965b306a4bce3864176022f5d86262cc Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Mon, 4 Dec 2023 17:19:19 -0500 Subject: [PATCH 3/7] update lightning imports --- ...ghtning_models_with_Weights_&_Biases.ipynb | 34 ++++++++++++------- .../Profile_PyTorch_Code.ipynb | 4 +-- ...rch_Lightning_and_Weights_and_Biases.ipynb | 6 ++-- ...fer_Learning_Using_PyTorch_Lightning.ipynb | 6 ++-- ...db_End_to_End_with_PyTorch_Lightning.ipynb | 14 ++++---- 5 files changed, 36 insertions(+), 28 deletions(-) diff --git a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb index a4746eaf..4e5a3fc3 100644 --- a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb +++ b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb @@ -28,8 +28,8 @@ "Coupled with the [Weights & Biases integration](https://docs.wandb.com/library/integrations/lightning), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n", "\n", "```python\n", - "from pytorch_lightning.loggers import WandbLogger\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from lightning.pytorch import Trainer\n", "\n", "wandb_logger = WandbLogger()\n", "trainer = Trainer(logger=wandb_logger)\n", @@ -64,7 +64,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q pytorch-lightning wandb" + "!pip install -q lightning wandb torchvision" ] }, { @@ -150,6 +150,15 @@ "* Call self.log in `training_step` and `validation_step` to log the metrics" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import lightning.pytorch as pl" + ] + }, { "cell_type": "code", "execution_count": null, @@ -160,9 +169,8 @@ "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", "from torch.optim import Adam\n", "from torchmetrics.functional import accuracy\n", - "from pytorch_lightning import LightningModule\n", "\n", - "class MNIST_LitModule(LightningModule):\n", + "class MNIST_LitModule(pl.LightningModule):\n", "\n", " def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):\n", " '''method used to define our model parameters'''\n", @@ -273,7 +281,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "from lightning.pytorch.callbacks import ModelCheckpoint\n", "\n", "checkpoint_callback = ModelCheckpoint(monitor='val_accuracy', mode='max')" ] @@ -284,9 +292,9 @@ "source": [ "## 💡 Tracking Experiments with WandbLogger\n", "\n", - "PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n", + "PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n", "\n", - "#### `pytorch_lightning.loggers.WandbLogger()`\n", + "#### `lightning.pytorch.loggers.WandbLogger()`\n", "\n", "| Functionality | Argument/Function | PS |\n", "| ------ | ------ | ------ |\n", @@ -295,9 +303,9 @@ "| Organize runs by project | `WandbLogger(... ,project='my_project')` | |\n", "| Log histograms of gradients and parameters | `WandbLogger.watch(model)` | `WandbLogger.watch(model, log='all')` to log parameter histograms |\n", "| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |\n", - "| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table` |\n", + "| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table`, etc. |\n", "\n", - "See the [WandbLogger docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. " + "See the [WandbLogger docs](https://lightning.ai/docs/pytorch/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) here for all parameters. " ] }, { @@ -306,8 +314,8 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.loggers import WandbLogger\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from lightning.pytorch import Trainer\n", "\n", "wandb_logger = WandbLogger(project='MNIST', # group runs in \"MNIST\" project\n", " log_model='all') # log all new checkpoints during training" @@ -334,7 +342,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.callbacks import Callback\n", + "from lightning.pytorch.callbacks import Callback\n", " \n", "class LogPredictionsCallback(Callback):\n", " \n", diff --git a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb index b7968866..00b56a76 100644 --- a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb +++ b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb @@ -88,7 +88,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q wandb pytorch_lightning torch_tb_profiler" + "!pip install -q wandb lightning torch_tb_profiler torchvision" ] }, { @@ -99,7 +99,7 @@ "source": [ "import glob\n", "\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", diff --git a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb index b150820c..1b3aa0c3 100644 --- a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb +++ b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb @@ -81,7 +81,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -qqq wandb pytorch-lightning torchmetrics" + "!pip install -qqq wandb lightning torchmetrics" ] }, { @@ -142,7 +142,7 @@ "outputs": [], "source": [ "# ⚡ PyTorch Lightning\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torchmetrics\n", "pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n", "\n", @@ -150,7 +150,7 @@ "import wandb\n", "\n", "# ⚡ 🤝 🏋️‍♀️\n", - "from pytorch_lightning.loggers import WandbLogger\n" + "from lightning.pytorch.loggers import WandbLogger\n" ] }, { diff --git a/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb index 3a4a744b..776bc3b1 100644 --- a/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Transfer_Learning_Using_PyTorch_Lightning.ipynb @@ -38,7 +38,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install wandb pytorch-lightning -qqq" + "!pip install wandb lightning torchvision -qqq" ] }, { @@ -56,9 +56,9 @@ "source": [ "import os\n", "\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "# your favorite machine learning tracking tool\n", - "from pytorch_lightning.loggers import WandbLogger\n", + "from lightning.pytorch.loggers import WandbLogger\n", "\n", "import torch\n", "from torch import nn\n", diff --git a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb index 7a3b7edd..16d7a3b6 100644 --- a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb @@ -32,7 +32,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q pytorch-lightning wandb" + "!pip install -q lightning wandb torchvision" ] }, { @@ -184,7 +184,7 @@ "outputs": [], "source": [ "from torchvision import transforms\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader, random_split\n", "from skimage import io, transform\n", @@ -368,7 +368,7 @@ "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", "from torch.optim import Adam\n", "from torchmetrics.functional import accuracy\n", - "from pytorch_lightning import LightningModule\n", + "from lightning.pytorch import LightningModule\n", "from torchvision import models\n", "\n", "class NatureLitModule(LightningModule):\n", @@ -468,7 +468,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.callbacks import Callback\n", + "from lightning.pytorch.callbacks import Callback\n", "\n", "class LogPredictionsCallback(Callback):\n", "\n", @@ -525,9 +525,9 @@ "metadata": {}, "outputs": [], "source": [ - "from pytorch_lightning.callbacks import ModelCheckpoint\n", - "from pytorch_lightning.loggers import WandbLogger\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from lightning.pytorch import Trainer\n", "\n", "wandb.init(project=PROJECT_NAME,\n", " entity=ENTITY,\n", From 126f1349f124b39bbb94356f775481082665f306 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Wed, 6 Dec 2023 07:42:11 -0500 Subject: [PATCH 4/7] Update Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb remove broken arg dataloader_idx --- ...ptimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb index 4e5a3fc3..3cf40163 100644 --- a/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb +++ b/colabs/pytorch-lightning/Optimize_Pytorch_Lightning_models_with_Weights_&_Biases.ipynb @@ -347,7 +347,7 @@ "class LogPredictionsCallback(Callback):\n", " \n", " def on_validation_batch_end(\n", - " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n", + " self, trainer, pl_module, outputs, batch, batch_idx):\n", " \"\"\"Called when the validation batch ends.\"\"\"\n", " \n", " # `outputs` comes from `LightningModule.validation_step`\n", From 213e6a0c1bb585550c9e05c3488e558181918981 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Wed, 6 Dec 2023 08:03:53 -0500 Subject: [PATCH 5/7] fix broken flags and deprecated functions --- ...Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb | 8 ++++---- colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb index 0f91a0c9..3ddb4919 100644 --- a/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Fine_tuning_a_Transformer_with_Pytorch_Lightning.ipynb @@ -245,7 +245,7 @@ " \n", " # Download the raw cola data from the 'zipfile' reference we added to the cola-raw artifact.\n", " raw_data_artifact = run.use_artifact(\"cola-raw:latest\")\n", - " zip_path = raw_data_artifact.get_path(\"zipfile\").download()\n", + " zip_path = raw_data_artifact.get_entry(\"zipfile\").download()\n", " !unzip -o $zip_path # jupyter hack to unzip data :P\n", " \n", " # Read in the raw data, log it to W&B as a wandb.Table\n", @@ -298,7 +298,7 @@ "\n", " # Download the preprocessed data\n", " pp_data_artifact = run.use_artifact(\"preprocessed-data:latest\")\n", - " data_path = pp_data_artifact.get_path(\"dataset\").download()\n", + " data_path = pp_data_artifact.get_entry(\"dataset\").download()\n", " dataset = torch.load(data_path)\n", "\n", " # Calculate the number of samples to include in each set.\n", @@ -410,8 +410,8 @@ "\n", " # Load the datasets from the split-dataset artifact\n", " data = run.use_artifact(\"split-dataset:latest\")\n", - " train_dataset = torch.load(data.get_path(\"train-data\").download())\n", - " val_dataset = torch.load(data.get_path(\"validation-data\").download())\n", + " train_dataset = torch.load(data.get_entry(\"train-data\").download())\n", + " val_dataset = torch.load(data.get_entry(\"validation-data\").download())\n", "\n", " # Extract the config object associated with the run\n", " config = run.config\n", diff --git a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb index 00b56a76..35040872 100644 --- a/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb +++ b/colabs/pytorch-lightning/Profile_PyTorch_Code.ipynb @@ -357,7 +357,7 @@ " with profiler:\n", " profiler_callback = TorchTensorboardProfilerCallback(profiler)\n", "\n", - " trainer = pl.Trainer(gpus=1, max_epochs=1, max_steps=total_steps,\n", + " trainer = pl.Trainer(max_epochs=1, max_steps=total_steps,\n", " logger=pl.loggers.WandbLogger(log_model=True, save_code=True),\n", " callbacks=[profiler_callback], precision=wandb.config.precision)\n", "\n", From 13fc7f51293eb6a1be804adf00dfe1f992d2f20a Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Wed, 6 Dec 2023 08:30:47 -0500 Subject: [PATCH 6/7] Update Wandb_End_to_End_with_PyTorch_Lightning.ipynb --- ...db_End_to_End_with_PyTorch_Lightning.ipynb | 1414 +++++++++-------- 1 file changed, 735 insertions(+), 679 deletions(-) diff --git a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb index 16d7a3b6..32a5df77 100644 --- a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb @@ -1,681 +1,737 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open\n", - "" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Ws0nlkuGOpDy" + }, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aU1q92uCOpD1" + }, + "source": [ + "\"Weights\n", + "\n", + "\n", + "\n", + "# W&B Tutorial with Pytorch Lightning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mHCKzEzSOpD1" + }, + "source": [ + "## 🛠️ Install `wandb` and `pytorch-lightning`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RnyGWvwDOpD1" + }, + "outputs": [], + "source": [ + "!pip install -q lightning wandb torchvision" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G5wpAEAoOpD2" + }, + "source": [ + "## Login to W&B either through Python or CLI\n", + "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n", + "\n", + "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n", + "```\n", + "import os\n", + "import wandb\n", + "\n", + "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n", + "```\n", + "You can also login via the CLI with:\n", + "```\n", + "wandb login --host \n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YjXGiBLQOpD2" + }, + "outputs": [], + "source": [ + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r5HflvclOpD2" + }, + "outputs": [], + "source": [ + "wandb.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zNuxn8BjOpD2" + }, + "source": [ + "## ⚱ Logging the Raw Training Data as an Artifact" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X9mkr942OpD2" + }, + "outputs": [], + "source": [ + "#@title Enter your W&B project and entity\n", + "\n", + "# FORM VARIABLES\n", + "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n", + "ENTITY = \"wandb\"#@param {type:\"string\"}\n", + "\n", + "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n", + "# to select one of these three datasets\n", + "# TINY dataset: 100 images, 30MB\n", + "# SMALL dataset: 1000 images, 312MB\n", + "# MEDIUM dataset: 5000 images, 1.5GB\n", + "# LARGE dataset: 12,000 images, 3.6GB\n", + "\n", + "SIZE = \"TINY\"\n", + "\n", + "if SIZE == \"TINY\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n", + " src_zip = \"nature_100.zip\"\n", + " DATA_SRC = \"nature_100\"\n", + " IMAGES_PER_LABEL = 10\n", + " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n", + "elif SIZE == \"SMALL\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n", + " src_zip = \"nature_1K.zip\"\n", + " DATA_SRC = \"nature_1K\"\n", + " IMAGES_PER_LABEL = 100\n", + " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n", + "elif SIZE == \"MEDIUM\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", + " src_zip = \"nature_12K.zip\"\n", + " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", + " IMAGES_PER_LABEL = 500\n", + " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n", + "elif SIZE == \"LARGE\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", + " src_zip = \"nature_12K.zip\"\n", + " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", + " IMAGES_PER_LABEL = 1000\n", + " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o8nApIhdOpD3" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!curl -SL $src_url > $src_zip\n", + "!unzip $src_zip" + ] + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ], + "metadata": { + "id": "ALmdQ7wISLaA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XQ_Kwsg9OpD3" + }, + "outputs": [], + "source": [ + "import wandb\n", + "import pandas as pd\n", + "import os\n", + "\n", + "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n", + " img_paths = []\n", + " for root, dirs, files in os.walk('nature_100', topdown=False):\n", + " for name in files:\n", + " img_path = os.path.join(root, name)\n", + " label = img_path.split('/')[1]\n", + " img_paths.append([img_path, label])\n", + "\n", + " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n", + " index_df.to_csv('index.csv', index=False)\n", + "\n", + " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n", + " train_art.add_dir('nature_100')\n", + "\n", + " # Also adding a csv indicating the labels of each image\n", + " train_art.add_file('index.csv')\n", + " wandb.log_artifact(train_art)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1eJpOk_VOpD3" + }, + "source": [ + "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n", + "- Makes it easy to interopt your DataLoaders with new versions of datasets\n", + "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z2g9JRrwOpD3" + }, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "import lightning.pytorch as pl\n", + "import torch\n", + "from torch.utils.data import Dataset, DataLoader, random_split\n", + "from skimage import io, transform\n", + "from torchvision import transforms, utils, models\n", + "import math\n", + "\n", + "class NatureDataset(Dataset):\n", + " def __init__(self,\n", + " wandb_run,\n", + " artifact_name_alias=\"Nature_100:latest\",\n", + " local_target_dir=\"Nature_100:latest\",\n", + " transform=None):\n", + " self.local_target_dir = local_target_dir\n", + " self.transform = transform\n", + "\n", + " # Pull down the artifact locally to load it into memory\n", + " art = wandb_run.use_artifact(artifact_name_alias)\n", + " path_at = art.download(root=self.local_target_dir)\n", + "\n", + " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n", + " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n", + " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n", + " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n", + "\n", + " def __len__(self):\n", + " return len(self.ref_df)\n", + "\n", + " def __getitem__(self, idx):\n", + " if torch.is_tensor(idx):\n", + " idx = idx.tolist()\n", + "\n", + " img_path = self.ref_df.iloc[idx, 0]\n", + "\n", + " image = io.imread(img_path)\n", + " label = self.ref_df.iloc[idx, 1]\n", + " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n", + "\n", + " if self.transform:\n", + " image = self.transform(image)\n", + "\n", + " return image, label\n", + "\n", + "\n", + "class NatureDatasetModule(pl.LightningDataModule):\n", + " def __init__(self,\n", + " wandb_run,\n", + " artifact_name_alias: str = \"Nature_100:latest\",\n", + " local_target_dir: str = \"Nature_100:latest\",\n", + " batch_size: int = 16,\n", + " input_size: int = 224,\n", + " seed: int = 42):\n", + " super().__init__()\n", + " self.wandb_run = wandb_run\n", + " self.artifact_name_alias = artifact_name_alias\n", + " self.local_target_dir = local_target_dir\n", + " self.batch_size = batch_size\n", + " self.input_size = input_size\n", + " self.seed = seed\n", + "\n", + " def setup(self, stage=None):\n", + " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n", + " artifact_name_alias=self.artifact_name_alias,\n", + " local_target_dir=self.local_target_dir,\n", + " transform=transforms.Compose([transforms.ToTensor(),\n", + " transforms.CenterCrop(self.input_size),\n", + " transforms.Normalize((0.485, 0.456, 0.406),\n", + " (0.229, 0.224, 0.225))]))\n", + "\n", + " nature_length = len(self.nature_dataset)\n", + " train_size = math.floor(0.8 * nature_length)\n", + " val_size = math.floor(0.2 * nature_length)\n", + " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n", + " [train_size, val_size],\n", + " generator=torch.Generator().manual_seed(self.seed))\n", + " return self\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.nature_train, batch_size=self.batch_size)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.nature_val, batch_size=self.batch_size)\n", + "\n", + " def predict_dataloader(self):\n", + " pass\n", + "\n", + " def teardown(self, stage: str):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NLhjdvF-OpD3" + }, + "source": [ + "##How Logging in your Pytorch `LightningModule`works:\n", + "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`.\n", + "\n", + "```\n", + "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\")\n", + "trainer = Trainer(logger=wandb_logger)\n", + "```\n", + "\n", + "\n", + "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B.\n", + "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oi1NpQs7OpD4" + }, + "source": [ + "### Some helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HLgYji2oOpD4" + }, + "outputs": [], + "source": [ + "# Some helper functions\n", + "\n", + "def set_parameter_requires_grad(model, feature_extracting):\n", + " if feature_extracting:\n", + " for param in model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n", + " # Initialize these variables which will be set in this if statement. Each of these\n", + " # variables is model specific.\n", + " model_ft = None\n", + " input_size = 0\n", + "\n", + " if model_name == \"resnet\":\n", + " \"\"\" Resnet18\n", + " \"\"\"\n", + " model_ft = models.resnet18(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " num_ftrs = model_ft.fc.in_features\n", + " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n", + " input_size = 224\n", + "\n", + " elif model_name == \"squeezenet\":\n", + " \"\"\" Squeezenet\n", + " \"\"\"\n", + " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n", + " model_ft.num_classes = num_classes\n", + " input_size = 224\n", + "\n", + " elif model_name == \"densenet\":\n", + " \"\"\" Densenet\n", + " \"\"\"\n", + " model_ft = models.densenet121(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " num_ftrs = model_ft.classifier.in_features\n", + " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n", + " input_size = 224\n", + "\n", + " else:\n", + " print(\"Invalid model name, exiting...\")\n", + " exit()\n", + "\n", + " return model_ft, input_size" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LLRK0S6oOpD4" + }, + "source": [ + "### Writing the `LightningModule`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "k0tTtK5zOpD4" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", + "from torch.optim import Adam\n", + "from torchmetrics.functional import accuracy\n", + "from lightning.pytorch import LightningModule\n", + "from torchvision import models\n", + "\n", + "class NatureLitModule(LightningModule):\n", + " def __init__(self,\n", + " model_name,\n", + " num_classes=10,\n", + " feature_extract=True,\n", + " lr=0.01):\n", + " '''method used to define our model parameters'''\n", + " super().__init__()\n", + "\n", + " self.model_name = model_name\n", + " self.num_classes = num_classes\n", + " self.feature_extract = feature_extract\n", + " self.model, self.input_size = initialize_model(model_name=self.model_name,\n", + " num_classes=self.num_classes,\n", + " feature_extract=True)\n", + "\n", + " # loss\n", + " self.loss = CrossEntropyLoss()\n", + "\n", + " # optimizer parameters\n", + " self.lr = lr\n", + "\n", + " # save hyper-parameters to self.hparams (auto-logged by W&B)\n", + " self.save_hyperparameters()\n", + "\n", + " # Record the gradients of all the layers\n", + " wandb.watch(self.model)\n", + "\n", + " def forward(self, x):\n", + " '''method used for inference input -> output'''\n", + " x = self.model(x)\n", + "\n", + " return x\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " '''needs to return a loss from a single batch'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('train/loss', loss)\n", + " self.log('train/accuracy', acc)\n", + "\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " '''used for logging metrics'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('validation/loss', loss)\n", + " self.log('validation/accuracy', acc)\n", + "\n", + " # Let's return preds to use it in a custom callback\n", + " return preds, y\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " '''used for logging metrics'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('test/loss', loss)\n", + " self.log('test/accuracy', acc)\n", + "\n", + " def configure_optimizers(self):\n", + " '''defines model optimizer'''\n", + " return Adam(self.parameters(), lr=self.lr)\n", + "\n", + "\n", + " def _get_preds_loss_accuracy(self, batch):\n", + " '''convenience function since train/valid/test steps are similar'''\n", + " x, y = batch\n", + " logits = self(x)\n", + " preds = torch.argmax(logits, dim=1)\n", + " loss = self.loss(logits, y)\n", + " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", + " return preds, y, loss, acc" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bvZHZuZzOpD4" + }, + "source": [ + "### Instrument Callbacks to log additional things at certain points in your code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "j9pHqaw1OpD5" + }, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import Callback\n", + "\n", + "class LogPredictionsCallback(Callback):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + "\n", + " def on_validation_epoch_start(self, trainer, pl_module):\n", + " self.batch_dfs = []\n", + " self.image_list = []\n", + " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n", + "\n", + "\n", + " def on_validation_batch_end(\n", + " self, trainer, pl_module, outputs, batch, batch_idx):\n", + " \"\"\"Called when the validation batch ends.\"\"\"\n", + "\n", + " # Append validation predictions and ground truth to log in confusion matrix\n", + " x, y = batch\n", + " preds, y = outputs\n", + " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.cpu().numpy(), \"Predictions\": preds.cpu().numpy()}))\n", + "\n", + " # Add wandb.Image to a table to log at the end of validation\n", + " x = x.cpu().numpy().transpose(0, 2, 3, 1)\n", + " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n", + " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n", + " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n", + "\n", + "\n", + " def on_validation_epoch_end(self, trainer, pl_module):\n", + " # Collect statistics for whole validation set and log\n", + " class_names = trainer.datamodule.nature_dataset.class_names\n", + " val_df = pd.concat(self.batch_dfs)\n", + " wandb.log({\"validation_table\": self.val_table,\n", + " \"images_over_time\": self.image_list,\n", + " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(),\n", + " preds=val_df[\"Predictions\"].tolist(),\n", + " class_names=class_names)}, step=trainer.global_step)\n", + "\n", + " del self.batch_dfs\n", + " del self.val_table\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAiH-kb1OpD5" + }, + "source": [ + "## 🏋️‍ Main Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FrYTQ7uaOpD5" + }, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from lightning.pytorch import Trainer\n", + "\n", + "wandb.init(project=PROJECT_NAME,\n", + " entity=ENTITY,\n", + " job_type='training',\n", + " config={\n", + " \"model_name\": \"squeezenet\",\n", + " \"batch_size\": 16\n", + " })\n", + "\n", + "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}')\n", + "\n", + "log_predictions_callback = LogPredictionsCallback()\n", + "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", + "\n", + "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", + "\n", + "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", + " artifact_name_alias = \"Nature_100:latest\",\n", + " local_target_dir = \"Nature_100:latest\",\n", + " batch_size=wandb.config['batch_size'],\n", + " input_size=model.input_size)\n", + "nature_module.setup()\n", + "\n", + "trainer = Trainer(logger=wandb_logger, # W&B integration\n", + " callbacks=[log_predictions_callback, checkpoint_callback],\n", + " max_epochs=5,\n", + " log_every_n_steps=5)\n", + "trainer.fit(model, datamodule=nature_module)\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uEII8J7UOpD5" + }, + "source": [ + "### Syncing with W&B Offline\n", + "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n", + "\n", + "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xi-EhHsxOpD5" + }, + "source": [ + "## Retrieve a model checkpoint artifact and resume training\n", + "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jiE1Bk7fOpD5" + }, + "outputs": [], + "source": [ + "#@title Enter which checkpoint you want to resume training from:\n", + "\n", + "# FORM VARIABLES\n", + "ARTIFACT_NAME_ALIAS = \"nature-oyxk79m1:v4\" #@param {type:\"string\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jZZXWRatOpD5" + }, + "outputs": [], + "source": [ + "wandb.init(project=PROJECT_NAME,\n", + " entity=ENTITY,\n", + " job_type='resume_training')\n", + "\n", + "# Retrieve model checkpoint artifact and restore previous hyperparameters\n", + "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n", + "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n", + "logging_run = model_chkpt_art.logged_by()\n", + "wandb.config = logging_run.config\n", + "\n", + "# Can create a new artifact name or continue logging to the old one\n", + "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n", + "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name)\n", + "\n", + "log_predictions_callback = LogPredictionsCallback()\n", + "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", + "\n", + "model = NatureLitModule.load_from_checkpoint(f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt')\n", + "\n", + "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", + " artifact_name_alias = \"Nature_100:latest\",\n", + " local_target_dir = \"Nature_100:latest\",\n", + " batch_size=wandb.config['batch_size'],\n", + " input_size=model.input_size)\n", + "nature_module.setup()\n", + "\n", + "\n", + "\n", + "trainer = Trainer(logger=wandb_logger, # W&B integration\n", + " callbacks=[log_predictions_callback, checkpoint_callback],\n", + " max_epochs=10,\n", + " log_every_n_steps=5)\n", + "trainer.fit(model, datamodule=nature_module)\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KTforR6EOpD5" + }, + "source": [ + "## Model Registry\n", + "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n", + "\n", + "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n", + "1. Access your team's model registry by going the team page and selecting `Model Registry`\n", + "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n", + "\n", + "2. Create a new Registered Model.\n", + "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n", + "\n", + "3. Go to the artifacts tab of the project that holds all your model checkpoints\n", + "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n", + "\n", + "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n", + "\n", + "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging.\n", + "\n", + "\n", + "### Create a Centralized Hub for all your models\n", + "- Add a model card, tags, slack notifactions to your Registered Model\n", + "- Change aliases to reflect when models move through different phases\n", + "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n", + "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Weights\n", - "\n", - "\n", - "\n", - "# W&B Tutorial with Pytorch Lightning" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🛠️ Install `wandb` and `pytorch-lightning`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -q lightning wandb torchvision" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Login to W&B either through Python or CLI\n", - "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n", - "\n", - "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n", - "```\n", - "import os\n", - "import wandb \n", - "\n", - "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n", - "```\n", - "You can also login via the CLI with: \n", - "```\n", - "wandb login --host \n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import wandb" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb.login()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## ⚱ Logging the Raw Training Data as an Artifact" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#@title Enter your W&B project and entity\n", - "\n", - "# FORM VARIABLES\n", - "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n", - "ENTITY = \"wandb\"#@param {type:\"string\"}\n", - "\n", - "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n", - "# to select one of these three datasets\n", - "# TINY dataset: 100 images, 30MB\n", - "# SMALL dataset: 1000 images, 312MB\n", - "# MEDIUM dataset: 5000 images, 1.5GB\n", - "# LARGE dataset: 12,000 images, 3.6GB\n", - "\n", - "SIZE = \"TINY\"\n", - "\n", - "if SIZE == \"TINY\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n", - " src_zip = \"nature_100.zip\"\n", - " DATA_SRC = \"nature_100\"\n", - " IMAGES_PER_LABEL = 10\n", - " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n", - "elif SIZE == \"SMALL\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n", - " src_zip = \"nature_1K.zip\"\n", - " DATA_SRC = \"nature_1K\"\n", - " IMAGES_PER_LABEL = 100\n", - " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n", - "elif SIZE == \"MEDIUM\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", - " src_zip = \"nature_12K.zip\"\n", - " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", - " IMAGES_PER_LABEL = 500\n", - " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n", - "elif SIZE == \"LARGE\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", - " src_zip = \"nature_12K.zip\"\n", - " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", - " IMAGES_PER_LABEL = 1000\n", - " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "!curl -SL $src_url > $src_zip\n", - "!unzip $src_zip" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import wandb\n", - "import pandas as pd\n", - "import os\n", - "\n", - "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n", - " img_paths = []\n", - " for root, dirs, files in os.walk('nature_100', topdown=False):\n", - " for name in files:\n", - " img_path = os.path.join(root, name)\n", - " label = img_path.split('/')[1]\n", - " img_paths.append([img_path, label])\n", - "\n", - " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n", - " index_df.to_csv('index.csv', index=False)\n", - "\n", - " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n", - " train_art.add_dir('nature_100')\n", - "\n", - " # Also adding a csv indicating the labels of each image\n", - " train_art.add_file('index.csv')\n", - " wandb.log_artifact(train_art)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n", - "- Makes it easy to interopt your DataLoaders with new versions of datasets\n", - "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torchvision import transforms\n", - "import lightning.pytorch as pl\n", - "import torch\n", - "from torch.utils.data import Dataset, DataLoader, random_split\n", - "from skimage import io, transform\n", - "from torchvision import transforms, utils, models\n", - "import math\n", - "\n", - "class NatureDataset(Dataset):\n", - " def __init__(self, \n", - " wandb_run, \n", - " artifact_name_alias=\"Nature_100:latest\", \n", - " local_target_dir=\"Nature_100:latest\", \n", - " transform=None):\n", - " self.local_target_dir = local_target_dir\n", - " self.transform = transform\n", - "\n", - " # Pull down the artifact locally to load it into memory\n", - " art = wandb_run.use_artifact(artifact_name_alias)\n", - " path_at = art.download(root=self.local_target_dir)\n", - "\n", - " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n", - " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n", - " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n", - " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n", - "\n", - " def __len__(self):\n", - " return len(self.ref_df)\n", - "\n", - " def __getitem__(self, idx):\n", - " if torch.is_tensor(idx):\n", - " idx = idx.tolist()\n", - "\n", - " img_path = self.ref_df.iloc[idx, 0]\n", - "\n", - " image = io.imread(img_path)\n", - " label = self.ref_df.iloc[idx, 1]\n", - " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n", - "\n", - " if self.transform:\n", - " image = self.transform(image)\n", - "\n", - " return image, label\n", - "\n", - "\n", - "class NatureDatasetModule(pl.LightningDataModule):\n", - " def __init__(self,\n", - " wandb_run,\n", - " artifact_name_alias: str = \"Nature_100:latest\",\n", - " local_target_dir: str = \"Nature_100:latest\",\n", - " batch_size: int = 16,\n", - " input_size: int = 224,\n", - " seed: int = 42):\n", - " super().__init__()\n", - " self.wandb_run = wandb_run\n", - " self.artifact_name_alias = artifact_name_alias\n", - " self.local_target_dir = local_target_dir\n", - " self.batch_size = batch_size\n", - " self.input_size = input_size\n", - " self.seed = seed\n", - "\n", - " def setup(self, stage=None):\n", - " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n", - " artifact_name_alias=self.artifact_name_alias,\n", - " local_target_dir=self.local_target_dir,\n", - " transform=transforms.Compose([transforms.ToTensor(),\n", - " transforms.CenterCrop(self.input_size),\n", - " transforms.Normalize((0.485, 0.456, 0.406),\n", - " (0.229, 0.224, 0.225))]))\n", - "\n", - " nature_length = len(self.nature_dataset)\n", - " train_size = math.floor(0.8 * nature_length)\n", - " val_size = math.floor(0.2 * nature_length)\n", - " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n", - " [train_size, val_size],\n", - " generator=torch.Generator().manual_seed(self.seed))\n", - " return self\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.nature_train, batch_size=self.batch_size)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.nature_val, batch_size=self.batch_size)\n", - "\n", - " def predict_dataloader(self):\n", - " pass\n", - "\n", - " def teardown(self, stage: str):\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##How Logging in your Pytorch `LightningModule`works:\n", - "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`. \n", - " \n", - "```\n", - "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\") \n", - "trainer = Trainer(logger=wandb_logger) \n", - "```\n", - "\n", - "\n", - "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B. \n", - "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Some helper functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Some helper functions\n", - "\n", - "def set_parameter_requires_grad(model, feature_extracting):\n", - " if feature_extracting:\n", - " for param in model.parameters():\n", - " param.requires_grad = False\n", - "\n", - "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n", - " # Initialize these variables which will be set in this if statement. Each of these\n", - " # variables is model specific.\n", - " model_ft = None\n", - " input_size = 0\n", - "\n", - " if model_name == \"resnet\":\n", - " \"\"\" Resnet18\n", - " \"\"\"\n", - " model_ft = models.resnet18(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " num_ftrs = model_ft.fc.in_features\n", - " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n", - " input_size = 224\n", - "\n", - " elif model_name == \"squeezenet\":\n", - " \"\"\" Squeezenet\n", - " \"\"\"\n", - " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n", - " model_ft.num_classes = num_classes\n", - " input_size = 224\n", - "\n", - " elif model_name == \"densenet\":\n", - " \"\"\" Densenet\n", - " \"\"\"\n", - " model_ft = models.densenet121(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " num_ftrs = model_ft.classifier.in_features\n", - " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n", - " input_size = 224\n", - "\n", - " else:\n", - " print(\"Invalid model name, exiting...\")\n", - " exit()\n", - "\n", - " return model_ft, input_size" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Writing the `LightningModule`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", - "from torch.optim import Adam\n", - "from torchmetrics.functional import accuracy\n", - "from lightning.pytorch import LightningModule\n", - "from torchvision import models\n", - "\n", - "class NatureLitModule(LightningModule):\n", - " def __init__(self,\n", - " model_name,\n", - " num_classes=10,\n", - " feature_extract=True,\n", - " lr=0.01):\n", - " '''method used to define our model parameters'''\n", - " super().__init__()\n", - "\n", - " self.model_name = model_name\n", - " self.num_classes = num_classes\n", - " self.feature_extract = feature_extract\n", - " self.model, self.input_size = initialize_model(model_name=self.model_name,\n", - " num_classes=self.num_classes,\n", - " feature_extract=True)\n", - "\n", - " # loss\n", - " self.loss = CrossEntropyLoss()\n", - "\n", - " # optimizer parameters\n", - " self.lr = lr\n", - "\n", - " # save hyper-parameters to self.hparams (auto-logged by W&B)\n", - " self.save_hyperparameters()\n", - "\n", - " # Record the gradients of all the layers\n", - " wandb.watch(self.model)\n", - "\n", - " def forward(self, x):\n", - " '''method used for inference input -> output'''\n", - " x = self.model(x)\n", - "\n", - " return x\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " '''needs to return a loss from a single batch'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('train/loss', loss)\n", - " self.log('train/accuracy', acc)\n", - "\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " '''used for logging metrics'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('validation/loss', loss)\n", - " self.log('validation/accuracy', acc)\n", - "\n", - " # Let's return preds to use it in a custom callback\n", - " return preds, y\n", - "\n", - " def validation_epoch_end(self, validation_step_outputs):\n", - " \"\"\"Called when the validation ends.\"\"\"\n", - " preds, y = validation_step_outputs\n", - " all_preds = torch.stack(preds)\n", - " all_y = torch.stack(y)\n", - "\n", - " def test_step(self, batch, batch_idx):\n", - " '''used for logging metrics'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('test/loss', loss)\n", - " self.log('test/accuracy', acc)\n", - "\n", - " def configure_optimizers(self):\n", - " '''defines model optimizer'''\n", - " return Adam(self.parameters(), lr=self.lr)\n", - "\n", - "\n", - " def _get_preds_loss_accuracy(self, batch):\n", - " '''convenience function since train/valid/test steps are similar'''\n", - " x, y = batch\n", - " logits = self(x)\n", - " preds = torch.argmax(logits, dim=1)\n", - " loss = self.loss(logits, y)\n", - " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", - " return preds, y, loss, acc" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Instrument Callbacks to log additional things at certain points in your code" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightning.pytorch.callbacks import Callback\n", - "\n", - "class LogPredictionsCallback(Callback):\n", - "\n", - " def __init__(self):\n", - " super().__init__()\n", - "\n", - " \n", - " def on_validation_epoch_start(self, trainer, pl_module):\n", - " self.batch_dfs = []\n", - " self.image_list = []\n", - " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n", - "\n", - " \n", - " def on_validation_batch_end(\n", - " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n", - " \"\"\"Called when the validation batch ends.\"\"\"\n", - "\n", - " # Append validation predictions and ground truth to log in confusion matrix\n", - " x, y = batch\n", - " preds, y = outputs\n", - " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.numpy(), \"Predictions\": preds.numpy()}))\n", - "\n", - " # Add wandb.Image to a table to log at the end of validation\n", - " x = x.numpy().transpose(0, 2, 3, 1)\n", - " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n", - " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n", - " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n", - " \n", - " \n", - " def on_validation_epoch_end(self, trainer, pl_module):\n", - " # Collect statistics for whole validation set and log\n", - " class_names = trainer.datamodule.nature_dataset.class_names\n", - " val_df = pd.concat(self.batch_dfs)\n", - " wandb.log({\"validation_table\": self.val_table,\n", - " \"images_over_time\": self.image_list,\n", - " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(), \n", - " preds=val_df[\"Predictions\"].tolist(), \n", - " class_names=class_names)}, step=trainer.global_step)\n", - "\n", - " del self.batch_dfs\n", - " del self.val_table\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🏋️‍ Main Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightning.pytorch.callbacks import ModelCheckpoint\n", - "from lightning.pytorch.loggers import WandbLogger\n", - "from lightning.pytorch import Trainer\n", - "\n", - "wandb.init(project=PROJECT_NAME,\n", - " entity=ENTITY,\n", - " job_type='training',\n", - " config={\n", - " \"model_name\": \"squeezenet\",\n", - " \"batch_size\": 16\n", - " })\n", - "\n", - "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}') \n", - "\n", - "log_predictions_callback = LogPredictionsCallback()\n", - "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", - "\n", - "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", - "\n", - "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", - " artifact_name_alias = \"Nature_100:latest\",\n", - " local_target_dir = \"Nature_100:latest\",\n", - " batch_size=wandb.config['batch_size'],\n", - " input_size=model.input_size)\n", - "nature_module.setup()\n", - "\n", - "trainer = Trainer(logger=wandb_logger, # W&B integration\n", - " callbacks=[log_predictions_callback, checkpoint_callback],\n", - " max_epochs=5,\n", - " log_every_n_steps=5) \n", - "trainer.fit(model, datamodule=nature_module)\n", - "\n", - "wandb.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Syncing with W&B Offline\n", - "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n", - "\n", - "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Retrieve a model checkpoint artifact and resume training\n", - "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#@title Enter which checkpoint you want to resume training from:\n", - "\n", - "# FORM VARIABLES\n", - "ARTIFACT_NAME_ALIAS = \"nature-zb4swpn6:v4\" #@param {type:\"string\"}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb.init(project=PROJECT_NAME,\n", - " entity=ENTITY,\n", - " job_type='resume_training')\n", - "\n", - "# Retrieve model checkpoint artifact and restore previous hyperparameters\n", - "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n", - "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n", - "logging_run = model_chkpt_art.logged_by()\n", - "wandb.config = logging_run.config\n", - "\n", - "# Can create a new artifact name or continue logging to the old one\n", - "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n", - "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name) \n", - "\n", - "log_predictions_callback = LogPredictionsCallback()\n", - "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", - "\n", - "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", - "\n", - "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", - " artifact_name_alias = \"Nature_100:latest\",\n", - " local_target_dir = \"Nature_100:latest\",\n", - " batch_size=wandb.config['batch_size'],\n", - " input_size=model.input_size)\n", - "nature_module.setup()\n", - "\n", - "\n", - "\n", - "trainer = Trainer(logger=wandb_logger, # W&B integration\n", - " resume_from_checkpoint = f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt',\n", - " callbacks=[log_predictions_callback, checkpoint_callback],\n", - " max_epochs=10,\n", - " log_every_n_steps=5) \n", - "trainer.fit(model, datamodule=nature_module)\n", - "\n", - "wandb.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Registry \n", - "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n", - "\n", - "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n", - "1. Access your team's model registry by going the team page and selecting `Model Registry`\n", - "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n", - "\n", - "2. Create a new Registered Model. \n", - "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n", - "\n", - "3. Go to the artifacts tab of the project that holds all your model checkpoints\n", - "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n", - "\n", - "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n", - "\n", - "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging. \n", - "\n", - "\n", - "### Create a Centralized Hub for all your models\n", - "- Add a model card, tags, slack notifactions to your Registered Model\n", - "- Change aliases to reflect when models move through different phases\n", - "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n", - "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "include_colab_link": true, - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 03020a3fb1462712e3fd785d5951a7c69c470353 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Wed, 6 Dec 2023 08:58:11 -0500 Subject: [PATCH 7/7] Update Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb --- ...rch_Lightning_and_Weights_and_Biases.ipynb | 1824 +++++++++-------- 1 file changed, 968 insertions(+), 856 deletions(-) diff --git a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb index 1b3aa0c3..018a1ad6 100644 --- a/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb +++ b/colabs/pytorch-lightning/Supercharge_your_Training_with_Pytorch_Lightning_and_Weights_and_Biases.ipynb @@ -1,857 +1,969 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Weights\n", - "\n", - "\n", - "\n", - "# ⚡ 💘 🏋️‍♀️ Supercharge your Training with PyTorch Lightning + Weights & Biases" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Weights" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "At Weights & Biases, we love anything\n", - "that makes training deep learning models easier.\n", - "That's why we worked with the folks at PyTorch Lightning to\n", - "[integrate our experiment tracking tool](https://docs.wandb.com/library/integrations/lightning)\n", - "directly into\n", - "[the Lightning library](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases).\n", - "\n", - "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/) is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.\n", - "It retains all the flexibility of PyTorch,\n", - "in case you need it,\n", - "but adds some useful abstractions\n", - "and builds in some best practices.\n", - "\n", - "## What this notebook covers:\n", - "\n", - "1. Differences between PyTorch and PyTorch Lightning, including how to set up `LightningModules` and `LightningDataModules`\n", - "2. How to get basic metric logging with the [`WandbLogger`](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases)\n", - "3. How to log media with W&B and fully customize logging with Lightning `Callbacks`\n", - "\n", - "## The interactive dashboard in W&B will look like this:\n", - "\n", - "![](https://i.imgur.com/lIbMyFR.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Follow along with a [video tutorial](http://wandb.me/lit-video)!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 🚀 Installing and importing" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`wandb` and `pytorch-lightning` are both easily installable via [`pip`](https://pip.pypa.io/en/stable/)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qqq wandb lightning torchmetrics" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "PyTorch Lightning is built on top of PyTorch,\n", - "so we still need to import vanilla PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# numpy for non-GPU array math\n", - "import numpy as np\n", - "\n", - "# 🍦 Vanilla PyTorch\n", - "import torch\n", - "from torch.nn import functional as F\n", - "from torch import nn\n", - "from torch.utils.data import DataLoader, random_split\n", - "\n", - "# 👀 Torchvision for CV\n", - "from torchvision.datasets import MNIST\n", - "from torchvision import transforms\n", - "\n", - "# remove slow mirror from list of MNIST mirrors\n", - "MNIST.mirrors = [mirror for mirror in MNIST.mirrors\n", - " if not mirror.startswith(\"http://yann.lecun.com\")]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Much of Lightning is built on the [Modules](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)\n", - "API from PyTorch,\n", - "but adds extra features\n", - "(like data loading and logging)\n", - "that are common to lots of PyTorch projects.\n", - "\n", - "Let's bring those in,\n", - "plus W&B and the integration.\n", - "\n", - "Lastly, we log in to the [Weights & Biases web service](https://wandb.ai).\n", - "If you've never used W&B,\n", - "you'll need to sign up first.\n", - "Accounts are free forever for academic and public projects." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ⚡ PyTorch Lightning\n", - "import lightning.pytorch as pl\n", - "import torchmetrics\n", - "pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n", - "\n", - "# 🏋️‍♀️ Weights & Biases\n", - "import wandb\n", - "\n", - "# ⚡ 🤝 🏋️‍♀️\n", - "from lightning.pytorch.loggers import WandbLogger\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb.login()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> _Note_: If you're executing your training in a terminal, rather than a notebook, you don't need to include `wandb.login()` in your script.\n", - "Instead, call `wandb login` in the terminal and we'll keep you logged in for future runs." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 🏗️ Building a Model with Lightning\n", - "\n", - "In PyTorch Lightning, models are built with `LightningModule` ([docs here](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html)), which has all the functionality of a vanilla `torch.nn.Module` (🍦) but with a few delicious cherries of added functionality on top (🍨).\n", - "These cherries are there to cut down on boilerplate and\n", - "help separate out the ML engineering code\n", - "from the actual machine learning.\n", - "\n", - "For example, the mechanics of iterating over batches\n", - "as part of an epoch are extracted away,\n", - "so long as you define what happens on the `training_step`.\n", - "\n", - "To make a working model out of a `LightningModule`,\n", - "we need to define a new `class` and add a few methods on top.\n", - "\n", - "We'll demonstrate this process with `LitMLP`,\n", - "which applies a two-layer perceptron\n", - "(aka two fully-connected layers and\n", - "a fully-connected softmax readout layer)\n", - "to input `Tensors`.\n", - "\n", - "> _Note_: It is common in the Lightning community to shorten \"Lightning\" to \"[Lit](https://www.urbandictionary.com/define.php?term=it%27s%20lit)\".\n", - "This sometimes it sound like\n", - "[your code was written by Travis Scott](https://www.youtube.com/watch?v=y3FCXV8oEZU).\n", - "We consider this a good thing." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🍦 `__init__` and `forward`\n", - "\n", - "First, we need to add two methods that\n", - "are part of any vanilla PyTorch model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Those methods are:\n", - "* `__init__` to do any setup, just like any Python class\n", - "* `forward` for inference, just like a PyTorch Module\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `forward` pass method is standard,\n", - "and it'll be different for every project,\n", - "so we won't comment on it.\n", - "\n", - "The `__init__` method,\n", - "which `init`ializes new instances of the class,\n", - "is a good place to log hyperparameter information to `wandb`.\n", - "\n", - "This is done with the `save_hyperparameters` method,\n", - "which captures all of the arguments to the initializer\n", - "and adds them to a dictionary at `self.hparams` --\n", - "that all comes for free as part of the `LightningModule`.\n", - "\n", - "> _Note_: `hparams` is logged to `wandb` as the `config`,\n", - "so you'll never lose track of the arguments you used to run a model again!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class LitMLP(pl.LightningModule):\n", - "\n", - " def __init__(self, in_dims, n_classes=10,\n", - " n_layer_1=128, n_layer_2=256, lr=1e-4):\n", - " super().__init__()\n", - "\n", - " # we flatten the input Tensors and pass them through an MLP\n", - " self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)\n", - " self.layer_2 = nn.Linear(n_layer_1, n_layer_2)\n", - " self.layer_3 = nn.Linear(n_layer_2, n_classes)\n", - "\n", - " # log hyperparameters\n", - " self.save_hyperparameters()\n", - "\n", - " # compute the accuracy -- no need to roll your own!\n", - " self.train_acc = torchmetrics.Accuracy()\n", - " self.valid_acc = torchmetrics.Accuracy()\n", - " self.test_acc = torchmetrics.Accuracy()\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\n", - " Defines a forward pass using the Stem-Learner-Task\n", - " design pattern from Deep Learning Design Patterns:\n", - " https://www.manning.com/books/deep-learning-design-patterns\n", - " \"\"\"\n", - " batch_size, *dims = x.size()\n", - "\n", - " # stem: flatten\n", - " x = x.view(batch_size, -1)\n", - "\n", - " # learner: two fully-connected layers\n", - " x = F.relu(self.layer_1(x))\n", - " x = F.relu(self.layer_2(x))\n", - " \n", - " # task: compute class logits\n", - " x = self.layer_3(x)\n", - " x = F.log_softmax(x, dim=1)\n", - "\n", - " return x\n", - "\n", - " # convenient method to get the loss on a batch\n", - " def loss(self, xs, ys):\n", - " logits = self(xs) # this calls self.forward\n", - " loss = F.nll_loss(logits, ys)\n", - " return logits, loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> _Note_: for pedagogical purposes, we're splitting out\n", - "each stage of building the `LitMLP` into a different cell.\n", - "In a more typical workflow,\n", - "this would all happen in the `class` definition.\n", - "\n", - "> _Note_: if you're familiar with PyTorch,\n", - "you might be surprised to see we aren't taking care with `.device`s:\n", - "no `to_cuda` etc. PyTorch Lightning handles all that for you! 😎" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🍨 `training_step` and `configure_optimizers`\n", - "Now, we add some special methods so that our `LitMLP` can be trained\n", - "using PyTorch Lightning's training API.\n", - "\n", - "> _Note_: if you've used Keras, this might be familiar.\n", - "It's very similar to the `.fit` API in that library." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Those methods are\n", - "\n", - "* `training_step`, which takes a batch and computes the loss; backprop goes through it\n", - "* `configure_optimizers`, which returns the `torch.optim.Optimizer` to apply after the `training_step`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> _Note_: `training_step` is part of a rich system of callbacks in PyTorch Lightning.\n", - "These callbacks are methods that get called\n", - "at specific points during training\n", - "(e.g. when a validation epoch ends),\n", - "and they are a major part of what makes\n", - "PyTorch Lightning both useful and extensible." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here's where we add some more serious logging code.\n", - "`self.log` takes a name and value for a metric.\n", - "Under the hood, this will get passed to `wandb.log` if you're using W&B.\n", - "\n", - "The logging behavior of PyTorch Lightning is both intelligent and configurable.\n", - "For example, by passing the `on_epoch`\n", - "keyword argument here,\n", - "we'll get `_epoch`-wise averages\n", - "of the metrics logged on each `_step`,\n", - "and those metrics will be named differently\n", - "in the W&B interface.\n", - "When training in a distributed setting,\n", - "these averages will be automatically computed across nodes.\n", - "\n", - "Read more about the `log` method [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def training_step(self, batch, batch_idx):\n", - " xs, ys = batch\n", - " logits, loss = self.loss(xs, ys)\n", - " preds = torch.argmax(logits, 1)\n", - "\n", - " # logging metrics we calculated by hand\n", - " self.log('train/loss', loss, on_epoch=True)\n", - " # logging a pl.Metric\n", - " self.train_acc(preds, ys)\n", - " self.log('train/acc', self.train_acc, on_epoch=True)\n", - " \n", - " return loss\n", - "\n", - "def configure_optimizers(self):\n", - " return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n", - "\n", - "LitMLP.training_step = training_step\n", - "LitMLP.configure_optimizers = configure_optimizers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## ➕ Optional methods for even better logging\n", - "\n", - "The code above will log our model's performance,\n", - "system metrics, and more to W&B.\n", - "\n", - "If we want to take our logging to the next level,\n", - "we need to make use of PyTorch Lightning's callback system.\n", - "\n", - "> _Note_: thanks to the clean design of PyTorch Lightning,\n", - "the training code below will run with or without any\n", - "of this extra logging code. Nice!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The other callbacks we'll make use of fall into two categories:\n", - "* methods that trigger on each batch for a dataset: `validation_step` and `test_step`\n", - "* methods that trigger at the end of an epoch,\n", - "or a full pass over a given dataset: `{training, validation, test}_epoch_end`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 💾 `test`ing and saving the model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We use the test set to evaluate the performance of the final model,\n", - "so the `test` callbacks will be called at the end of the training pipeline.\n", - "\n", - "For performance on the `test` and `validation` sets,\n", - "we're typically less concerned about how\n", - "we do on intermediate steps and more\n", - "with how we did overall.\n", - "That's why below, we pass in\n", - "`on_step=False` and `on_epoch=True`\n", - "so that we log only `epoch`-wise metrics.\n", - "\n", - "> _Note_: That's actually the default behavior for `.log` when it's called inside of a `validation` or a `test` loop -- but not when it's called inside a `training` loop! Check out the table of default behaviors for `.log` [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def test_step(self, batch, batch_idx):\n", - " xs, ys = batch\n", - " logits, loss = self.loss(xs, ys)\n", - " preds = torch.argmax(logits, 1)\n", - "\n", - " self.test_acc(preds, ys)\n", - " self.log(\"test/loss_epoch\", loss, on_step=False, on_epoch=True)\n", - " self.log(\"test/acc_epoch\", self.test_acc, on_step=False, on_epoch=True)\n", - "\n", - "LitMLP.test_step = test_step" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We'll also take the opportunity to save the model in the\n", - "[portable `ONNX` format](https://onnx.ai/).\n", - "\n", - "\n", - "Later,\n", - "we'll see that this allows us to use the\n", - "[Netron model viewer](https://github.com/lutzroeder/netron) in W&B." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def test_epoch_end(self, test_step_outputs): # args are defined as part of pl API\n", - " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n", - " model_filename = \"model_final.onnx\"\n", - " self.to_onnx(model_filename, dummy_input, export_params=True)\n", - " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n", - " artifact.add_file(model_filename)\n", - " wandb.log_artifact(artifact)\n", - "\n", - "LitMLP.test_epoch_end = test_epoch_end" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 📊 Logging `Histograms`" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "For the `validation_data`,\n", - "let's track not only the `acc`uracy and `loss`,\n", - "but also the `logits`:\n", - "the un-normalized class probabilities.\n", - "That way, we can track if our network\n", - "is becoming more or less confident over time.\n", - "\n", - "There's a problem though:\n", - "`.log` wants to average,\n", - "but we'd rather look at a distribution.\n", - "\n", - "So instead, on every `validation_step`,\n", - "we'll `return` the `logits`,\n", - "rather than `log`ging them.\n", - "\n", - "Then, when we reach the `end`\n", - "of the `validation_epoch`,\n", - "the `logits` are available as the\n", - "`validation_step_outputs` -- a list.\n", - "\n", - "So to log we'll take those `logits`,\n", - "concatenate them together,\n", - "and turn them into a histogram with [`wandb.Histogram`](https://docs.wandb.com/library/log#histograms).\n", - "\n", - "Because we're no longer using Lightning's `.log` interface and are instead using `wandb`,\n", - "we need to drop down a level and use\n", - "`self.experiment.logger.log`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def validation_step(self, batch, batch_idx):\n", - " xs, ys = batch\n", - " logits, loss = self.loss(xs, ys)\n", - " preds = torch.argmax(logits, 1)\n", - " self.valid_acc(preds, ys)\n", - "\n", - " self.log(\"valid/loss_epoch\", loss) # default on val/test is on_epoch only\n", - " self.log('valid/acc_epoch', self.valid_acc)\n", - " \n", - " return logits\n", - "\n", - "def validation_epoch_end(self, validation_step_outputs):\n", - " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n", - " model_filename = f\"model_{str(self.global_step).zfill(5)}.onnx\"\n", - " torch.onnx.export(self, dummy_input, model_filename, opset_version=11)\n", - " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n", - " artifact.add_file(model_filename)\n", - " self.logger.experiment.log_artifact(artifact)\n", - "\n", - " flattened_logits = torch.flatten(torch.cat(validation_step_outputs))\n", - " self.logger.experiment.log(\n", - " {\"valid/logits\": wandb.Histogram(flattened_logits.to(\"cpu\")),\n", - " \"global_step\": self.global_step})\n", - "\n", - "LitMLP.validation_step = validation_step\n", - "LitMLP.validation_epoch_end = validation_epoch_end" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that we're once again saving\n", - "the model in ONNX format.\n", - "That way, we can roll back our model to any given epoch --\n", - "useful in case the evaluation on the test set reveals we've overfit." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 📲 `Callback`s for extra-fancy logging" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "What we've done so far\n", - "will tell us how well our model\n", - "is using our system resources,\n", - "how well our model is training and generalizing,\n", - "and how confident it is.\n", - "\n", - "But DNNs often fail in pernicious and silent ways.\n", - "Often, the only way to notice these failures\n", - "is to look at how the model is doing\n", - "on specific examples.\n", - "\n", - "So let's additionally log some detailed information on some specific examples:\n", - "the inputs, outputs,\n", - "and `pred`ictions.\n", - "\n", - "We'll do this by writing our own `Callback` --\n", - "one that, after every `validation_epoch` ends,\n", - "logs input images and output predictions\n", - "using W&B's `Image` logger.\n", - "\n", - "> _Note_:\n", - "For more on the W&B media toolkit, read the [docs](https://docs.wandb.com/library/log#media)\n", - "or check out\n", - "[this Colab](http://wandb.me/media-colab)\n", - "to see everything it's capable of." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class ImagePredictionLogger(pl.Callback):\n", - " def __init__(self, val_samples, num_samples=32):\n", - " super().__init__()\n", - " self.val_imgs, self.val_labels = val_samples\n", - " self.val_imgs = self.val_imgs[:num_samples]\n", - " self.val_labels = self.val_labels[:num_samples]\n", - " \n", - " def on_validation_epoch_end(self, trainer, pl_module):\n", - " val_imgs = self.val_imgs.to(device=pl_module.device)\n", - "\n", - " logits = pl_module(val_imgs)\n", - " preds = torch.argmax(logits, 1)\n", - "\n", - " trainer.logger.experiment.log({\n", - " \"examples\": [wandb.Image(x, caption=f\"Pred:{pred}, Label:{y}\") \n", - " for x, pred, y in zip(val_imgs, preds, self.val_labels)],\n", - " \"global_step\": trainer.global_step\n", - " })" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 🛒 Loading data\n", - "\n", - "Data pipelines can be created with:\n", - "* 🍦 Vanilla Pytorch `DataLoaders`\n", - "* ⚡ Pytorch Lightning `DataModules`\n", - "\n", - "`DataModules` are more structured definition, which allows for additional optimizations such as automated distribution of workload between CPU & GPU.\n", - "Using `DataModules` is recommended whenever possible!\n", - "\n", - "A `DataModule` is also defined by an interface:\n", - "* `prepare_data` (optional) which is called only once and on 1 GPU -- typically something like the data download step we have below\n", - "* `setup`, which is called on each GPU separately and accepts `stage` to define if we are at `fit` or `test` step\n", - "* `train_dataloader`, `val_dataloader` and `test_dataloader` to load each dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class MNISTDataModule(pl.LightningDataModule):\n", - "\n", - " def __init__(self, data_dir='./', batch_size=128):\n", - " super().__init__()\n", - " self.data_dir = data_dir\n", - " self.batch_size = batch_size\n", - " self.transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize((0.1307,), (0.3081,))])\n", - "\n", - " def prepare_data(self):\n", - " # download data, train then test\n", - " MNIST(self.data_dir, train=True, download=True)\n", - " MNIST(self.data_dir, train=False, download=True)\n", - "\n", - " def setup(self, stage=None):\n", - "\n", - " # we set up only relevant datasets when stage is specified\n", - " if stage == 'fit' or stage is None:\n", - " mnist = MNIST(self.data_dir, train=True, transform=self.transform)\n", - " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n", - " if stage == 'test' or stage is None:\n", - " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", - "\n", - " # we define a separate DataLoader for each of train/val/test\n", - " def train_dataloader(self):\n", - " mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)\n", - " return mnist_train\n", - "\n", - " def val_dataloader(self):\n", - " mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)\n", - " return mnist_val\n", - "\n", - " def test_dataloader(self):\n", - " mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)\n", - " return mnist_test\n", - "\n", - "# setup data\n", - "mnist = MNISTDataModule()\n", - "mnist.prepare_data()\n", - "mnist.setup()\n", - "\n", - "# grab samples to log predictions on\n", - "samples = next(iter(mnist.val_dataloader()))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 👟 Making a `Trainer`\n", - "\n", - "The `DataLoader` and the `LightningModule`\n", - "are brought together by a `Trainer`,\n", - "which orchestrates data loading,\n", - "gradient calculation,\n", - "optimizer logic,\n", - "and logging. \n", - "\n", - "Luckily, we don't need to sub-class the `Trainer`,\n", - "we just need to configure it with keyword arguments." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And that is where we'll use the `pytorch_lightning.loggers.WandbLogger` to connect our logging to W&B." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb_logger = WandbLogger(project=\"lit-wandb\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> _Note_: Check out [the documentation](https://docs.wandb.com/library/integrations/lightning) for customization options. I like `group`s and `tag`s!.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can then set up our `Trainer` and customize several options, such as gradient accumulation, half precision training and distributed computing.\n", - "\n", - "We'll stick to the basics for this example,\n", - "but half-precision training and easy scaling to distributed settings are two of the major reasons why folks like PyTorch Lightning!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer = pl.Trainer(\n", - " logger=wandb_logger, # W&B integration\n", - " log_every_n_steps=50, # set the logging frequency\n", - " gpus=-1, # use all GPUs\n", - " max_epochs=5, # number of epochs\n", - " deterministic=True, # keep it deterministic\n", - " callbacks=[ImagePredictionLogger(samples)] # see Callbacks section\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 🏃‍♀️ Running our Model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now, let's make it all happen:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# setup model\n", - "model = LitMLP(in_dims=(1, 28, 28))\n", - "\n", - "# fit the model\n", - "trainer.fit(model, mnist)\n", - "\n", - "# evaluate the model on a test set\n", - "trainer.test(datamodule=mnist,\n", - " ckpt_path=None) # uses last-saved model\n", - "\n", - "wandb.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "> _Note_: In notebooks, we need to call `wandb.finish()` to indicate when we've finished our run. This isn't necessary in scripts." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Viewing the results on wandb.ai\n", - "\n", - "Among the outputs from W&B,\n", - "you will have noticed a few URLs.\n", - "One of these is the\n", - "[run page](https://docs.wandb.ai/ref/app/pages/run-page),\n", - "which has a dashboard with all of the information logged in this run, complete with smart default charts\n", - "and more.\n", - "The run page is printed both at the start and end of training, and ends with `lit-wandb/runs/{run_id}`.\n", - "\n", - ">_Note_: When visiting your run page, it is recommended to use `global_step` as x-axis to correctly superimpose metrics logged in different stages.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "![image.png]()" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "include_colab_link": true, - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "z5w0NbyRVCKD" + }, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7PI6YWwrVCKF" + }, + "source": [ + "\"Weights\n", + "\n", + "\n", + "\n", + "# ⚡ 💘 🏋️‍♀️ Supercharge your Training with PyTorch Lightning + Weights & Biases" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mPqv_lacVCKF" + }, + "source": [ + "\"Weights" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zJwmNjhlVCKF" + }, + "source": [ + "At Weights & Biases, we love anything\n", + "that makes training deep learning models easier.\n", + "That's why we worked with the folks at PyTorch Lightning to\n", + "[integrate our experiment tracking tool](https://docs.wandb.com/library/integrations/lightning)\n", + "directly into\n", + "[the Lightning library](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases).\n", + "\n", + "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/) is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.\n", + "It retains all the flexibility of PyTorch,\n", + "in case you need it,\n", + "but adds some useful abstractions\n", + "and builds in some best practices.\n", + "\n", + "## What this notebook covers:\n", + "\n", + "1. Differences between PyTorch and PyTorch Lightning, including how to set up `LightningModules` and `LightningDataModules`\n", + "2. How to get basic metric logging with the [`WandbLogger`](https://pytorch-lightning.readthedocs.io/en/latest/common/loggers.html#weights-and-biases)\n", + "3. How to log media with W&B and fully customize logging with Lightning `Callbacks`\n", + "\n", + "## The interactive dashboard in W&B will look like this:\n", + "\n", + "![](https://i.imgur.com/lIbMyFR.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3-dKMdt4VCKG" + }, + "source": [ + "## Follow along with a [video tutorial](http://wandb.me/lit-video)!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y-QLq_y4VCKG" + }, + "source": [ + "# 🚀 Installing and importing" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LgFEVQPiVCKG" + }, + "source": [ + "`wandb` and `pytorch-lightning` are both easily installable via [`pip`](https://pip.pypa.io/en/stable/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YKrH3BHxVCKG" + }, + "outputs": [], + "source": [ + "!pip install -qqq wandb lightning torchmetrics onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HyRgs61mVCKG" + }, + "source": [ + "PyTorch Lightning is built on top of PyTorch,\n", + "so we still need to import vanilla PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UksNGoWZVCKH" + }, + "outputs": [], + "source": [ + "# numpy for non-GPU array math\n", + "import numpy as np\n", + "\n", + "# 🍦 Vanilla PyTorch\n", + "import torch\n", + "from torch.nn import functional as F\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader, random_split\n", + "\n", + "# 👀 Torchvision for CV\n", + "from torchvision.datasets import MNIST\n", + "from torchvision import transforms\n", + "\n", + "# remove slow mirror from list of MNIST mirrors\n", + "MNIST.mirrors = [mirror for mirror in MNIST.mirrors\n", + " if not mirror.startswith(\"http://yann.lecun.com\")]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kOwGIep1VCKH" + }, + "source": [ + "Much of Lightning is built on the [Modules](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)\n", + "API from PyTorch,\n", + "but adds extra features\n", + "(like data loading and logging)\n", + "that are common to lots of PyTorch projects.\n", + "\n", + "Let's bring those in,\n", + "plus W&B and the integration.\n", + "\n", + "Lastly, we log in to the [Weights & Biases web service](https://wandb.ai).\n", + "If you've never used W&B,\n", + "you'll need to sign up first.\n", + "Accounts are free forever for academic and public projects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6SpApyF_VCKH" + }, + "outputs": [], + "source": [ + "# ⚡ PyTorch Lightning\n", + "import lightning.pytorch as pl\n", + "import torchmetrics\n", + "pl.seed_everything(hash(\"setting random seeds\") % 2**32 - 1)\n", + "\n", + "# 🏋️‍♀️ Weights & Biases\n", + "import wandb\n", + "\n", + "# ⚡ 🤝 🏋️‍♀️\n", + "from lightning.pytorch.loggers import WandbLogger\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WLJp3GnxVCKH" + }, + "outputs": [], + "source": [ + "wandb.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bw6unCdQVCKH" + }, + "source": [ + "> _Note_: If you're executing your training in a terminal, rather than a notebook, you don't need to include `wandb.login()` in your script.\n", + "Instead, call `wandb login` in the terminal and we'll keep you logged in for future runs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VHlNYD9iVCKH" + }, + "source": [ + "# 🏗️ Building a Model with Lightning\n", + "\n", + "In PyTorch Lightning, models are built with `LightningModule` ([docs here](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html)), which has all the functionality of a vanilla `torch.nn.Module` (🍦) but with a few delicious cherries of added functionality on top (🍨).\n", + "These cherries are there to cut down on boilerplate and\n", + "help separate out the ML engineering code\n", + "from the actual machine learning.\n", + "\n", + "For example, the mechanics of iterating over batches\n", + "as part of an epoch are extracted away,\n", + "so long as you define what happens on the `training_step`.\n", + "\n", + "To make a working model out of a `LightningModule`,\n", + "we need to define a new `class` and add a few methods on top.\n", + "\n", + "We'll demonstrate this process with `LitMLP`,\n", + "which applies a two-layer perceptron\n", + "(aka two fully-connected layers and\n", + "a fully-connected softmax readout layer)\n", + "to input `Tensors`.\n", + "\n", + "> _Note_: It is common in the Lightning community to shorten \"Lightning\" to \"[Lit](https://www.urbandictionary.com/define.php?term=it%27s%20lit)\".\n", + "This sometimes it sound like\n", + "[your code was written by Travis Scott](https://www.youtube.com/watch?v=y3FCXV8oEZU).\n", + "We consider this a good thing." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RkIVeyM9VCKI" + }, + "source": [ + "## 🍦 `__init__` and `forward`\n", + "\n", + "First, we need to add two methods that\n", + "are part of any vanilla PyTorch model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "coQH1peAVCKI" + }, + "source": [ + "Those methods are:\n", + "* `__init__` to do any setup, just like any Python class\n", + "* `forward` for inference, just like a PyTorch Module\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xr46KZ1ZVCKI" + }, + "source": [ + "The `forward` pass method is standard,\n", + "and it'll be different for every project,\n", + "so we won't comment on it.\n", + "\n", + "The `__init__` method,\n", + "which `init`ializes new instances of the class,\n", + "is a good place to log hyperparameter information to `wandb`.\n", + "\n", + "This is done with the `save_hyperparameters` method,\n", + "which captures all of the arguments to the initializer\n", + "and adds them to a dictionary at `self.hparams` --\n", + "that all comes for free as part of the `LightningModule`.\n", + "\n", + "> _Note_: `hparams` is logged to `wandb` as the `config`,\n", + "so you'll never lose track of the arguments you used to run a model again!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-16a4P3nVCKI" + }, + "outputs": [], + "source": [ + "class LitMLP(pl.LightningModule):\n", + "\n", + " def __init__(self, in_dims, n_classes=10,\n", + " n_layer_1=128, n_layer_2=256, lr=1e-4):\n", + " super().__init__()\n", + "\n", + " # we flatten the input Tensors and pass them through an MLP\n", + " self.layer_1 = nn.Linear(np.prod(in_dims), n_layer_1)\n", + " self.layer_2 = nn.Linear(n_layer_1, n_layer_2)\n", + " self.layer_3 = nn.Linear(n_layer_2, n_classes)\n", + "\n", + " # log hyperparameters\n", + " self.save_hyperparameters()\n", + "\n", + " # compute the accuracy -- no need to roll your own!\n", + " self.train_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n", + " self.valid_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n", + " self.test_acc = torchmetrics.Accuracy(task=\"multiclass\", num_classes=n_classes)\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Defines a forward pass using the Stem-Learner-Task\n", + " design pattern from Deep Learning Design Patterns:\n", + " https://www.manning.com/books/deep-learning-design-patterns\n", + " \"\"\"\n", + " batch_size, *dims = x.size()\n", + "\n", + " # stem: flatten\n", + " x = x.view(batch_size, -1)\n", + "\n", + " # learner: two fully-connected layers\n", + " x = F.relu(self.layer_1(x))\n", + " x = F.relu(self.layer_2(x))\n", + "\n", + " # task: compute class logits\n", + " x = self.layer_3(x)\n", + " x = F.log_softmax(x, dim=1)\n", + "\n", + " return x\n", + "\n", + " # convenient method to get the loss on a batch\n", + " def loss(self, xs, ys):\n", + " logits = self(xs) # this calls self.forward\n", + " loss = F.nll_loss(logits, ys)\n", + " return logits, loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XHIwTEqwVCKI" + }, + "source": [ + "> _Note_: for pedagogical purposes, we're splitting out\n", + "each stage of building the `LitMLP` into a different cell.\n", + "In a more typical workflow,\n", + "this would all happen in the `class` definition.\n", + "\n", + "> _Note_: if you're familiar with PyTorch,\n", + "you might be surprised to see we aren't taking care with `.device`s:\n", + "no `to_cuda` etc. PyTorch Lightning handles all that for you! 😎" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "q9n788xQVCKI" + }, + "source": [ + "## 🍨 `training_step` and `configure_optimizers`\n", + "Now, we add some special methods so that our `LitMLP` can be trained\n", + "using PyTorch Lightning's training API.\n", + "\n", + "> _Note_: if you've used Keras, this might be familiar.\n", + "It's very similar to the `.fit` API in that library." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0PTKSV_4VCKI" + }, + "source": [ + "Those methods are\n", + "\n", + "* `training_step`, which takes a batch and computes the loss; backprop goes through it\n", + "* `configure_optimizers`, which returns the `torch.optim.Optimizer` to apply after the `training_step`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KL-Yuk-DVCKI" + }, + "source": [ + "> _Note_: `training_step` is part of a rich system of callbacks in PyTorch Lightning.\n", + "These callbacks are methods that get called\n", + "at specific points during training\n", + "(e.g. when a validation epoch ends),\n", + "and they are a major part of what makes\n", + "PyTorch Lightning both useful and extensible." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S_xhIcNbVCKI" + }, + "source": [ + "Here's where we add some more serious logging code.\n", + "`self.log` takes a name and value for a metric.\n", + "Under the hood, this will get passed to `wandb.log` if you're using W&B.\n", + "\n", + "The logging behavior of PyTorch Lightning is both intelligent and configurable.\n", + "For example, by passing the `on_epoch`\n", + "keyword argument here,\n", + "we'll get `_epoch`-wise averages\n", + "of the metrics logged on each `_step`,\n", + "and those metrics will be named differently\n", + "in the W&B interface.\n", + "When training in a distributed setting,\n", + "these averages will be automatically computed across nodes.\n", + "\n", + "Read more about the `log` method [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IpbUYwBDVCKJ" + }, + "outputs": [], + "source": [ + "def training_step(self, batch, batch_idx):\n", + " xs, ys = batch\n", + " logits, loss = self.loss(xs, ys)\n", + " preds = torch.argmax(logits, 1)\n", + "\n", + " # logging metrics we calculated by hand\n", + " self.log('train/loss', loss, on_epoch=True)\n", + " # logging a pl.Metric\n", + " self.train_acc(preds, ys)\n", + " self.log('train/acc', self.train_acc, on_epoch=True)\n", + "\n", + " return loss\n", + "\n", + "def configure_optimizers(self):\n", + " return torch.optim.Adam(self.parameters(), lr=self.hparams[\"lr\"])\n", + "\n", + "LitMLP.training_step = training_step\n", + "LitMLP.configure_optimizers = configure_optimizers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JXquZkggVCKJ" + }, + "source": [ + "## ➕ Optional methods for even better logging\n", + "\n", + "The code above will log our model's performance,\n", + "system metrics, and more to W&B.\n", + "\n", + "If we want to take our logging to the next level,\n", + "we need to make use of PyTorch Lightning's callback system.\n", + "\n", + "> _Note_: thanks to the clean design of PyTorch Lightning,\n", + "the training code below will run with or without any\n", + "of this extra logging code. Nice!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TE8qYwtOVCKJ" + }, + "source": [ + "The other callbacks we'll make use of fall into two categories:\n", + "* methods that trigger on each batch for a dataset: `validation_step` and `test_step`\n", + "* methods that trigger at the end of an epoch,\n", + "or a full pass over a given dataset: `{training, validation, test}_epoch_end`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z8JylsGXVCKJ" + }, + "source": [ + "### 💾 `test`ing and saving the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AHy-SYLDVCKJ" + }, + "source": [ + "We use the test set to evaluate the performance of the final model,\n", + "so the `test` callbacks will be called at the end of the training pipeline.\n", + "\n", + "For performance on the `test` and `validation` sets,\n", + "we're typically less concerned about how\n", + "we do on intermediate steps and more\n", + "with how we did overall.\n", + "That's why below, we pass in\n", + "`on_step=False` and `on_epoch=True`\n", + "so that we log only `epoch`-wise metrics.\n", + "\n", + "> _Note_: That's actually the default behavior for `.log` when it's called inside of a `validation` or a `test` loop -- but not when it's called inside a `training` loop! Check out the table of default behaviors for `.log` [in the docs](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#log)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yeVmedTTVCKJ" + }, + "outputs": [], + "source": [ + "def test_step(self, batch, batch_idx):\n", + " xs, ys = batch\n", + " logits, loss = self.loss(xs, ys)\n", + " preds = torch.argmax(logits, 1)\n", + "\n", + " self.test_acc(preds, ys)\n", + " self.log(\"test/loss_epoch\", loss, on_step=False, on_epoch=True)\n", + " self.log(\"test/acc_epoch\", self.test_acc, on_step=False, on_epoch=True)\n", + "\n", + "LitMLP.test_step = test_step" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qk50ovviVCKJ" + }, + "source": [ + "We'll also take the opportunity to save the model in the\n", + "[portable `ONNX` format](https://onnx.ai/).\n", + "\n", + "\n", + "Later,\n", + "we'll see that this allows us to use the\n", + "[Netron model viewer](https://github.com/lutzroeder/netron) in W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QNr-9bEpVCKJ" + }, + "outputs": [], + "source": [ + "def on_test_epoch_end(self): # args are defined as part of pl API\n", + " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n", + " model_filename = \"model_final.onnx\"\n", + " self.to_onnx(model_filename, dummy_input, export_params=True)\n", + " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n", + " artifact.add_file(model_filename)\n", + " wandb.log_artifact(artifact)\n", + "\n", + "LitMLP.on_test_epoch_end = on_test_epoch_end" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JosFHArGVCKK" + }, + "source": [ + "### 📊 Logging `Histograms`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FxViSjO4VCKK" + }, + "source": [ + "For the `validation_data`,\n", + "let's track not only the `acc`uracy and `loss`,\n", + "but also the `logits`:\n", + "the un-normalized class probabilities.\n", + "That way, we can track if our network\n", + "is becoming more or less confident over time.\n", + "\n", + "There's a problem though:\n", + "`.log` wants to average,\n", + "but we'd rather look at a distribution.\n", + "\n", + "So instead, on every `validation_step`,\n", + "we'll `return` the `logits`,\n", + "rather than `log`ging them.\n", + "\n", + "Then, when we reach the `end`\n", + "of the `validation_epoch`,\n", + "the `logits` are available as the\n", + "`validation_step_outputs` -- a list.\n", + "\n", + "So to log we'll take those `logits`,\n", + "concatenate them together,\n", + "and turn them into a histogram with [`wandb.Histogram`](https://docs.wandb.com/library/log#histograms).\n", + "\n", + "Because we're no longer using Lightning's `.log` interface and are instead using `wandb`,\n", + "we need to drop down a level and use\n", + "`self.experiment.logger.log`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VrPs4tAJVCKK" + }, + "outputs": [], + "source": [ + "def on_validation_epoch_start(self):\n", + " self.validation_step_outputs = []\n", + "\n", + "def validation_step(self, batch, batch_idx):\n", + " xs, ys = batch\n", + " logits, loss = self.loss(xs, ys)\n", + " preds = torch.argmax(logits, 1)\n", + " self.valid_acc(preds, ys)\n", + "\n", + " self.log(\"valid/loss_epoch\", loss) # default on val/test is on_epoch only\n", + " self.log('valid/acc_epoch', self.valid_acc)\n", + "\n", + " self.validation_step_outputs.append(logits)\n", + "\n", + " return logits\n", + "\n", + "def on_validation_epoch_end(self):\n", + "\n", + " validation_step_outputs = self.validation_step_outputs\n", + "\n", + " dummy_input = torch.zeros(self.hparams[\"in_dims\"], device=self.device)\n", + " model_filename = f\"model_{str(self.global_step).zfill(5)}.onnx\"\n", + " torch.onnx.export(self, dummy_input, model_filename, opset_version=11)\n", + " artifact = wandb.Artifact(name=\"model.ckpt\", type=\"model\")\n", + " artifact.add_file(model_filename)\n", + " self.logger.experiment.log_artifact(artifact)\n", + "\n", + " flattened_logits = torch.flatten(torch.cat(validation_step_outputs))\n", + " self.logger.experiment.log(\n", + " {\"valid/logits\": wandb.Histogram(flattened_logits.to(\"cpu\")),\n", + " \"global_step\": self.global_step})\n", + "\n", + "LitMLP.on_validation_epoch_start = on_validation_epoch_start\n", + "LitMLP.validation_step = validation_step\n", + "LitMLP.on_validation_epoch_end = on_validation_epoch_end" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "617Vro8jVCKK" + }, + "source": [ + "Note that we're once again saving\n", + "the model in ONNX format.\n", + "That way, we can roll back our model to any given epoch --\n", + "useful in case the evaluation on the test set reveals we've overfit." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yMe2e8FEVCKK" + }, + "source": [ + "### 📲 `Callback`s for extra-fancy logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X-DfYBQrVCKL" + }, + "source": [ + "What we've done so far\n", + "will tell us how well our model\n", + "is using our system resources,\n", + "how well our model is training and generalizing,\n", + "and how confident it is.\n", + "\n", + "But DNNs often fail in pernicious and silent ways.\n", + "Often, the only way to notice these failures\n", + "is to look at how the model is doing\n", + "on specific examples.\n", + "\n", + "So let's additionally log some detailed information on some specific examples:\n", + "the inputs, outputs,\n", + "and `pred`ictions.\n", + "\n", + "We'll do this by writing our own `Callback` --\n", + "one that, after every `validation_epoch` ends,\n", + "logs input images and output predictions\n", + "using W&B's `Image` logger.\n", + "\n", + "> _Note_:\n", + "For more on the W&B media toolkit, read the [docs](https://docs.wandb.com/library/log#media)\n", + "or check out\n", + "[this Colab](http://wandb.me/media-colab)\n", + "to see everything it's capable of." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VM9hxnNtVCKL" + }, + "outputs": [], + "source": [ + "class ImagePredictionLogger(pl.Callback):\n", + " def __init__(self, val_samples, num_samples=32):\n", + " super().__init__()\n", + " self.val_imgs, self.val_labels = val_samples\n", + " self.val_imgs = self.val_imgs[:num_samples]\n", + " self.val_labels = self.val_labels[:num_samples]\n", + "\n", + " def on_validation_epoch_end(self, trainer, pl_module):\n", + " val_imgs = self.val_imgs.to(device=pl_module.device)\n", + "\n", + " logits = pl_module(val_imgs)\n", + " preds = torch.argmax(logits, 1)\n", + "\n", + " trainer.logger.experiment.log({\n", + " \"examples\": [wandb.Image(x, caption=f\"Pred:{pred}, Label:{y}\")\n", + " for x, pred, y in zip(val_imgs, preds, self.val_labels)],\n", + " \"global_step\": trainer.global_step\n", + " })" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6z2bcdwZVCKL" + }, + "source": [ + "# 🛒 Loading data\n", + "\n", + "Data pipelines can be created with:\n", + "* 🍦 Vanilla Pytorch `DataLoaders`\n", + "* ⚡ Pytorch Lightning `DataModules`\n", + "\n", + "`DataModules` are more structured definition, which allows for additional optimizations such as automated distribution of workload between CPU & GPU.\n", + "Using `DataModules` is recommended whenever possible!\n", + "\n", + "A `DataModule` is also defined by an interface:\n", + "* `prepare_data` (optional) which is called only once and on 1 GPU -- typically something like the data download step we have below\n", + "* `setup`, which is called on each GPU separately and accepts `stage` to define if we are at `fit` or `test` step\n", + "* `train_dataloader`, `val_dataloader` and `test_dataloader` to load each dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pO5M1XEnVCKL" + }, + "outputs": [], + "source": [ + "class MNISTDataModule(pl.LightningDataModule):\n", + "\n", + " def __init__(self, data_dir='./', batch_size=128):\n", + " super().__init__()\n", + " self.data_dir = data_dir\n", + " self.batch_size = batch_size\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))])\n", + "\n", + " def prepare_data(self):\n", + " # download data, train then test\n", + " MNIST(self.data_dir, train=True, download=True)\n", + " MNIST(self.data_dir, train=False, download=True)\n", + "\n", + " def setup(self, stage=None):\n", + "\n", + " # we set up only relevant datasets when stage is specified\n", + " if stage == 'fit' or stage is None:\n", + " mnist = MNIST(self.data_dir, train=True, transform=self.transform)\n", + " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n", + " if stage == 'test' or stage is None:\n", + " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", + "\n", + " # we define a separate DataLoader for each of train/val/test\n", + " def train_dataloader(self):\n", + " mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)\n", + " return mnist_train\n", + "\n", + " def val_dataloader(self):\n", + " mnist_val = DataLoader(self.mnist_val, batch_size=10 * self.batch_size)\n", + " return mnist_val\n", + "\n", + " def test_dataloader(self):\n", + " mnist_test = DataLoader(self.mnist_test, batch_size=10 * self.batch_size)\n", + " return mnist_test\n", + "\n", + "# setup data\n", + "mnist = MNISTDataModule()\n", + "mnist.prepare_data()\n", + "mnist.setup()\n", + "\n", + "# grab samples to log predictions on\n", + "samples = next(iter(mnist.val_dataloader()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Ord5Pu3VCKM" + }, + "source": [ + "# 👟 Making a `Trainer`\n", + "\n", + "The `DataLoader` and the `LightningModule`\n", + "are brought together by a `Trainer`,\n", + "which orchestrates data loading,\n", + "gradient calculation,\n", + "optimizer logic,\n", + "and logging.\n", + "\n", + "Luckily, we don't need to sub-class the `Trainer`,\n", + "we just need to configure it with keyword arguments." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MDDjVOSEVCKM" + }, + "source": [ + "And that is where we'll use the `pytorch_lightning.loggers.WandbLogger` to connect our logging to W&B." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3nibeocmVCKM" + }, + "outputs": [], + "source": [ + "wandb_logger = WandbLogger(project=\"lit-wandb\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YDEryNsoVCKM" + }, + "source": [ + "> _Note_: Check out [the documentation](https://docs.wandb.com/library/integrations/lightning) for customization options. I like `group`s and `tag`s!.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NLfJc41aVCKM" + }, + "source": [ + "We can then set up our `Trainer` and customize several options, such as gradient accumulation, half precision training and distributed computing.\n", + "\n", + "We'll stick to the basics for this example,\n", + "but half-precision training and easy scaling to distributed settings are two of the major reasons why folks like PyTorch Lightning!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0dciVNKHVCKN" + }, + "outputs": [], + "source": [ + "trainer = pl.Trainer(\n", + " logger=wandb_logger, # W&B integration\n", + " log_every_n_steps=50, # set the logging frequency\n", + " max_epochs=5, # number of epochs\n", + " deterministic=True, # keep it deterministic\n", + " callbacks=[ImagePredictionLogger(samples)] # see Callbacks section\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Fs6hyyU2VCKN" + }, + "source": [ + "# 🏃‍♀️ Running our Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LzVnPc6_VCKN" + }, + "source": [ + "Now, let's make it all happen:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q7VxL_o6VCKN" + }, + "outputs": [], + "source": [ + "# setup model\n", + "model = LitMLP(in_dims=(1, 28, 28))\n", + "\n", + "# fit the model\n", + "trainer.fit(model, mnist)\n", + "\n", + "# evaluate the model on a test set\n", + "trainer.test(datamodule=mnist,\n", + " ckpt_path=None) # uses last-saved model\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cryJvmKOVCKN" + }, + "source": [ + "> _Note_: In notebooks, we need to call `wandb.finish()` to indicate when we've finished our run. This isn't necessary in scripts." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t9jsRs-oVCKN" + }, + "source": [ + "## Viewing the results on wandb.ai\n", + "\n", + "Among the outputs from W&B,\n", + "you will have noticed a few URLs.\n", + "One of these is the\n", + "[run page](https://docs.wandb.ai/ref/app/pages/run-page),\n", + "which has a dashboard with all of the information logged in this run, complete with smart default charts\n", + "and more.\n", + "The run page is printed both at the start and end of training, and ends with `lit-wandb/runs/{run_id}`.\n", + "\n", + ">_Note_: When visiting your run page, it is recommended to use `global_step` as x-axis to correctly superimpose metrics logged in different stages.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fgPkQGamVCKN" + }, + "source": [ + "\n", + "![image.png]()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file