diff --git a/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb b/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
index 6b427d89..9a36ca45 100644
--- a/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
+++ b/colabs/mosaicml/MosaicML_Composer_and_wandb.ipynb
@@ -1,456 +1,506 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "0196e78a",
- "metadata": {},
- "source": [
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "id": "362c2ed6",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "# Running fast with MosaicML Composer and Weight and Biases"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "c740179e",
- "metadata": {},
- "source": [
- "[MosaicML Composer](https://docs.mosaicml.com) is a library for training neural networks better, faster, and cheaper. It contains many state-of-the-art methods for accelerating neural network training and improving generalization, along with an optional Trainer API that makes composing many different enhancements easy.\n",
- "\n",
- "Coupled with [Weights & Biases integration](https://docs.mosaicml.com/en/v0.5.0/trainer/logging.html), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
- "\n",
- "```python\n",
- "from composer.loggers import WandBLogger\n",
- "wandb_logger = WandBLogger()\n",
- "```"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0eca27e4",
- "metadata": {},
- "source": [
- "W&B integration with Composer can automatically:\n",
- "* log your configuration parameters\n",
- "* log your losses and metrics\n",
- "* log gradients and parameter distributions\n",
- "* log your model\n",
- "* keep track of your code\n",
- "* log your system metrics (GPU, CPU, memory, temperature, etc)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5640b1a0",
- "metadata": {},
- "source": [
- "### 🛠️ Installation and set-up\n",
- "\n",
- "We need to install the following libraries:\n",
- "* [mosaicml-composer](https://docs.mosaicml.com/en/v0.5.0/getting_started/installation.html) to set up and train our models\n",
- "* [wandb](https://docs.wandb.ai/) to instrument our training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "764b0904",
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install -Uq wandb mosaicml fastcore"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "deb25a49",
- "metadata": {},
- "source": [
- "## Getting Started with Composer 🔥"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e713fe81",
- "metadata": {},
- "source": [
- "Composer gives you access to a set of functions to speedup your models and infuse them with state of the art methods. For instance, you can insert [BlurPool](https://docs.mosaicml.com/en/latest/method_cards/blurpool.html) into your CNN by calling `CF.apply_blurpool(model)` into your PyTorch model. Take a look at all the [functional](https://docs.mosaicml.com/en/latest/functional_api.html) methods available."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2c9413bd",
- "metadata": {},
- "outputs": [],
- "source": [
- "import logging\n",
- "from composer import functional as CF\n",
- "import torchvision.models as models\n",
- "\n",
- "logging.basicConfig(level=logging.INFO)\n",
- "model = models.resnet50()\n",
- "\n",
- "# replace some layers with blurpool\n",
- "CF.apply_blurpool(model);\n",
- "# replace some layers with squeeze-excite\n",
- "CF.apply_squeeze_excite(model, latent_channels=64, min_channels=128);"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0eb3e57b",
- "metadata": {},
- "source": [
- "> 💡 you can use this upgraded model with your favourite PyTorch training or... "
- ]
- },
- {
- "cell_type": "markdown",
- "id": "02c31ab7",
- "metadata": {},
- "source": [
- "## Use the `Trainer` class with Weights and Biases 🏋️♀️\n",
- "\n",
- "W&B integration with MosaicML-Composer is built into the `Trainer` and can be configured to add extra functionalities through `WandBLogger`:\n",
- "\n",
- "* logging of Artifacts: Use `log_artifacts=True` to log model checkpoints as `wandb.Artifacts`. You can setup how often by passing an int value to `log_artifacts_every_n_batches` (default = 100)\n",
- "* you can also pass any parameter that you would pass to `wandb.init` in `init_params` as a dictionary. For example, you could pass `init_params = {\"project\":\"try_mosaicml\", \"name\":\"benchmark\", \"entity\":\"user_name\"}`.\n",
- "\n",
- "For more details refer to [Logger documentation](https://docs.mosaicml.com/en/latest/api_reference/composer.loggers.wandb_logger.html#composer.loggers.wandb_logger.WandBLogger) and [Wandb docs](https://docs.wandb.ai)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1e12ce9d",
- "metadata": {},
- "outputs": [],
- "source": [
- "EPOCHS = 5\n",
- "BS = 32"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4dbc1493",
- "metadata": {},
- "outputs": [],
- "source": [
- "import wandb\n",
- "\n",
- "from torch.utils.data import DataLoader\n",
- "from torchvision import datasets\n",
- "\n",
- "import torchvision.transforms as T\n",
- "\n",
- "from composer import Callback, State, Logger, Trainer\n",
- "from composer.models import MNIST_Classifier\n",
- "from composer.loggers import WandBLogger, TQDMLogger\n",
- "from composer.callbacks import SpeedMonitor, LRMonitor"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "ef7be365",
- "metadata": {},
- "source": [
- "let's grab a copy of MNIST from `torchvision`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b298a861",
- "metadata": {},
- "outputs": [],
- "source": [
- "train_dataset = datasets.MNIST(\".\", train=True, download=True, transform=T.ToTensor())\n",
- "eval_dataset = datasets.MNIST(\".\", train=False, download=True, transform=T.ToTensor())\n",
- "\n",
- "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BS, num_workers=2)\n",
- "eval_dataloader = DataLoader(eval_dataset, batch_size=2*BS, num_workers=2)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b798a9ed",
- "metadata": {},
- "source": [
- "we can import a simple ConvNet model to try"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6498cf78",
- "metadata": {},
- "outputs": [],
- "source": [
- "model = MNIST_Classifier()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8e9daaa5",
- "metadata": {},
- "source": [
- "### 📊 Tracking the experiment\n",
- "> we define the `wandb.init` params here"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "73bccc9f",
- "metadata": {},
- "outputs": [],
- "source": [
- "# config params to log\n",
- "config = {\"epochs\":EPOCHS, \n",
- " \"batch_size\":BS,\n",
- " \"model_name\":\"MNIST_Classifier\"}\n",
- "\n",
- "# these will get passed to wandb.init(**init_params)\n",
- "init_params = {\"project\":\"mosaic_ml\", \n",
- " \"name\":\"mnist_baseline\",\n",
- " \"config\":config}\n",
- "\n",
- "# setup of the logger \n",
- "wandb_logger = WandBLogger(init_params=init_params)\n",
- "\n",
- "# we also add progressbar logging\n",
- "progress_logger = TQDMLogger()\n",
- "\n",
- "loggers = [progress_logger, wandb_logger]"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0dc1f7f6",
- "metadata": {},
- "source": [
- "we are able to tweak what are we logging using `Callbacks` into the `Trainer` class."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9b470f1a",
- "metadata": {},
- "outputs": [],
- "source": [
- "callbacks = [LRMonitor(), # Logs the learning rate\n",
- " SpeedMonitor(), # Logs the training throughput\n",
- " ]"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "28920cc2",
- "metadata": {},
- "source": [
- "we include callbacks that measure the model throughput (and the learning rate) and logs them to Weights & Biases. [Callbacks](https://docs.mosaicml.com/en/latest/trainer/callbacks.html) control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see [Logging](https://docs.mosaicml.com/en/latest/trainer/logging.html)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d1cd982a",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration=f\"{EPOCHS}ep\",\n",
- " loggers=loggers,\n",
- " callbacks=callbacks,\n",
- " device=\"gpu\", # to train on GPU\n",
- " precision=\"amp\", # use mixed precision training, nice speed bump\n",
- "\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0888b0c5",
- "metadata": {},
- "source": [
- "once we are ready to train we call `fit`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "2ca3468a",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "598495ee",
- "metadata": {},
- "source": [
- "## ⚙️ Advanced: Using callbacks to log sample predictions\n",
- "\n",
- "> Composer is extensible through its callback system.\n",
- "\n",
- "We create a custom callback to automatically log sample predictions during validation."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "00401b44",
- "metadata": {},
- "outputs": [],
- "source": [
- "class LogPredictions(Callback):\n",
- " def __init__(self, num_samples=100):\n",
- " super().__init__()\n",
- " self.num_samples = num_samples\n",
- " self.data = []\n",
- " \n",
- " def eval_batch_end(self, state: State, logger: Logger):\n",
- " \"\"\"Compute predictions per batch and stores them on self.data\"\"\"\n",
- " if state.timer.epoch == state.max_duration: # on last val epoch\n",
- " if len(self.data) < self.num_samples:\n",
- " n = self.num_samples\n",
- " x, y = state.batch_pair\n",
- " outputs = state.outputs.argmax(-1)\n",
- " data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]\n",
- " self.data += data\n",
- " \n",
- " def eval_end(self, state: State, logger: Logger):\n",
- " \"Create a wandb.Table and logs it\"\n",
- " columns = ['image', 'ground truth', 'prediction']\n",
- " table = wandb.Table(columns=columns, data=self.data[:self.num_samples])\n",
- " wandb.log({'predictions_table':table}, step=int(state.timer.batch))"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6161475c",
- "metadata": {},
- "source": [
- "we add `LogPredictions` to the other callbacks"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "94d39bd5",
- "metadata": {},
- "outputs": [],
- "source": [
- "callbacks.append(LogPredictions())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8ea986a8",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer = Trainer(\n",
- " model=model,\n",
- " train_dataloader=train_dataloader,\n",
- " eval_dataloader=eval_dataloader,\n",
- " max_duration=f\"{EPOCHS}ep\",\n",
- " loggers=loggers,\n",
- " callbacks=callbacks,\n",
- " device=\"gpu\", # to train on GPU\n",
- " precision=\"amp\", # use mixed precision training, nice speed bump\n",
- "\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a4deb712",
- "metadata": {},
- "source": [
- "Once we're ready to train, we just call the `fit` method."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c8b00679",
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "2c2e89f6",
- "metadata": {},
- "source": [
- "We can monitor losses, metrics, gradients, parameters and sample predictions as the model trains."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f4889215",
- "metadata": {},
- "source": [
- "![composer.png](https://i.imgur.com/VFZLOB3.png?1)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "97f57640",
- "metadata": {},
- "source": [
- "## 📚 Resources\n",
- "\n",
- "* We are excited to showcase this early support of [MosaicML-Composer](https://docs.mosaicml.com/en/latest/index.html) go ahead and try this new state of the art framework."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "c6722706",
- "metadata": {},
- "source": [
- "## ❓ Questions about W&B\n",
- "\n",
- "If you have any questions about using W&B to track your model performance and predictions, please reach out to the [wandb community](https://community.wandb.ai)."
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0196e78a",
+ "metadata": {
+ "id": "0196e78a"
+ },
+ "source": [
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "362c2ed6",
+ "metadata": {
+ "id": "362c2ed6"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "# Running fast with MosaicML Composer and Weight and Biases"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c740179e",
+ "metadata": {
+ "id": "c740179e"
+ },
+ "source": [
+ "[MosaicML Composer](https://docs.mosaicml.com) is a library for training neural networks better, faster, and cheaper. It contains many state-of-the-art methods for accelerating neural network training and improving generalization, along with an optional Trainer API that makes composing many different enhancements easy.\n",
+ "\n",
+ "Coupled with [Weights & Biases integration](https://docs.wandb.ai/guides/integrations/composer), you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:\n",
+ "\n",
+ "```python\n",
+ "from composer import Trainer\n",
+ "from composer.loggers import WandBLogger\n",
+ "\n",
+ "wandb_logger = WandBLogger(init_params=init_params)\n",
+ "trainer = Trainer(..., logger=wandb_logger)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0eca27e4",
+ "metadata": {
+ "id": "0eca27e4"
+ },
+ "source": [
+ "W&B integration with Composer can automatically:\n",
+ "* log your configuration parameters\n",
+ "* log your losses and metrics\n",
+ "* log gradients and parameter distributions\n",
+ "* log your model\n",
+ "* keep track of your code\n",
+ "* log your system metrics (GPU, CPU, memory, temperature, etc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5640b1a0",
+ "metadata": {
+ "id": "5640b1a0"
+ },
+ "source": [
+ "### 🛠️ Installation and set-up\n",
+ "\n",
+ "We need to install the following libraries:\n",
+ "* [mosaicml-composer](https://docs.mosaicml.com/en/v0.5.0/getting_started/installation.html) to set up and train our models\n",
+ "* [wandb](https://docs.wandb.ai/) to instrument our training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "764b0904",
+ "metadata": {
+ "id": "764b0904"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -Uq wandb mosaicml"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "02c31ab7",
+ "metadata": {
+ "id": "02c31ab7"
+ },
+ "source": [
+ "## Use the Composer `Trainer` class with Weights and Biases 🏋️♀️\n",
+ "\n",
+ "W&B integration with MosaicML-Composer is built into the `Trainer` and can be configured to add extra functionalities through `WandBLogger`:\n",
+ "\n",
+ "* logging of Artifacts: Use `log_artifacts=True` to log model checkpoints as `wandb.Artifacts`. You can setup how often by passing an int value to `log_artifacts_every_n_batches` (default = 100)\n",
+ "* you can also pass any parameter that you would pass to `wandb.init` in `init_params` as a dictionary. For example, you could pass `init_params = {\"project\":\"try_mosaicml\", \"name\":\"benchmark\", \"entity\":\"user_name\"}`.\n",
+ "\n",
+ "For more details refer to [Logger documentation](https://docs.mosaicml.com/en/latest/api_reference/composer.loggers.wandb_logger.html#composer.loggers.wandb_logger.WandBLogger) and [Wandb docs](https://docs.wandb.ai)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "1e12ce9d",
+ "metadata": {
+ "id": "1e12ce9d"
+ },
+ "outputs": [],
+ "source": [
+ "EPOCHS = 5\n",
+ "BS = 32"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "4dbc1493",
+ "metadata": {
+ "id": "4dbc1493"
+ },
+ "outputs": [],
+ "source": [
+ "import wandb\n",
+ "\n",
+ "from torchvision import datasets, transforms\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "from composer import Callback, State, Logger, Trainer\n",
+ "from composer.models import mnist_model\n",
+ "from composer.loggers import WandBLogger\n",
+ "from composer.callbacks import SpeedMonitor, LRMonitor\n",
+ "from composer.algorithms import LabelSmoothing, CutMix, ChannelsLast"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef7be365",
+ "metadata": {
+ "id": "ef7be365"
+ },
+ "source": [
+ "let's grab a copy of MNIST from `torchvision`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b298a861",
+ "metadata": {
+ "id": "b298a861"
+ },
+ "outputs": [],
+ "source": [
+ "transform = transforms.Compose([transforms.ToTensor()])\n",
+ "dataset = datasets.MNIST(\"data\", train=True, download=True, transform=transform)\n",
+ "train_dataloader = DataLoader(dataset, batch_size=128)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b798a9ed",
+ "metadata": {
+ "id": "b798a9ed"
+ },
+ "source": [
+ "we can import a simple ConvNet model to try"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "6498cf78",
+ "metadata": {
+ "id": "6498cf78"
+ },
+ "outputs": [],
+ "source": [
+ "model = mnist_model(num_classes=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8e9daaa5",
+ "metadata": {
+ "id": "8e9daaa5"
+ },
+ "source": [
+ "### 📊 Tracking the experiment\n",
+ "> we define the `wandb.init` params here"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "73bccc9f",
+ "metadata": {
+ "id": "73bccc9f"
+ },
+ "outputs": [],
+ "source": [
+ "# config params to log\n",
+ "config = {\"epochs\":EPOCHS,\n",
+ " \"batch_size\":BS,\n",
+ " \"model_name\":\"MNIST_Classifier\"}\n",
+ "\n",
+ "# these will get passed to wandb.init(**init_params)\n",
+ "wandb_init_kwargs = {\"config\":config}\n",
+ "\n",
+ "# setup of the logger\n",
+ "wandb_logger = WandBLogger(project=\"mnist-composer\",\n",
+ " log_artifacts=True,\n",
+ " init_kwargs=wandb_init_kwargs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0dc1f7f6",
+ "metadata": {
+ "id": "0dc1f7f6"
+ },
+ "source": [
+ "we are able to tweak what are we logging using `Callbacks` into the `Trainer` class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "id": "9b470f1a",
+ "metadata": {
+ "id": "9b470f1a"
+ },
+ "outputs": [],
+ "source": [
+ "callbacks = [LRMonitor(), # Logs the learning rate\n",
+ " SpeedMonitor(), # Logs the training throughput\n",
+ " ]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28920cc2",
+ "metadata": {
+ "id": "28920cc2"
+ },
+ "source": [
+ "we include callbacks that measure the model throughput (and the learning rate) and logs them to Weights & Biases. [Callbacks](https://docs.mosaicml.com/en/latest/trainer/callbacks.html) control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see [Logging](https://docs.mosaicml.com/en/latest/trainer/logging.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer = Trainer(\n",
+ " model=mnist_model(num_classes=10),\n",
+ " train_dataloader=train_dataloader,\n",
+ " max_duration=\"2ep\",\n",
+ " loggers=[wandb_logger], # Pass your WandbLogger\n",
+ " callbacks=callbacks,\n",
+ " algorithms=[\n",
+ " LabelSmoothing(smoothing=0.1),\n",
+ " CutMix(alpha=1.0),\n",
+ " ChannelsLast(),\n",
+ " ]\n",
+ ")"
+ ],
+ "metadata": {
+ "id": "kmUiJQZoGU5D"
+ },
+ "id": "kmUiJQZoGU5D",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0888b0c5",
+ "metadata": {
+ "id": "0888b0c5"
+ },
+ "source": [
+ "once we are ready to train we call `fit`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2ca3468a",
+ "metadata": {
+ "id": "2ca3468a"
+ },
+ "outputs": [],
+ "source": [
+ "trainer.fit()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "We close the Trainer to properly finish all callbacks and loggers"
+ ],
+ "metadata": {
+ "id": "_4U7TodlIgPy"
+ },
+ "id": "_4U7TodlIgPy"
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.close()"
+ ],
+ "metadata": {
+ "id": "dTWX_MFZIfSF"
+ },
+ "id": "dTWX_MFZIfSF",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "598495ee",
+ "metadata": {
+ "id": "598495ee"
+ },
+ "source": [
+ "## ⚙️ Advanced: Using callbacks to log sample predictions\n",
+ "\n",
+ "> Composer is extensible through its callback system.\n",
+ "\n",
+ "We create a custom callback to automatically log sample predictions during validation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "00401b44",
+ "metadata": {
+ "id": "00401b44"
+ },
+ "outputs": [],
+ "source": [
+ "class LogPredictions(Callback):\n",
+ " def __init__(self, num_samples=100):\n",
+ " super().__init__()\n",
+ " self.num_samples = num_samples\n",
+ " self.data = []\n",
+ "\n",
+ " def batch_end(self, state: State, logger: Logger):\n",
+ " \"\"\"Compute predictions per batch and stores them on self.data\"\"\"\n",
+ " if len(self.data) < self.num_samples:\n",
+ " n = self.num_samples\n",
+ " x, y = state.batch\n",
+ " outputs = state.outputs.argmax(-1)\n",
+ " data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]\n",
+ " self.data += data\n",
+ "\n",
+ " def epoch_end(self, state: State, logger: Logger):\n",
+ " \"Create a wandb.Table and logs it\"\n",
+ " columns = ['image', 'ground truth', 'prediction']\n",
+ " table = wandb.Table(columns=columns, data=self.data[:self.num_samples])\n",
+ " wandb.log({'predictions_table':table}, step=int(state.timestamp.batch))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6161475c",
+ "metadata": {
+ "id": "6161475c"
+ },
+ "source": [
+ "we add `LogPredictions` to the other callbacks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "94d39bd5",
+ "metadata": {
+ "id": "94d39bd5"
+ },
+ "outputs": [],
+ "source": [
+ "callbacks.append(LogPredictions())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "trainer.close()"
+ ],
+ "metadata": {
+ "id": "8qXCtgRWM1ke"
+ },
+ "id": "8qXCtgRWM1ke",
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8ea986a8",
+ "metadata": {
+ "id": "8ea986a8"
+ },
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=mnist_model(num_classes=10),\n",
+ " train_dataloader=train_dataloader,\n",
+ " max_duration=\"2ep\",\n",
+ " loggers=[wandb_logger], # Pass your WandbLogger\n",
+ " callbacks=callbacks,\n",
+ " algorithms=[\n",
+ " LabelSmoothing(smoothing=0.1),\n",
+ " CutMix(alpha=1.0),\n",
+ " ChannelsLast(),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a4deb712",
+ "metadata": {
+ "id": "a4deb712"
+ },
+ "source": [
+ "Once we're ready to train, we just call the `fit` method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c8b00679",
+ "metadata": {
+ "id": "c8b00679"
+ },
+ "outputs": [],
+ "source": [
+ "trainer.fit()\n",
+ "trainer.close()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c2e89f6",
+ "metadata": {
+ "id": "2c2e89f6"
+ },
+ "source": [
+ "We can monitor losses, metrics, gradients, parameters and sample predictions as the model trains."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4889215",
+ "metadata": {
+ "id": "f4889215"
+ },
+ "source": [
+ "![composer.png](https://i.imgur.com/VFZLOB3.png?1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "97f57640",
+ "metadata": {
+ "id": "97f57640"
+ },
+ "source": [
+ "## 📚 Resources\n",
+ "\n",
+ "* We are excited to showcase this early support of [MosaicML-Composer](https://docs.mosaicml.com/en/latest/index.html) go ahead and try this new state of the art framework."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c6722706",
+ "metadata": {
+ "id": "c6722706"
+ },
+ "source": [
+ "## ❓ Questions about W&B\n",
+ "\n",
+ "If you have any questions about using W&B to track your model performance and predictions, please reach out to the [wandb community](https://community.wandb.ai)."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file