diff --git a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
index 16d7a3b6..32a5df77 100644
--- a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
+++ b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb
@@ -1,681 +1,737 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ws0nlkuGOpDy"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aU1q92uCOpD1"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# W&B Tutorial with Pytorch Lightning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mHCKzEzSOpD1"
+ },
+ "source": [
+ "## 🛠️ Install `wandb` and `pytorch-lightning`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RnyGWvwDOpD1"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q lightning wandb torchvision"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G5wpAEAoOpD2"
+ },
+ "source": [
+ "## Login to W&B either through Python or CLI\n",
+ "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n",
+ "\n",
+ "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n",
+ "```\n",
+ "import os\n",
+ "import wandb\n",
+ "\n",
+ "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n",
+ "```\n",
+ "You can also login via the CLI with:\n",
+ "```\n",
+ "wandb login --host \n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YjXGiBLQOpD2"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "r5HflvclOpD2"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.login()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zNuxn8BjOpD2"
+ },
+ "source": [
+ "## ⚱ Logging the Raw Training Data as an Artifact"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "X9mkr942OpD2"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Enter your W&B project and entity\n",
+ "\n",
+ "# FORM VARIABLES\n",
+ "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n",
+ "ENTITY = \"wandb\"#@param {type:\"string\"}\n",
+ "\n",
+ "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n",
+ "# to select one of these three datasets\n",
+ "# TINY dataset: 100 images, 30MB\n",
+ "# SMALL dataset: 1000 images, 312MB\n",
+ "# MEDIUM dataset: 5000 images, 1.5GB\n",
+ "# LARGE dataset: 12,000 images, 3.6GB\n",
+ "\n",
+ "SIZE = \"TINY\"\n",
+ "\n",
+ "if SIZE == \"TINY\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n",
+ " src_zip = \"nature_100.zip\"\n",
+ " DATA_SRC = \"nature_100\"\n",
+ " IMAGES_PER_LABEL = 10\n",
+ " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n",
+ "elif SIZE == \"SMALL\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n",
+ " src_zip = \"nature_1K.zip\"\n",
+ " DATA_SRC = \"nature_1K\"\n",
+ " IMAGES_PER_LABEL = 100\n",
+ " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n",
+ "elif SIZE == \"MEDIUM\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
+ " src_zip = \"nature_12K.zip\"\n",
+ " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
+ " IMAGES_PER_LABEL = 500\n",
+ " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n",
+ "elif SIZE == \"LARGE\":\n",
+ " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
+ " src_zip = \"nature_12K.zip\"\n",
+ " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
+ " IMAGES_PER_LABEL = 1000\n",
+ " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "o8nApIhdOpD3"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!curl -SL $src_url > $src_zip\n",
+ "!unzip $src_zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ],
+ "metadata": {
+ "id": "ALmdQ7wISLaA"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XQ_Kwsg9OpD3"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "import pandas as pd\n",
+ "import os\n",
+ "\n",
+ "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n",
+ " img_paths = []\n",
+ " for root, dirs, files in os.walk('nature_100', topdown=False):\n",
+ " for name in files:\n",
+ " img_path = os.path.join(root, name)\n",
+ " label = img_path.split('/')[1]\n",
+ " img_paths.append([img_path, label])\n",
+ "\n",
+ " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n",
+ " index_df.to_csv('index.csv', index=False)\n",
+ "\n",
+ " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n",
+ " train_art.add_dir('nature_100')\n",
+ "\n",
+ " # Also adding a csv indicating the labels of each image\n",
+ " train_art.add_file('index.csv')\n",
+ " wandb.log_artifact(train_art)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1eJpOk_VOpD3"
+ },
+ "source": [
+ "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n",
+ "- Makes it easy to interopt your DataLoaders with new versions of datasets\n",
+ "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Z2g9JRrwOpD3"
+ },
+ "outputs": [],
+ "source": [
+ "from torchvision import transforms\n",
+ "import lightning.pytorch as pl\n",
+ "import torch\n",
+ "from torch.utils.data import Dataset, DataLoader, random_split\n",
+ "from skimage import io, transform\n",
+ "from torchvision import transforms, utils, models\n",
+ "import math\n",
+ "\n",
+ "class NatureDataset(Dataset):\n",
+ " def __init__(self,\n",
+ " wandb_run,\n",
+ " artifact_name_alias=\"Nature_100:latest\",\n",
+ " local_target_dir=\"Nature_100:latest\",\n",
+ " transform=None):\n",
+ " self.local_target_dir = local_target_dir\n",
+ " self.transform = transform\n",
+ "\n",
+ " # Pull down the artifact locally to load it into memory\n",
+ " art = wandb_run.use_artifact(artifact_name_alias)\n",
+ " path_at = art.download(root=self.local_target_dir)\n",
+ "\n",
+ " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n",
+ " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n",
+ " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n",
+ " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.ref_df)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " if torch.is_tensor(idx):\n",
+ " idx = idx.tolist()\n",
+ "\n",
+ " img_path = self.ref_df.iloc[idx, 0]\n",
+ "\n",
+ " image = io.imread(img_path)\n",
+ " label = self.ref_df.iloc[idx, 1]\n",
+ " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n",
+ "\n",
+ " if self.transform:\n",
+ " image = self.transform(image)\n",
+ "\n",
+ " return image, label\n",
+ "\n",
+ "\n",
+ "class NatureDatasetModule(pl.LightningDataModule):\n",
+ " def __init__(self,\n",
+ " wandb_run,\n",
+ " artifact_name_alias: str = \"Nature_100:latest\",\n",
+ " local_target_dir: str = \"Nature_100:latest\",\n",
+ " batch_size: int = 16,\n",
+ " input_size: int = 224,\n",
+ " seed: int = 42):\n",
+ " super().__init__()\n",
+ " self.wandb_run = wandb_run\n",
+ " self.artifact_name_alias = artifact_name_alias\n",
+ " self.local_target_dir = local_target_dir\n",
+ " self.batch_size = batch_size\n",
+ " self.input_size = input_size\n",
+ " self.seed = seed\n",
+ "\n",
+ " def setup(self, stage=None):\n",
+ " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n",
+ " artifact_name_alias=self.artifact_name_alias,\n",
+ " local_target_dir=self.local_target_dir,\n",
+ " transform=transforms.Compose([transforms.ToTensor(),\n",
+ " transforms.CenterCrop(self.input_size),\n",
+ " transforms.Normalize((0.485, 0.456, 0.406),\n",
+ " (0.229, 0.224, 0.225))]))\n",
+ "\n",
+ " nature_length = len(self.nature_dataset)\n",
+ " train_size = math.floor(0.8 * nature_length)\n",
+ " val_size = math.floor(0.2 * nature_length)\n",
+ " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n",
+ " [train_size, val_size],\n",
+ " generator=torch.Generator().manual_seed(self.seed))\n",
+ " return self\n",
+ "\n",
+ " def train_dataloader(self):\n",
+ " return DataLoader(self.nature_train, batch_size=self.batch_size)\n",
+ "\n",
+ " def val_dataloader(self):\n",
+ " return DataLoader(self.nature_val, batch_size=self.batch_size)\n",
+ "\n",
+ " def predict_dataloader(self):\n",
+ " pass\n",
+ "\n",
+ " def teardown(self, stage: str):\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NLhjdvF-OpD3"
+ },
+ "source": [
+ "##How Logging in your Pytorch `LightningModule`works:\n",
+ "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`.\n",
+ "\n",
+ "```\n",
+ "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\")\n",
+ "trainer = Trainer(logger=wandb_logger)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B.\n",
+ "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Oi1NpQs7OpD4"
+ },
+ "source": [
+ "### Some helper functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HLgYji2oOpD4"
+ },
+ "outputs": [],
+ "source": [
+ "# Some helper functions\n",
+ "\n",
+ "def set_parameter_requires_grad(model, feature_extracting):\n",
+ " if feature_extracting:\n",
+ " for param in model.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n",
+ " # Initialize these variables which will be set in this if statement. Each of these\n",
+ " # variables is model specific.\n",
+ " model_ft = None\n",
+ " input_size = 0\n",
+ "\n",
+ " if model_name == \"resnet\":\n",
+ " \"\"\" Resnet18\n",
+ " \"\"\"\n",
+ " model_ft = models.resnet18(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " num_ftrs = model_ft.fc.in_features\n",
+ " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n",
+ " input_size = 224\n",
+ "\n",
+ " elif model_name == \"squeezenet\":\n",
+ " \"\"\" Squeezenet\n",
+ " \"\"\"\n",
+ " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n",
+ " model_ft.num_classes = num_classes\n",
+ " input_size = 224\n",
+ "\n",
+ " elif model_name == \"densenet\":\n",
+ " \"\"\" Densenet\n",
+ " \"\"\"\n",
+ " model_ft = models.densenet121(pretrained=use_pretrained)\n",
+ " set_parameter_requires_grad(model_ft, feature_extract)\n",
+ " num_ftrs = model_ft.classifier.in_features\n",
+ " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n",
+ " input_size = 224\n",
+ "\n",
+ " else:\n",
+ " print(\"Invalid model name, exiting...\")\n",
+ " exit()\n",
+ "\n",
+ " return model_ft, input_size"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LLRK0S6oOpD4"
+ },
+ "source": [
+ "### Writing the `LightningModule`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "k0tTtK5zOpD4"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
+ "from torch.optim import Adam\n",
+ "from torchmetrics.functional import accuracy\n",
+ "from lightning.pytorch import LightningModule\n",
+ "from torchvision import models\n",
+ "\n",
+ "class NatureLitModule(LightningModule):\n",
+ " def __init__(self,\n",
+ " model_name,\n",
+ " num_classes=10,\n",
+ " feature_extract=True,\n",
+ " lr=0.01):\n",
+ " '''method used to define our model parameters'''\n",
+ " super().__init__()\n",
+ "\n",
+ " self.model_name = model_name\n",
+ " self.num_classes = num_classes\n",
+ " self.feature_extract = feature_extract\n",
+ " self.model, self.input_size = initialize_model(model_name=self.model_name,\n",
+ " num_classes=self.num_classes,\n",
+ " feature_extract=True)\n",
+ "\n",
+ " # loss\n",
+ " self.loss = CrossEntropyLoss()\n",
+ "\n",
+ " # optimizer parameters\n",
+ " self.lr = lr\n",
+ "\n",
+ " # save hyper-parameters to self.hparams (auto-logged by W&B)\n",
+ " self.save_hyperparameters()\n",
+ "\n",
+ " # Record the gradients of all the layers\n",
+ " wandb.watch(self.model)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " '''method used for inference input -> output'''\n",
+ " x = self.model(x)\n",
+ "\n",
+ " return x\n",
+ "\n",
+ " def training_step(self, batch, batch_idx):\n",
+ " '''needs to return a loss from a single batch'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('train/loss', loss)\n",
+ " self.log('train/accuracy', acc)\n",
+ "\n",
+ " return loss\n",
+ "\n",
+ " def validation_step(self, batch, batch_idx):\n",
+ " '''used for logging metrics'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('validation/loss', loss)\n",
+ " self.log('validation/accuracy', acc)\n",
+ "\n",
+ " # Let's return preds to use it in a custom callback\n",
+ " return preds, y\n",
+ "\n",
+ " def test_step(self, batch, batch_idx):\n",
+ " '''used for logging metrics'''\n",
+ " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
+ "\n",
+ " # Log loss and metric\n",
+ " self.log('test/loss', loss)\n",
+ " self.log('test/accuracy', acc)\n",
+ "\n",
+ " def configure_optimizers(self):\n",
+ " '''defines model optimizer'''\n",
+ " return Adam(self.parameters(), lr=self.lr)\n",
+ "\n",
+ "\n",
+ " def _get_preds_loss_accuracy(self, batch):\n",
+ " '''convenience function since train/valid/test steps are similar'''\n",
+ " x, y = batch\n",
+ " logits = self(x)\n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " loss = self.loss(logits, y)\n",
+ " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n",
+ " return preds, y, loss, acc"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bvZHZuZzOpD4"
+ },
+ "source": [
+ "### Instrument Callbacks to log additional things at certain points in your code"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "j9pHqaw1OpD5"
+ },
+ "outputs": [],
+ "source": [
+ "from lightning.pytorch.callbacks import Callback\n",
+ "\n",
+ "class LogPredictionsCallback(Callback):\n",
+ "\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ "\n",
+ "\n",
+ " def on_validation_epoch_start(self, trainer, pl_module):\n",
+ " self.batch_dfs = []\n",
+ " self.image_list = []\n",
+ " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n",
+ "\n",
+ "\n",
+ " def on_validation_batch_end(\n",
+ " self, trainer, pl_module, outputs, batch, batch_idx):\n",
+ " \"\"\"Called when the validation batch ends.\"\"\"\n",
+ "\n",
+ " # Append validation predictions and ground truth to log in confusion matrix\n",
+ " x, y = batch\n",
+ " preds, y = outputs\n",
+ " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.cpu().numpy(), \"Predictions\": preds.cpu().numpy()}))\n",
+ "\n",
+ " # Add wandb.Image to a table to log at the end of validation\n",
+ " x = x.cpu().numpy().transpose(0, 2, 3, 1)\n",
+ " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n",
+ " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n",
+ " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n",
+ "\n",
+ "\n",
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
+ " # Collect statistics for whole validation set and log\n",
+ " class_names = trainer.datamodule.nature_dataset.class_names\n",
+ " val_df = pd.concat(self.batch_dfs)\n",
+ " wandb.log({\"validation_table\": self.val_table,\n",
+ " \"images_over_time\": self.image_list,\n",
+ " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(),\n",
+ " preds=val_df[\"Predictions\"].tolist(),\n",
+ " class_names=class_names)}, step=trainer.global_step)\n",
+ "\n",
+ " del self.batch_dfs\n",
+ " del self.val_table\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jAiH-kb1OpD5"
+ },
+ "source": [
+ "## 🏋️ Main Training Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FrYTQ7uaOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "from lightning.pytorch.callbacks import ModelCheckpoint\n",
+ "from lightning.pytorch.loggers import WandbLogger\n",
+ "from lightning.pytorch import Trainer\n",
+ "\n",
+ "wandb.init(project=PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " job_type='training',\n",
+ " config={\n",
+ " \"model_name\": \"squeezenet\",\n",
+ " \"batch_size\": 16\n",
+ " })\n",
+ "\n",
+ "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}')\n",
+ "\n",
+ "log_predictions_callback = LogPredictionsCallback()\n",
+ "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
+ "\n",
+ "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
+ "\n",
+ "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
+ " artifact_name_alias = \"Nature_100:latest\",\n",
+ " local_target_dir = \"Nature_100:latest\",\n",
+ " batch_size=wandb.config['batch_size'],\n",
+ " input_size=model.input_size)\n",
+ "nature_module.setup()\n",
+ "\n",
+ "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
+ " callbacks=[log_predictions_callback, checkpoint_callback],\n",
+ " max_epochs=5,\n",
+ " log_every_n_steps=5)\n",
+ "trainer.fit(model, datamodule=nature_module)\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uEII8J7UOpD5"
+ },
+ "source": [
+ "### Syncing with W&B Offline\n",
+ "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n",
+ "\n",
+ "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xi-EhHsxOpD5"
+ },
+ "source": [
+ "## Retrieve a model checkpoint artifact and resume training\n",
+ "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jiE1Bk7fOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Enter which checkpoint you want to resume training from:\n",
+ "\n",
+ "# FORM VARIABLES\n",
+ "ARTIFACT_NAME_ALIAS = \"nature-oyxk79m1:v4\" #@param {type:\"string\"}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jZZXWRatOpD5"
+ },
+ "outputs": [],
+ "source": [
+ "wandb.init(project=PROJECT_NAME,\n",
+ " entity=ENTITY,\n",
+ " job_type='resume_training')\n",
+ "\n",
+ "# Retrieve model checkpoint artifact and restore previous hyperparameters\n",
+ "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n",
+ "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n",
+ "logging_run = model_chkpt_art.logged_by()\n",
+ "wandb.config = logging_run.config\n",
+ "\n",
+ "# Can create a new artifact name or continue logging to the old one\n",
+ "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n",
+ "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name)\n",
+ "\n",
+ "log_predictions_callback = LogPredictionsCallback()\n",
+ "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
+ "\n",
+ "model = NatureLitModule.load_from_checkpoint(f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt')\n",
+ "\n",
+ "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
+ " artifact_name_alias = \"Nature_100:latest\",\n",
+ " local_target_dir = \"Nature_100:latest\",\n",
+ " batch_size=wandb.config['batch_size'],\n",
+ " input_size=model.input_size)\n",
+ "nature_module.setup()\n",
+ "\n",
+ "\n",
+ "\n",
+ "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
+ " callbacks=[log_predictions_callback, checkpoint_callback],\n",
+ " max_epochs=10,\n",
+ " log_every_n_steps=5)\n",
+ "trainer.fit(model, datamodule=nature_module)\n",
+ "\n",
+ "wandb.finish()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KTforR6EOpD5"
+ },
+ "source": [
+ "## Model Registry\n",
+ "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n",
+ "\n",
+ "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n",
+ "1. Access your team's model registry by going the team page and selecting `Model Registry`\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n",
+ "\n",
+ "2. Create a new Registered Model.\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n",
+ "\n",
+ "3. Go to the artifacts tab of the project that holds all your model checkpoints\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n",
+ "\n",
+ "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n",
+ "\n",
+ "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging.\n",
+ "\n",
+ "\n",
+ "### Create a Centralized Hub for all your models\n",
+ "- Add a model card, tags, slack notifactions to your Registered Model\n",
+ "- Change aliases to reflect when models move through different phases\n",
+ "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n",
+ "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "# W&B Tutorial with Pytorch Lightning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🛠️ Install `wandb` and `pytorch-lightning`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -q lightning wandb torchvision"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Login to W&B either through Python or CLI\n",
- "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n",
- "\n",
- "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n",
- "```\n",
- "import os\n",
- "import wandb \n",
- "\n",
- "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n",
- "```\n",
- "You can also login via the CLI with: \n",
- "```\n",
- "wandb login --host \n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb.login()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## ⚱ Logging the Raw Training Data as an Artifact"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#@title Enter your W&B project and entity\n",
- "\n",
- "# FORM VARIABLES\n",
- "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n",
- "ENTITY = \"wandb\"#@param {type:\"string\"}\n",
- "\n",
- "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n",
- "# to select one of these three datasets\n",
- "# TINY dataset: 100 images, 30MB\n",
- "# SMALL dataset: 1000 images, 312MB\n",
- "# MEDIUM dataset: 5000 images, 1.5GB\n",
- "# LARGE dataset: 12,000 images, 3.6GB\n",
- "\n",
- "SIZE = \"TINY\"\n",
- "\n",
- "if SIZE == \"TINY\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n",
- " src_zip = \"nature_100.zip\"\n",
- " DATA_SRC = \"nature_100\"\n",
- " IMAGES_PER_LABEL = 10\n",
- " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n",
- "elif SIZE == \"SMALL\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n",
- " src_zip = \"nature_1K.zip\"\n",
- " DATA_SRC = \"nature_1K\"\n",
- " IMAGES_PER_LABEL = 100\n",
- " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n",
- "elif SIZE == \"MEDIUM\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
- " src_zip = \"nature_12K.zip\"\n",
- " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
- " IMAGES_PER_LABEL = 500\n",
- " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n",
- "elif SIZE == \"LARGE\":\n",
- " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n",
- " src_zip = \"nature_12K.zip\"\n",
- " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n",
- " IMAGES_PER_LABEL = 1000\n",
- " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%%capture\n",
- "!curl -SL $src_url > $src_zip\n",
- "!unzip $src_zip"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "import pandas as pd\n",
- "import os\n",
- "\n",
- "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n",
- " img_paths = []\n",
- " for root, dirs, files in os.walk('nature_100', topdown=False):\n",
- " for name in files:\n",
- " img_path = os.path.join(root, name)\n",
- " label = img_path.split('/')[1]\n",
- " img_paths.append([img_path, label])\n",
- "\n",
- " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n",
- " index_df.to_csv('index.csv', index=False)\n",
- "\n",
- " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n",
- " train_art.add_dir('nature_100')\n",
- "\n",
- " # Also adding a csv indicating the labels of each image\n",
- " train_art.add_file('index.csv')\n",
- " wandb.log_artifact(train_art)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n",
- "- Makes it easy to interopt your DataLoaders with new versions of datasets\n",
- "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchvision import transforms\n",
- "import lightning.pytorch as pl\n",
- "import torch\n",
- "from torch.utils.data import Dataset, DataLoader, random_split\n",
- "from skimage import io, transform\n",
- "from torchvision import transforms, utils, models\n",
- "import math\n",
- "\n",
- "class NatureDataset(Dataset):\n",
- " def __init__(self, \n",
- " wandb_run, \n",
- " artifact_name_alias=\"Nature_100:latest\", \n",
- " local_target_dir=\"Nature_100:latest\", \n",
- " transform=None):\n",
- " self.local_target_dir = local_target_dir\n",
- " self.transform = transform\n",
- "\n",
- " # Pull down the artifact locally to load it into memory\n",
- " art = wandb_run.use_artifact(artifact_name_alias)\n",
- " path_at = art.download(root=self.local_target_dir)\n",
- "\n",
- " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n",
- " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n",
- " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n",
- " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.ref_df)\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " if torch.is_tensor(idx):\n",
- " idx = idx.tolist()\n",
- "\n",
- " img_path = self.ref_df.iloc[idx, 0]\n",
- "\n",
- " image = io.imread(img_path)\n",
- " label = self.ref_df.iloc[idx, 1]\n",
- " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n",
- "\n",
- " if self.transform:\n",
- " image = self.transform(image)\n",
- "\n",
- " return image, label\n",
- "\n",
- "\n",
- "class NatureDatasetModule(pl.LightningDataModule):\n",
- " def __init__(self,\n",
- " wandb_run,\n",
- " artifact_name_alias: str = \"Nature_100:latest\",\n",
- " local_target_dir: str = \"Nature_100:latest\",\n",
- " batch_size: int = 16,\n",
- " input_size: int = 224,\n",
- " seed: int = 42):\n",
- " super().__init__()\n",
- " self.wandb_run = wandb_run\n",
- " self.artifact_name_alias = artifact_name_alias\n",
- " self.local_target_dir = local_target_dir\n",
- " self.batch_size = batch_size\n",
- " self.input_size = input_size\n",
- " self.seed = seed\n",
- "\n",
- " def setup(self, stage=None):\n",
- " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n",
- " artifact_name_alias=self.artifact_name_alias,\n",
- " local_target_dir=self.local_target_dir,\n",
- " transform=transforms.Compose([transforms.ToTensor(),\n",
- " transforms.CenterCrop(self.input_size),\n",
- " transforms.Normalize((0.485, 0.456, 0.406),\n",
- " (0.229, 0.224, 0.225))]))\n",
- "\n",
- " nature_length = len(self.nature_dataset)\n",
- " train_size = math.floor(0.8 * nature_length)\n",
- " val_size = math.floor(0.2 * nature_length)\n",
- " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n",
- " [train_size, val_size],\n",
- " generator=torch.Generator().manual_seed(self.seed))\n",
- " return self\n",
- "\n",
- " def train_dataloader(self):\n",
- " return DataLoader(self.nature_train, batch_size=self.batch_size)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return DataLoader(self.nature_val, batch_size=self.batch_size)\n",
- "\n",
- " def predict_dataloader(self):\n",
- " pass\n",
- "\n",
- " def teardown(self, stage: str):\n",
- " pass"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "##How Logging in your Pytorch `LightningModule`works:\n",
- "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`. \n",
- " \n",
- "```\n",
- "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\") \n",
- "trainer = Trainer(logger=wandb_logger) \n",
- "```\n",
- "\n",
- "\n",
- "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B. \n",
- "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Some helper functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Some helper functions\n",
- "\n",
- "def set_parameter_requires_grad(model, feature_extracting):\n",
- " if feature_extracting:\n",
- " for param in model.parameters():\n",
- " param.requires_grad = False\n",
- "\n",
- "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n",
- " # Initialize these variables which will be set in this if statement. Each of these\n",
- " # variables is model specific.\n",
- " model_ft = None\n",
- " input_size = 0\n",
- "\n",
- " if model_name == \"resnet\":\n",
- " \"\"\" Resnet18\n",
- " \"\"\"\n",
- " model_ft = models.resnet18(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " num_ftrs = model_ft.fc.in_features\n",
- " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n",
- " input_size = 224\n",
- "\n",
- " elif model_name == \"squeezenet\":\n",
- " \"\"\" Squeezenet\n",
- " \"\"\"\n",
- " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n",
- " model_ft.num_classes = num_classes\n",
- " input_size = 224\n",
- "\n",
- " elif model_name == \"densenet\":\n",
- " \"\"\" Densenet\n",
- " \"\"\"\n",
- " model_ft = models.densenet121(pretrained=use_pretrained)\n",
- " set_parameter_requires_grad(model_ft, feature_extract)\n",
- " num_ftrs = model_ft.classifier.in_features\n",
- " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n",
- " input_size = 224\n",
- "\n",
- " else:\n",
- " print(\"Invalid model name, exiting...\")\n",
- " exit()\n",
- "\n",
- " return model_ft, input_size"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Writing the `LightningModule`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "from torch.nn import Linear, CrossEntropyLoss, functional as F\n",
- "from torch.optim import Adam\n",
- "from torchmetrics.functional import accuracy\n",
- "from lightning.pytorch import LightningModule\n",
- "from torchvision import models\n",
- "\n",
- "class NatureLitModule(LightningModule):\n",
- " def __init__(self,\n",
- " model_name,\n",
- " num_classes=10,\n",
- " feature_extract=True,\n",
- " lr=0.01):\n",
- " '''method used to define our model parameters'''\n",
- " super().__init__()\n",
- "\n",
- " self.model_name = model_name\n",
- " self.num_classes = num_classes\n",
- " self.feature_extract = feature_extract\n",
- " self.model, self.input_size = initialize_model(model_name=self.model_name,\n",
- " num_classes=self.num_classes,\n",
- " feature_extract=True)\n",
- "\n",
- " # loss\n",
- " self.loss = CrossEntropyLoss()\n",
- "\n",
- " # optimizer parameters\n",
- " self.lr = lr\n",
- "\n",
- " # save hyper-parameters to self.hparams (auto-logged by W&B)\n",
- " self.save_hyperparameters()\n",
- "\n",
- " # Record the gradients of all the layers\n",
- " wandb.watch(self.model)\n",
- "\n",
- " def forward(self, x):\n",
- " '''method used for inference input -> output'''\n",
- " x = self.model(x)\n",
- "\n",
- " return x\n",
- "\n",
- " def training_step(self, batch, batch_idx):\n",
- " '''needs to return a loss from a single batch'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('train/loss', loss)\n",
- " self.log('train/accuracy', acc)\n",
- "\n",
- " return loss\n",
- "\n",
- " def validation_step(self, batch, batch_idx):\n",
- " '''used for logging metrics'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('validation/loss', loss)\n",
- " self.log('validation/accuracy', acc)\n",
- "\n",
- " # Let's return preds to use it in a custom callback\n",
- " return preds, y\n",
- "\n",
- " def validation_epoch_end(self, validation_step_outputs):\n",
- " \"\"\"Called when the validation ends.\"\"\"\n",
- " preds, y = validation_step_outputs\n",
- " all_preds = torch.stack(preds)\n",
- " all_y = torch.stack(y)\n",
- "\n",
- " def test_step(self, batch, batch_idx):\n",
- " '''used for logging metrics'''\n",
- " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n",
- "\n",
- " # Log loss and metric\n",
- " self.log('test/loss', loss)\n",
- " self.log('test/accuracy', acc)\n",
- "\n",
- " def configure_optimizers(self):\n",
- " '''defines model optimizer'''\n",
- " return Adam(self.parameters(), lr=self.lr)\n",
- "\n",
- "\n",
- " def _get_preds_loss_accuracy(self, batch):\n",
- " '''convenience function since train/valid/test steps are similar'''\n",
- " x, y = batch\n",
- " logits = self(x)\n",
- " preds = torch.argmax(logits, dim=1)\n",
- " loss = self.loss(logits, y)\n",
- " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n",
- " return preds, y, loss, acc"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Instrument Callbacks to log additional things at certain points in your code"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from lightning.pytorch.callbacks import Callback\n",
- "\n",
- "class LogPredictionsCallback(Callback):\n",
- "\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- "\n",
- " \n",
- " def on_validation_epoch_start(self, trainer, pl_module):\n",
- " self.batch_dfs = []\n",
- " self.image_list = []\n",
- " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n",
- "\n",
- " \n",
- " def on_validation_batch_end(\n",
- " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n",
- " \"\"\"Called when the validation batch ends.\"\"\"\n",
- "\n",
- " # Append validation predictions and ground truth to log in confusion matrix\n",
- " x, y = batch\n",
- " preds, y = outputs\n",
- " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.numpy(), \"Predictions\": preds.numpy()}))\n",
- "\n",
- " # Add wandb.Image to a table to log at the end of validation\n",
- " x = x.numpy().transpose(0, 2, 3, 1)\n",
- " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n",
- " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n",
- " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n",
- " \n",
- " \n",
- " def on_validation_epoch_end(self, trainer, pl_module):\n",
- " # Collect statistics for whole validation set and log\n",
- " class_names = trainer.datamodule.nature_dataset.class_names\n",
- " val_df = pd.concat(self.batch_dfs)\n",
- " wandb.log({\"validation_table\": self.val_table,\n",
- " \"images_over_time\": self.image_list,\n",
- " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(), \n",
- " preds=val_df[\"Predictions\"].tolist(), \n",
- " class_names=class_names)}, step=trainer.global_step)\n",
- "\n",
- " del self.batch_dfs\n",
- " del self.val_table\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 🏋️ Main Training Loop"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from lightning.pytorch.callbacks import ModelCheckpoint\n",
- "from lightning.pytorch.loggers import WandbLogger\n",
- "from lightning.pytorch import Trainer\n",
- "\n",
- "wandb.init(project=PROJECT_NAME,\n",
- " entity=ENTITY,\n",
- " job_type='training',\n",
- " config={\n",
- " \"model_name\": \"squeezenet\",\n",
- " \"batch_size\": 16\n",
- " })\n",
- "\n",
- "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}') \n",
- "\n",
- "log_predictions_callback = LogPredictionsCallback()\n",
- "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
- "\n",
- "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
- "\n",
- "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
- " artifact_name_alias = \"Nature_100:latest\",\n",
- " local_target_dir = \"Nature_100:latest\",\n",
- " batch_size=wandb.config['batch_size'],\n",
- " input_size=model.input_size)\n",
- "nature_module.setup()\n",
- "\n",
- "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
- " callbacks=[log_predictions_callback, checkpoint_callback],\n",
- " max_epochs=5,\n",
- " log_every_n_steps=5) \n",
- "trainer.fit(model, datamodule=nature_module)\n",
- "\n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Syncing with W&B Offline\n",
- "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n",
- "\n",
- "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Retrieve a model checkpoint artifact and resume training\n",
- "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#@title Enter which checkpoint you want to resume training from:\n",
- "\n",
- "# FORM VARIABLES\n",
- "ARTIFACT_NAME_ALIAS = \"nature-zb4swpn6:v4\" #@param {type:\"string\"}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wandb.init(project=PROJECT_NAME,\n",
- " entity=ENTITY,\n",
- " job_type='resume_training')\n",
- "\n",
- "# Retrieve model checkpoint artifact and restore previous hyperparameters\n",
- "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n",
- "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n",
- "logging_run = model_chkpt_art.logged_by()\n",
- "wandb.config = logging_run.config\n",
- "\n",
- "# Can create a new artifact name or continue logging to the old one\n",
- "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n",
- "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name) \n",
- "\n",
- "log_predictions_callback = LogPredictionsCallback()\n",
- "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n",
- "\n",
- "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n",
- "\n",
- "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n",
- " artifact_name_alias = \"Nature_100:latest\",\n",
- " local_target_dir = \"Nature_100:latest\",\n",
- " batch_size=wandb.config['batch_size'],\n",
- " input_size=model.input_size)\n",
- "nature_module.setup()\n",
- "\n",
- "\n",
- "\n",
- "trainer = Trainer(logger=wandb_logger, # W&B integration\n",
- " resume_from_checkpoint = f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt',\n",
- " callbacks=[log_predictions_callback, checkpoint_callback],\n",
- " max_epochs=10,\n",
- " log_every_n_steps=5) \n",
- "trainer.fit(model, datamodule=nature_module)\n",
- "\n",
- "wandb.finish()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Model Registry \n",
- "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n",
- "\n",
- "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n",
- "1. Access your team's model registry by going the team page and selecting `Model Registry`\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n",
- "\n",
- "2. Create a new Registered Model. \n",
- "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n",
- "\n",
- "3. Go to the artifacts tab of the project that holds all your model checkpoints\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n",
- "\n",
- "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n",
- "\n",
- "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging. \n",
- "\n",
- "\n",
- "### Create a Centralized Hub for all your models\n",
- "- Add a model card, tags, slack notifactions to your Registered Model\n",
- "- Change aliases to reflect when models move through different phases\n",
- "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n",
- "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file