diff --git a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb index 16d7a3b6..32a5df77 100644 --- a/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb +++ b/colabs/pytorch-lightning/Wandb_End_to_End_with_PyTorch_Lightning.ipynb @@ -1,681 +1,737 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open\n", - "" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Ws0nlkuGOpDy" + }, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aU1q92uCOpD1" + }, + "source": [ + "\"Weights\n", + "\n", + "\n", + "\n", + "# W&B Tutorial with Pytorch Lightning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mHCKzEzSOpD1" + }, + "source": [ + "## 🛠️ Install `wandb` and `pytorch-lightning`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RnyGWvwDOpD1" + }, + "outputs": [], + "source": [ + "!pip install -q lightning wandb torchvision" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G5wpAEAoOpD2" + }, + "source": [ + "## Login to W&B either through Python or CLI\n", + "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n", + "\n", + "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n", + "```\n", + "import os\n", + "import wandb\n", + "\n", + "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n", + "```\n", + "You can also login via the CLI with:\n", + "```\n", + "wandb login --host \n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YjXGiBLQOpD2" + }, + "outputs": [], + "source": [ + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r5HflvclOpD2" + }, + "outputs": [], + "source": [ + "wandb.login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zNuxn8BjOpD2" + }, + "source": [ + "## ⚱ Logging the Raw Training Data as an Artifact" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X9mkr942OpD2" + }, + "outputs": [], + "source": [ + "#@title Enter your W&B project and entity\n", + "\n", + "# FORM VARIABLES\n", + "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n", + "ENTITY = \"wandb\"#@param {type:\"string\"}\n", + "\n", + "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n", + "# to select one of these three datasets\n", + "# TINY dataset: 100 images, 30MB\n", + "# SMALL dataset: 1000 images, 312MB\n", + "# MEDIUM dataset: 5000 images, 1.5GB\n", + "# LARGE dataset: 12,000 images, 3.6GB\n", + "\n", + "SIZE = \"TINY\"\n", + "\n", + "if SIZE == \"TINY\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n", + " src_zip = \"nature_100.zip\"\n", + " DATA_SRC = \"nature_100\"\n", + " IMAGES_PER_LABEL = 10\n", + " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n", + "elif SIZE == \"SMALL\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n", + " src_zip = \"nature_1K.zip\"\n", + " DATA_SRC = \"nature_1K\"\n", + " IMAGES_PER_LABEL = 100\n", + " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n", + "elif SIZE == \"MEDIUM\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", + " src_zip = \"nature_12K.zip\"\n", + " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", + " IMAGES_PER_LABEL = 500\n", + " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n", + "elif SIZE == \"LARGE\":\n", + " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", + " src_zip = \"nature_12K.zip\"\n", + " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", + " IMAGES_PER_LABEL = 1000\n", + " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o8nApIhdOpD3" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!curl -SL $src_url > $src_zip\n", + "!unzip $src_zip" + ] + }, + { + "cell_type": "code", + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ], + "metadata": { + "id": "ALmdQ7wISLaA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XQ_Kwsg9OpD3" + }, + "outputs": [], + "source": [ + "import wandb\n", + "import pandas as pd\n", + "import os\n", + "\n", + "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n", + " img_paths = []\n", + " for root, dirs, files in os.walk('nature_100', topdown=False):\n", + " for name in files:\n", + " img_path = os.path.join(root, name)\n", + " label = img_path.split('/')[1]\n", + " img_paths.append([img_path, label])\n", + "\n", + " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n", + " index_df.to_csv('index.csv', index=False)\n", + "\n", + " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n", + " train_art.add_dir('nature_100')\n", + "\n", + " # Also adding a csv indicating the labels of each image\n", + " train_art.add_file('index.csv')\n", + " wandb.log_artifact(train_art)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1eJpOk_VOpD3" + }, + "source": [ + "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n", + "- Makes it easy to interopt your DataLoaders with new versions of datasets\n", + "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z2g9JRrwOpD3" + }, + "outputs": [], + "source": [ + "from torchvision import transforms\n", + "import lightning.pytorch as pl\n", + "import torch\n", + "from torch.utils.data import Dataset, DataLoader, random_split\n", + "from skimage import io, transform\n", + "from torchvision import transforms, utils, models\n", + "import math\n", + "\n", + "class NatureDataset(Dataset):\n", + " def __init__(self,\n", + " wandb_run,\n", + " artifact_name_alias=\"Nature_100:latest\",\n", + " local_target_dir=\"Nature_100:latest\",\n", + " transform=None):\n", + " self.local_target_dir = local_target_dir\n", + " self.transform = transform\n", + "\n", + " # Pull down the artifact locally to load it into memory\n", + " art = wandb_run.use_artifact(artifact_name_alias)\n", + " path_at = art.download(root=self.local_target_dir)\n", + "\n", + " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n", + " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n", + " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n", + " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n", + "\n", + " def __len__(self):\n", + " return len(self.ref_df)\n", + "\n", + " def __getitem__(self, idx):\n", + " if torch.is_tensor(idx):\n", + " idx = idx.tolist()\n", + "\n", + " img_path = self.ref_df.iloc[idx, 0]\n", + "\n", + " image = io.imread(img_path)\n", + " label = self.ref_df.iloc[idx, 1]\n", + " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n", + "\n", + " if self.transform:\n", + " image = self.transform(image)\n", + "\n", + " return image, label\n", + "\n", + "\n", + "class NatureDatasetModule(pl.LightningDataModule):\n", + " def __init__(self,\n", + " wandb_run,\n", + " artifact_name_alias: str = \"Nature_100:latest\",\n", + " local_target_dir: str = \"Nature_100:latest\",\n", + " batch_size: int = 16,\n", + " input_size: int = 224,\n", + " seed: int = 42):\n", + " super().__init__()\n", + " self.wandb_run = wandb_run\n", + " self.artifact_name_alias = artifact_name_alias\n", + " self.local_target_dir = local_target_dir\n", + " self.batch_size = batch_size\n", + " self.input_size = input_size\n", + " self.seed = seed\n", + "\n", + " def setup(self, stage=None):\n", + " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n", + " artifact_name_alias=self.artifact_name_alias,\n", + " local_target_dir=self.local_target_dir,\n", + " transform=transforms.Compose([transforms.ToTensor(),\n", + " transforms.CenterCrop(self.input_size),\n", + " transforms.Normalize((0.485, 0.456, 0.406),\n", + " (0.229, 0.224, 0.225))]))\n", + "\n", + " nature_length = len(self.nature_dataset)\n", + " train_size = math.floor(0.8 * nature_length)\n", + " val_size = math.floor(0.2 * nature_length)\n", + " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n", + " [train_size, val_size],\n", + " generator=torch.Generator().manual_seed(self.seed))\n", + " return self\n", + "\n", + " def train_dataloader(self):\n", + " return DataLoader(self.nature_train, batch_size=self.batch_size)\n", + "\n", + " def val_dataloader(self):\n", + " return DataLoader(self.nature_val, batch_size=self.batch_size)\n", + "\n", + " def predict_dataloader(self):\n", + " pass\n", + "\n", + " def teardown(self, stage: str):\n", + " pass" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NLhjdvF-OpD3" + }, + "source": [ + "##How Logging in your Pytorch `LightningModule`works:\n", + "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`.\n", + "\n", + "```\n", + "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\")\n", + "trainer = Trainer(logger=wandb_logger)\n", + "```\n", + "\n", + "\n", + "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B.\n", + "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oi1NpQs7OpD4" + }, + "source": [ + "### Some helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HLgYji2oOpD4" + }, + "outputs": [], + "source": [ + "# Some helper functions\n", + "\n", + "def set_parameter_requires_grad(model, feature_extracting):\n", + " if feature_extracting:\n", + " for param in model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n", + " # Initialize these variables which will be set in this if statement. Each of these\n", + " # variables is model specific.\n", + " model_ft = None\n", + " input_size = 0\n", + "\n", + " if model_name == \"resnet\":\n", + " \"\"\" Resnet18\n", + " \"\"\"\n", + " model_ft = models.resnet18(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " num_ftrs = model_ft.fc.in_features\n", + " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n", + " input_size = 224\n", + "\n", + " elif model_name == \"squeezenet\":\n", + " \"\"\" Squeezenet\n", + " \"\"\"\n", + " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n", + " model_ft.num_classes = num_classes\n", + " input_size = 224\n", + "\n", + " elif model_name == \"densenet\":\n", + " \"\"\" Densenet\n", + " \"\"\"\n", + " model_ft = models.densenet121(pretrained=use_pretrained)\n", + " set_parameter_requires_grad(model_ft, feature_extract)\n", + " num_ftrs = model_ft.classifier.in_features\n", + " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n", + " input_size = 224\n", + "\n", + " else:\n", + " print(\"Invalid model name, exiting...\")\n", + " exit()\n", + "\n", + " return model_ft, input_size" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LLRK0S6oOpD4" + }, + "source": [ + "### Writing the `LightningModule`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "k0tTtK5zOpD4" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", + "from torch.optim import Adam\n", + "from torchmetrics.functional import accuracy\n", + "from lightning.pytorch import LightningModule\n", + "from torchvision import models\n", + "\n", + "class NatureLitModule(LightningModule):\n", + " def __init__(self,\n", + " model_name,\n", + " num_classes=10,\n", + " feature_extract=True,\n", + " lr=0.01):\n", + " '''method used to define our model parameters'''\n", + " super().__init__()\n", + "\n", + " self.model_name = model_name\n", + " self.num_classes = num_classes\n", + " self.feature_extract = feature_extract\n", + " self.model, self.input_size = initialize_model(model_name=self.model_name,\n", + " num_classes=self.num_classes,\n", + " feature_extract=True)\n", + "\n", + " # loss\n", + " self.loss = CrossEntropyLoss()\n", + "\n", + " # optimizer parameters\n", + " self.lr = lr\n", + "\n", + " # save hyper-parameters to self.hparams (auto-logged by W&B)\n", + " self.save_hyperparameters()\n", + "\n", + " # Record the gradients of all the layers\n", + " wandb.watch(self.model)\n", + "\n", + " def forward(self, x):\n", + " '''method used for inference input -> output'''\n", + " x = self.model(x)\n", + "\n", + " return x\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " '''needs to return a loss from a single batch'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('train/loss', loss)\n", + " self.log('train/accuracy', acc)\n", + "\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " '''used for logging metrics'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('validation/loss', loss)\n", + " self.log('validation/accuracy', acc)\n", + "\n", + " # Let's return preds to use it in a custom callback\n", + " return preds, y\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " '''used for logging metrics'''\n", + " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", + "\n", + " # Log loss and metric\n", + " self.log('test/loss', loss)\n", + " self.log('test/accuracy', acc)\n", + "\n", + " def configure_optimizers(self):\n", + " '''defines model optimizer'''\n", + " return Adam(self.parameters(), lr=self.lr)\n", + "\n", + "\n", + " def _get_preds_loss_accuracy(self, batch):\n", + " '''convenience function since train/valid/test steps are similar'''\n", + " x, y = batch\n", + " logits = self(x)\n", + " preds = torch.argmax(logits, dim=1)\n", + " loss = self.loss(logits, y)\n", + " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", + " return preds, y, loss, acc" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bvZHZuZzOpD4" + }, + "source": [ + "### Instrument Callbacks to log additional things at certain points in your code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "j9pHqaw1OpD5" + }, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import Callback\n", + "\n", + "class LogPredictionsCallback(Callback):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + "\n", + " def on_validation_epoch_start(self, trainer, pl_module):\n", + " self.batch_dfs = []\n", + " self.image_list = []\n", + " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n", + "\n", + "\n", + " def on_validation_batch_end(\n", + " self, trainer, pl_module, outputs, batch, batch_idx):\n", + " \"\"\"Called when the validation batch ends.\"\"\"\n", + "\n", + " # Append validation predictions and ground truth to log in confusion matrix\n", + " x, y = batch\n", + " preds, y = outputs\n", + " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.cpu().numpy(), \"Predictions\": preds.cpu().numpy()}))\n", + "\n", + " # Add wandb.Image to a table to log at the end of validation\n", + " x = x.cpu().numpy().transpose(0, 2, 3, 1)\n", + " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n", + " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n", + " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n", + "\n", + "\n", + " def on_validation_epoch_end(self, trainer, pl_module):\n", + " # Collect statistics for whole validation set and log\n", + " class_names = trainer.datamodule.nature_dataset.class_names\n", + " val_df = pd.concat(self.batch_dfs)\n", + " wandb.log({\"validation_table\": self.val_table,\n", + " \"images_over_time\": self.image_list,\n", + " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(),\n", + " preds=val_df[\"Predictions\"].tolist(),\n", + " class_names=class_names)}, step=trainer.global_step)\n", + "\n", + " del self.batch_dfs\n", + " del self.val_table\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAiH-kb1OpD5" + }, + "source": [ + "## 🏋️‍ Main Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FrYTQ7uaOpD5" + }, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from lightning.pytorch import Trainer\n", + "\n", + "wandb.init(project=PROJECT_NAME,\n", + " entity=ENTITY,\n", + " job_type='training',\n", + " config={\n", + " \"model_name\": \"squeezenet\",\n", + " \"batch_size\": 16\n", + " })\n", + "\n", + "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}')\n", + "\n", + "log_predictions_callback = LogPredictionsCallback()\n", + "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", + "\n", + "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", + "\n", + "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", + " artifact_name_alias = \"Nature_100:latest\",\n", + " local_target_dir = \"Nature_100:latest\",\n", + " batch_size=wandb.config['batch_size'],\n", + " input_size=model.input_size)\n", + "nature_module.setup()\n", + "\n", + "trainer = Trainer(logger=wandb_logger, # W&B integration\n", + " callbacks=[log_predictions_callback, checkpoint_callback],\n", + " max_epochs=5,\n", + " log_every_n_steps=5)\n", + "trainer.fit(model, datamodule=nature_module)\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uEII8J7UOpD5" + }, + "source": [ + "### Syncing with W&B Offline\n", + "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n", + "\n", + "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xi-EhHsxOpD5" + }, + "source": [ + "## Retrieve a model checkpoint artifact and resume training\n", + "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jiE1Bk7fOpD5" + }, + "outputs": [], + "source": [ + "#@title Enter which checkpoint you want to resume training from:\n", + "\n", + "# FORM VARIABLES\n", + "ARTIFACT_NAME_ALIAS = \"nature-oyxk79m1:v4\" #@param {type:\"string\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jZZXWRatOpD5" + }, + "outputs": [], + "source": [ + "wandb.init(project=PROJECT_NAME,\n", + " entity=ENTITY,\n", + " job_type='resume_training')\n", + "\n", + "# Retrieve model checkpoint artifact and restore previous hyperparameters\n", + "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n", + "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n", + "logging_run = model_chkpt_art.logged_by()\n", + "wandb.config = logging_run.config\n", + "\n", + "# Can create a new artifact name or continue logging to the old one\n", + "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n", + "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name)\n", + "\n", + "log_predictions_callback = LogPredictionsCallback()\n", + "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", + "\n", + "model = NatureLitModule.load_from_checkpoint(f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt')\n", + "\n", + "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", + " artifact_name_alias = \"Nature_100:latest\",\n", + " local_target_dir = \"Nature_100:latest\",\n", + " batch_size=wandb.config['batch_size'],\n", + " input_size=model.input_size)\n", + "nature_module.setup()\n", + "\n", + "\n", + "\n", + "trainer = Trainer(logger=wandb_logger, # W&B integration\n", + " callbacks=[log_predictions_callback, checkpoint_callback],\n", + " max_epochs=10,\n", + " log_every_n_steps=5)\n", + "trainer.fit(model, datamodule=nature_module)\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KTforR6EOpD5" + }, + "source": [ + "## Model Registry\n", + "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n", + "\n", + "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n", + "1. Access your team's model registry by going the team page and selecting `Model Registry`\n", + "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n", + "\n", + "2. Create a new Registered Model.\n", + "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n", + "\n", + "3. Go to the artifacts tab of the project that holds all your model checkpoints\n", + "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n", + "\n", + "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n", + "\n", + "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging.\n", + "\n", + "\n", + "### Create a Centralized Hub for all your models\n", + "- Add a model card, tags, slack notifactions to your Registered Model\n", + "- Change aliases to reflect when models move through different phases\n", + "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n", + "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Weights\n", - "\n", - "\n", - "\n", - "# W&B Tutorial with Pytorch Lightning" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🛠️ Install `wandb` and `pytorch-lightning`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -q lightning wandb torchvision" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Login to W&B either through Python or CLI\n", - "If you are using the public W&B cloud, you don't need to specify the `WANDB_HOST`.\n", - "\n", - "You can set environment variables `WANDB_API_KEY` and `WANDB_HOST` and pass them in as:\n", - "```\n", - "import os\n", - "import wandb \n", - "\n", - "wandb.login(host=os.getenv(\"WANDB_HOST\"), key=os.getenv(\"WANDB_API_KEY\"))\n", - "```\n", - "You can also login via the CLI with: \n", - "```\n", - "wandb login --host \n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import wandb" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb.login()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## ⚱ Logging the Raw Training Data as an Artifact" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#@title Enter your W&B project and entity\n", - "\n", - "# FORM VARIABLES\n", - "PROJECT_NAME = \"pytorch-lightning-e2e\" #@param {type:\"string\"}\n", - "ENTITY = \"wandb\"#@param {type:\"string\"}\n", - "\n", - "# set SIZE to \"TINY\", \"SMALL\", \"MEDIUM\", or \"LARGE\"\n", - "# to select one of these three datasets\n", - "# TINY dataset: 100 images, 30MB\n", - "# SMALL dataset: 1000 images, 312MB\n", - "# MEDIUM dataset: 5000 images, 1.5GB\n", - "# LARGE dataset: 12,000 images, 3.6GB\n", - "\n", - "SIZE = \"TINY\"\n", - "\n", - "if SIZE == \"TINY\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_100.zip\"\n", - " src_zip = \"nature_100.zip\"\n", - " DATA_SRC = \"nature_100\"\n", - " IMAGES_PER_LABEL = 10\n", - " BALANCED_SPLITS = {\"train\" : 8, \"val\" : 1, \"test\": 1}\n", - "elif SIZE == \"SMALL\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_1K.zip\"\n", - " src_zip = \"nature_1K.zip\"\n", - " DATA_SRC = \"nature_1K\"\n", - " IMAGES_PER_LABEL = 100\n", - " BALANCED_SPLITS = {\"train\" : 80, \"val\" : 10, \"test\": 10}\n", - "elif SIZE == \"MEDIUM\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", - " src_zip = \"nature_12K.zip\"\n", - " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", - " IMAGES_PER_LABEL = 500\n", - " BALANCED_SPLITS = {\"train\" : 400, \"val\" : 50, \"test\": 50}\n", - "elif SIZE == \"LARGE\":\n", - " src_url = \"https://storage.googleapis.com/wandb_datasets/nature_12K.zip\"\n", - " src_zip = \"nature_12K.zip\"\n", - " DATA_SRC = \"inaturalist_12K/train\" # (technically a subset of only 10K images)\n", - " IMAGES_PER_LABEL = 1000\n", - " BALANCED_SPLITS = {\"train\" : 800, \"val\" : 100, \"test\": 100}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "!curl -SL $src_url > $src_zip\n", - "!unzip $src_zip" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import wandb\n", - "import pandas as pd\n", - "import os\n", - "\n", - "with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:\n", - " img_paths = []\n", - " for root, dirs, files in os.walk('nature_100', topdown=False):\n", - " for name in files:\n", - " img_path = os.path.join(root, name)\n", - " label = img_path.split('/')[1]\n", - " img_paths.append([img_path, label])\n", - "\n", - " index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)\n", - " index_df.to_csv('index.csv', index=False)\n", - "\n", - " train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')\n", - " train_art.add_dir('nature_100')\n", - "\n", - " # Also adding a csv indicating the labels of each image\n", - " train_art.add_file('index.csv')\n", - " wandb.log_artifact(train_art)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Using Artifacts in Pytorch Lightning `DataModule`'s and Pytorch `Dataset`'s\n", - "- Makes it easy to interopt your DataLoaders with new versions of datasets\n", - "- Just indicate the `name:alias` as an argument to your `Dataset` or `DataModule`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torchvision import transforms\n", - "import lightning.pytorch as pl\n", - "import torch\n", - "from torch.utils.data import Dataset, DataLoader, random_split\n", - "from skimage import io, transform\n", - "from torchvision import transforms, utils, models\n", - "import math\n", - "\n", - "class NatureDataset(Dataset):\n", - " def __init__(self, \n", - " wandb_run, \n", - " artifact_name_alias=\"Nature_100:latest\", \n", - " local_target_dir=\"Nature_100:latest\", \n", - " transform=None):\n", - " self.local_target_dir = local_target_dir\n", - " self.transform = transform\n", - "\n", - " # Pull down the artifact locally to load it into memory\n", - " art = wandb_run.use_artifact(artifact_name_alias)\n", - " path_at = art.download(root=self.local_target_dir)\n", - "\n", - " self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))\n", - " self.class_names = self.ref_df.iloc[:, 1].unique().tolist()\n", - " self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}\n", - " self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}\n", - "\n", - " def __len__(self):\n", - " return len(self.ref_df)\n", - "\n", - " def __getitem__(self, idx):\n", - " if torch.is_tensor(idx):\n", - " idx = idx.tolist()\n", - "\n", - " img_path = self.ref_df.iloc[idx, 0]\n", - "\n", - " image = io.imread(img_path)\n", - " label = self.ref_df.iloc[idx, 1]\n", - " label = torch.tensor(self.class_to_idx[label], dtype=torch.long)\n", - "\n", - " if self.transform:\n", - " image = self.transform(image)\n", - "\n", - " return image, label\n", - "\n", - "\n", - "class NatureDatasetModule(pl.LightningDataModule):\n", - " def __init__(self,\n", - " wandb_run,\n", - " artifact_name_alias: str = \"Nature_100:latest\",\n", - " local_target_dir: str = \"Nature_100:latest\",\n", - " batch_size: int = 16,\n", - " input_size: int = 224,\n", - " seed: int = 42):\n", - " super().__init__()\n", - " self.wandb_run = wandb_run\n", - " self.artifact_name_alias = artifact_name_alias\n", - " self.local_target_dir = local_target_dir\n", - " self.batch_size = batch_size\n", - " self.input_size = input_size\n", - " self.seed = seed\n", - "\n", - " def setup(self, stage=None):\n", - " self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,\n", - " artifact_name_alias=self.artifact_name_alias,\n", - " local_target_dir=self.local_target_dir,\n", - " transform=transforms.Compose([transforms.ToTensor(),\n", - " transforms.CenterCrop(self.input_size),\n", - " transforms.Normalize((0.485, 0.456, 0.406),\n", - " (0.229, 0.224, 0.225))]))\n", - "\n", - " nature_length = len(self.nature_dataset)\n", - " train_size = math.floor(0.8 * nature_length)\n", - " val_size = math.floor(0.2 * nature_length)\n", - " self.nature_train, self.nature_val = random_split(self.nature_dataset,\n", - " [train_size, val_size],\n", - " generator=torch.Generator().manual_seed(self.seed))\n", - " return self\n", - "\n", - " def train_dataloader(self):\n", - " return DataLoader(self.nature_train, batch_size=self.batch_size)\n", - "\n", - " def val_dataloader(self):\n", - " return DataLoader(self.nature_val, batch_size=self.batch_size)\n", - "\n", - " def predict_dataloader(self):\n", - " pass\n", - "\n", - " def teardown(self, stage: str):\n", - " pass" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##How Logging in your Pytorch `LightningModule`works:\n", - "When you train the model using `Trainer`, ensure you have a `WandbLogger` instantiated and passed in as a `logger`. \n", - " \n", - "```\n", - "wandb_logger = WandbLogger(project=\"my_project\", entity=\"machine-learning\") \n", - "trainer = Trainer(logger=wandb_logger) \n", - "```\n", - "\n", - "\n", - "You can always use `wandb.log` as normal throughout the module. When the `WandbLogger` is used, `self.log` will also log metrics to W&B. \n", - "- To access the current run from within the `LightningModule`, you can access `Trainer.logger.experiment`, which is a `wandb.Run` object" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Some helper functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Some helper functions\n", - "\n", - "def set_parameter_requires_grad(model, feature_extracting):\n", - " if feature_extracting:\n", - " for param in model.parameters():\n", - " param.requires_grad = False\n", - "\n", - "def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):\n", - " # Initialize these variables which will be set in this if statement. Each of these\n", - " # variables is model specific.\n", - " model_ft = None\n", - " input_size = 0\n", - "\n", - " if model_name == \"resnet\":\n", - " \"\"\" Resnet18\n", - " \"\"\"\n", - " model_ft = models.resnet18(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " num_ftrs = model_ft.fc.in_features\n", - " model_ft.fc = torch.nn.Linear(num_ftrs, num_classes)\n", - " input_size = 224\n", - "\n", - " elif model_name == \"squeezenet\":\n", - " \"\"\" Squeezenet\n", - " \"\"\"\n", - " model_ft = models.squeezenet1_0(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))\n", - " model_ft.num_classes = num_classes\n", - " input_size = 224\n", - "\n", - " elif model_name == \"densenet\":\n", - " \"\"\" Densenet\n", - " \"\"\"\n", - " model_ft = models.densenet121(pretrained=use_pretrained)\n", - " set_parameter_requires_grad(model_ft, feature_extract)\n", - " num_ftrs = model_ft.classifier.in_features\n", - " model_ft.classifier = torch.nn.Linear(num_ftrs, num_classes)\n", - " input_size = 224\n", - "\n", - " else:\n", - " print(\"Invalid model name, exiting...\")\n", - " exit()\n", - "\n", - " return model_ft, input_size" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Writing the `LightningModule`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from torch.nn import Linear, CrossEntropyLoss, functional as F\n", - "from torch.optim import Adam\n", - "from torchmetrics.functional import accuracy\n", - "from lightning.pytorch import LightningModule\n", - "from torchvision import models\n", - "\n", - "class NatureLitModule(LightningModule):\n", - " def __init__(self,\n", - " model_name,\n", - " num_classes=10,\n", - " feature_extract=True,\n", - " lr=0.01):\n", - " '''method used to define our model parameters'''\n", - " super().__init__()\n", - "\n", - " self.model_name = model_name\n", - " self.num_classes = num_classes\n", - " self.feature_extract = feature_extract\n", - " self.model, self.input_size = initialize_model(model_name=self.model_name,\n", - " num_classes=self.num_classes,\n", - " feature_extract=True)\n", - "\n", - " # loss\n", - " self.loss = CrossEntropyLoss()\n", - "\n", - " # optimizer parameters\n", - " self.lr = lr\n", - "\n", - " # save hyper-parameters to self.hparams (auto-logged by W&B)\n", - " self.save_hyperparameters()\n", - "\n", - " # Record the gradients of all the layers\n", - " wandb.watch(self.model)\n", - "\n", - " def forward(self, x):\n", - " '''method used for inference input -> output'''\n", - " x = self.model(x)\n", - "\n", - " return x\n", - "\n", - " def training_step(self, batch, batch_idx):\n", - " '''needs to return a loss from a single batch'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('train/loss', loss)\n", - " self.log('train/accuracy', acc)\n", - "\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " '''used for logging metrics'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('validation/loss', loss)\n", - " self.log('validation/accuracy', acc)\n", - "\n", - " # Let's return preds to use it in a custom callback\n", - " return preds, y\n", - "\n", - " def validation_epoch_end(self, validation_step_outputs):\n", - " \"\"\"Called when the validation ends.\"\"\"\n", - " preds, y = validation_step_outputs\n", - " all_preds = torch.stack(preds)\n", - " all_y = torch.stack(y)\n", - "\n", - " def test_step(self, batch, batch_idx):\n", - " '''used for logging metrics'''\n", - " preds, y, loss, acc = self._get_preds_loss_accuracy(batch)\n", - "\n", - " # Log loss and metric\n", - " self.log('test/loss', loss)\n", - " self.log('test/accuracy', acc)\n", - "\n", - " def configure_optimizers(self):\n", - " '''defines model optimizer'''\n", - " return Adam(self.parameters(), lr=self.lr)\n", - "\n", - "\n", - " def _get_preds_loss_accuracy(self, batch):\n", - " '''convenience function since train/valid/test steps are similar'''\n", - " x, y = batch\n", - " logits = self(x)\n", - " preds = torch.argmax(logits, dim=1)\n", - " loss = self.loss(logits, y)\n", - " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", - " return preds, y, loss, acc" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Instrument Callbacks to log additional things at certain points in your code" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightning.pytorch.callbacks import Callback\n", - "\n", - "class LogPredictionsCallback(Callback):\n", - "\n", - " def __init__(self):\n", - " super().__init__()\n", - "\n", - " \n", - " def on_validation_epoch_start(self, trainer, pl_module):\n", - " self.batch_dfs = []\n", - " self.image_list = []\n", - " self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])\n", - "\n", - " \n", - " def on_validation_batch_end(\n", - " self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n", - " \"\"\"Called when the validation batch ends.\"\"\"\n", - "\n", - " # Append validation predictions and ground truth to log in confusion matrix\n", - " x, y = batch\n", - " preds, y = outputs\n", - " self.batch_dfs.append(pd.DataFrame({\"Ground Truth\": y.numpy(), \"Predictions\": preds.numpy()}))\n", - "\n", - " # Add wandb.Image to a table to log at the end of validation\n", - " x = x.numpy().transpose(0, 2, 3, 1)\n", - " for x_i, y_i, y_pred in list(zip(x, y, preds)):\n", - " self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))\n", - " self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)\n", - " \n", - " \n", - " def on_validation_epoch_end(self, trainer, pl_module):\n", - " # Collect statistics for whole validation set and log\n", - " class_names = trainer.datamodule.nature_dataset.class_names\n", - " val_df = pd.concat(self.batch_dfs)\n", - " wandb.log({\"validation_table\": self.val_table,\n", - " \"images_over_time\": self.image_list,\n", - " \"validation_conf_matrix\": wandb.plot.confusion_matrix(y_true = val_df[\"Ground Truth\"].tolist(), \n", - " preds=val_df[\"Predictions\"].tolist(), \n", - " class_names=class_names)}, step=trainer.global_step)\n", - "\n", - " del self.batch_dfs\n", - " del self.val_table\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 🏋️‍ Main Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from lightning.pytorch.callbacks import ModelCheckpoint\n", - "from lightning.pytorch.loggers import WandbLogger\n", - "from lightning.pytorch import Trainer\n", - "\n", - "wandb.init(project=PROJECT_NAME,\n", - " entity=ENTITY,\n", - " job_type='training',\n", - " config={\n", - " \"model_name\": \"squeezenet\",\n", - " \"batch_size\": 16\n", - " })\n", - "\n", - "wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'nature-{wandb.run.id}') \n", - "\n", - "log_predictions_callback = LogPredictionsCallback()\n", - "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", - "\n", - "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", - "\n", - "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", - " artifact_name_alias = \"Nature_100:latest\",\n", - " local_target_dir = \"Nature_100:latest\",\n", - " batch_size=wandb.config['batch_size'],\n", - " input_size=model.input_size)\n", - "nature_module.setup()\n", - "\n", - "trainer = Trainer(logger=wandb_logger, # W&B integration\n", - " callbacks=[log_predictions_callback, checkpoint_callback],\n", - " max_epochs=5,\n", - " log_every_n_steps=5) \n", - "trainer.fit(model, datamodule=nature_module)\n", - "\n", - "wandb.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Syncing with W&B Offline\n", - "If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`\n", - "\n", - "The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app. " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Retrieve a model checkpoint artifact and resume training\n", - "- Artifacts make it easy to track state of your training remotely and then resume training from a checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#@title Enter which checkpoint you want to resume training from:\n", - "\n", - "# FORM VARIABLES\n", - "ARTIFACT_NAME_ALIAS = \"nature-zb4swpn6:v4\" #@param {type:\"string\"}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wandb.init(project=PROJECT_NAME,\n", - " entity=ENTITY,\n", - " job_type='resume_training')\n", - "\n", - "# Retrieve model checkpoint artifact and restore previous hyperparameters\n", - "model_chkpt_art = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/{ARTIFACT_NAME_ALIAS}')\n", - "model_chkpt_art.download() # Can change download directory by adding `root`, defaults to \"./artifacts\"\n", - "logging_run = model_chkpt_art.logged_by()\n", - "wandb.config = logging_run.config\n", - "\n", - "# Can create a new artifact name or continue logging to the old one\n", - "artifact_name = ARTIFACT_NAME_ALIAS.split(\":\")[0]\n", - "wandb_logger = WandbLogger(log_model='all', checkpoint_name=artifact_name) \n", - "\n", - "log_predictions_callback = LogPredictionsCallback()\n", - "checkpoint_callback = ModelCheckpoint(every_n_epochs=1)\n", - "\n", - "model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets\n", - "\n", - "nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,\n", - " artifact_name_alias = \"Nature_100:latest\",\n", - " local_target_dir = \"Nature_100:latest\",\n", - " batch_size=wandb.config['batch_size'],\n", - " input_size=model.input_size)\n", - "nature_module.setup()\n", - "\n", - "\n", - "\n", - "trainer = Trainer(logger=wandb_logger, # W&B integration\n", - " resume_from_checkpoint = f'./artifacts/{ARTIFACT_NAME_ALIAS}/model.ckpt',\n", - " callbacks=[log_predictions_callback, checkpoint_callback],\n", - " max_epochs=10,\n", - " log_every_n_steps=5) \n", - "trainer.fit(model, datamodule=nature_module)\n", - "\n", - "wandb.finish()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Registry \n", - "After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).\n", - "\n", - "The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be \"linked\" to a Registered Model. Here are the steps to start using the model registry for more organized model management:\n", - "1. Access your team's model registry by going the team page and selecting `Model Registry`\n", - "![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)\n", - "\n", - "2. Create a new Registered Model. \n", - "![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)\n", - "\n", - "3. Go to the artifacts tab of the project that holds all your model checkpoints\n", - "![model registry](https://drive.google.com/uc?export=view&id=1r_jlhhtcU3as8VwQ-4oAntd8YtTwElFB)\n", - "\n", - "4. Click \"Link to Registry\" for the model artifact version you want. (Alternatively you can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact`)\n", - "\n", - "**A note on linking:** The process of linking a model checkpoint is akin to \"bookmarking\" it. Each time you link a new model artifact to a Registered Model, this increments the version of the Registered Model. This helps delineate the model development side of the workflow from the model deployment/consumption side. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new \"bookmarked\" models as opposed to model checkpoint logging. \n", - "\n", - "\n", - "### Create a Centralized Hub for all your models\n", - "- Add a model card, tags, slack notifactions to your Registered Model\n", - "- Change aliases to reflect when models move through different phases\n", - "- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)\n", - "![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)\n" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "include_colab_link": true, - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file