From 019ef7396b910309b96d2426e2162db888b6ce95 Mon Sep 17 00:00:00 2001 From: Anish Shah Date: Tue, 26 Dec 2023 13:51:40 -0500 Subject: [PATCH] Create Track_PyTorch_Lightning_with_Fabric_and_Wandb.ipynb Add notebook for upcoming fabric logger --- ...orch_Lightning_with_Fabric_and_Wandb.ipynb | 794 ++++++++++++++++++ 1 file changed, 794 insertions(+) create mode 100644 colabs/pytorch-lightning/Track_PyTorch_Lightning_with_Fabric_and_Wandb.ipynb diff --git a/colabs/pytorch-lightning/Track_PyTorch_Lightning_with_Fabric_and_Wandb.ipynb b/colabs/pytorch-lightning/Track_PyTorch_Lightning_with_Fabric_and_Wandb.ipynb new file mode 100644 index 00000000..43f8170e --- /dev/null +++ b/colabs/pytorch-lightning/Track_PyTorch_Lightning_with_Fabric_and_Wandb.ipynb @@ -0,0 +1,794 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "tXO1WCH-WacK" + }, + "source": [ + "\"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hXQeOXR4WuAU" + }, + "source": [ + "\"Weights\n", + "\n", + "\n", + "\n", + "# ⚡ Track PyTorch Lightning with Fabric and Wandb" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rutMtofCW4VG" + }, + "source": [ + "\"Weights" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "z5_TyWecW7j2" + }, + "source": [ + "At Weights & Biases, we love anything\n", + "that makes training deep learning models easier.\n", + "That's why we worked with the folks at PyTorch Lightning to\n", + "[integrate our experiment tracking tool](https://docs.wandb.com/library/integrations/lightning)\n", + "directly into the Fabric library of PyTorch Lightning\n", + "\n", + "[PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.\n", + "It retains all the flexibility of PyTorch,\n", + "in case you need it,\n", + "but adds some useful abstractions\n", + "and builds in some best practices.\n", + "\n", + "[Pytorch Fabric](https://lightning.ai/docs/fabric/stable/) allows you to scale PyTorch models on\n", + "distributed machines while\n", + "maintaining full control of your\n", + "training loop.\n", + "\n", + "## What this notebook covers:\n", + "\n", + "1. How to get basic metric logging with the `WandbLogger`\n", + "2. How to log media with W&B\n", + "\n", + "## The interactive dashboard in W&B will look like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9Q-cZ3PA-w4m" + }, + "outputs": [], + "source": [ + "%%capture\n", + "# Remove when PR Merged\n", + "!pip install git+https://github.com/ash0ts/lightning.git@fabric-wandb\n", + "!pip install wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x0zNPAVIa70W" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"WANDB_API_KEY\"]=\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "z90ZAP09BQyx", + "outputId": "637fa913-ed63-408a-8959-80dd2d96b799" + }, + "outputs": [], + "source": [ + "import wandb\n", + "wandb.login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yRlO9oLtVlic" + }, + "outputs": [], + "source": [ + "import lightning as L\n", + "import torch; import torchvision as tv\n", + "from lightning.fabric.loggers import WandbLogger" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KoA6WKajd9e9" + }, + "source": [ + "## 💡 Tracking Experiments with WandbLogger\n", + "\n", + "PyTorch Lightning has a `WandbLogger` to easily log your experiments with Wights & Biases. Just pass it to your `Trainer` to log to W&B. See the WandbLogger docs for all parameters. Note, to log the metrics to a specific W&B Team, pass your Team name to the `entity` argument in `WandbLogger`\n", + "\n", + "#### `lightning.fabric.loggers.WandbLogger()`\n", + "\n", + "| Functionality | Argument/Function | PS |\n", + "| ------ | ------ | ------ |\n", + "| Logging models | `WandbLogger(... ,log_model='all')` or `WandbLogger(... ,log_model=True`) | Log all models if `log_model=\"all\"` and at end of training if `log_model=True`\n", + "| Set custom run names | `WandbLogger(... ,name='my_run_name'`) | |\n", + "| Organize runs by project | `WandbLogger(... ,project='my_project')` | |\n", + "| Log histograms of gradients and parameters | `WandbLogger.watch(model)` | `WandbLogger.watch(model, log='all')` to log parameter histograms |\n", + "| Log hyperparameters | Call `self.save_hyperparameters()` within `LightningModule.__init__()` |\n", + "| Log custom objects (images, audio, video, molecules…) | Use `WandbLogger.log_text`, `WandbLogger.log_image` and `WandbLogger.log_table`, etc. |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NMcmlO5RVn2M" + }, + "outputs": [], + "source": [ + "logger = WandbLogger(project=\"Cifar10_ptl_fabric\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yXgF106teqTo" + }, + "source": [ + "## Log custom hyperparameters and configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "id": "-6BH5NhxVruQ", + "outputId": "1adc35a7-1f67-4870-cc63-8ae12d9b9185" + }, + "outputs": [], + "source": [ + "lr = 0.001\n", + "batch_size = 16\n", + "num_epochs = 5\n", + "classes = ('plane', 'car', 'bird', 'cat',\n", + " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", + "log_images_after_n_batches = 200\n", + "\n", + "logger.log_hyperparams({\n", + " \"lr\": lr,\n", + " \"batch_size\": batch_size,\n", + " \"num_epochs\": num_epochs,\n", + " \"classes\": classes,\n", + " \"log_images_after_n_batches\": log_images_after_n_batches\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VzXAr3kWgBJR" + }, + "source": [ + "## Save Data to Weights and Biases Artifacts\n", + "\n", + "This allows us to audit and create direct data lineages to our experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2gMCexD5gM1q" + }, + "outputs": [], + "source": [ + "root_folder = \"data\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5TT53dYtVzhj", + "outputId": "d1fcdc5d-97ce-4cd0-8ae5-b7026d4de281" + }, + "outputs": [], + "source": [ + "train_dataset = tv.datasets.CIFAR10(root_folder, download=True,\n", + " train=True,\n", + " transform=tv.transforms.ToTensor())\n", + "test_dataset = tv.datasets.CIFAR10(root_folder, download=True,\n", + " train=False,\n", + " transform=tv.transforms.ToTensor())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xitzA_KrgH_y" + }, + "outputs": [], + "source": [ + "data_folder = train_dataset.base_folder # same as test_dataset.base_folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5qWJMqUEgctV", + "outputId": "293d9b2f-3594-45e4-dbd9-0ec3bf302748" + }, + "outputs": [], + "source": [ + "data_art = wandb.Artifact(name=\"cifar10\", type=\"dataset\")\n", + "data_art.add_dir(os.path.join(root_folder, data_folder))\n", + "logger.experiment.log_artifact(data_art)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pse_7xRehZTY" + }, + "source": [ + "## Configure our Model and Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zurvz8nuVvsB" + }, + "outputs": [], + "source": [ + "model = tv.models.resnet18()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oRDFGDNUvvX5" + }, + "outputs": [], + "source": [ + "class TableLoggingCallback:\n", + " def __init__(self, wandb_logger):\n", + " self.wandb_logger = wandb_logger\n", + " self.table = wandb.Table(columns=[\"image\", \"prediction\", \"ground_truth\"])\n", + "\n", + " def on_test_batch_end(self, images, predictions, ground_truths):\n", + " for image, prediction, ground_truth in zip(images, predictions, ground_truths):\n", + " self.table.add_data(wandb.Image(image), prediction, ground_truth)\n", + "\n", + " def on_model_epoch_end(self):\n", + " prediction_table = self.table\n", + " print(self.table.data[0])\n", + " self.wandb_logger.experiment.log({\"prediction_table\": prediction_table}) # You can directly access the run object via `experiment`\n", + "\n", + " # We could also use\n", + " # (1) wandb_logger.log_metrics()\n", + " # (2) wandb_logger.log_table() \n", + "\n", + " self.table = wandb.Table(columns=[\"image\", \"prediction\", \"ground_truth\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "46hLsuXthhlI" + }, + "source": [ + "Load our model, datasources, and loggers into PyTorch Fabric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BBqRdWCX-YMd" + }, + "outputs": [], + "source": [ + "tlc = TableLoggingCallback(logger)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PFByZcKq-tta" + }, + "outputs": [], + "source": [ + "fabric = L.Fabric(loggers=[logger], callbacks=[tlc])\n", + "fabric.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VFSmQIxXV4pX" + }, + "outputs": [], + "source": [ + "model, optimizer = fabric.setup(model, optimizer)\n", + "\n", + "train_dataloader = fabric.setup_dataloaders(torch.utils.data.DataLoader(train_dataset, batch_size=batch_size))\n", + "test_dataloader = fabric.setup_dataloaders(torch.utils.data.DataLoader(test_dataset, batch_size=batch_size))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vFqs6e1nHxja" + }, + "source": [ + "## Run training and log test predictions\n", + "\n", + "For every epoch, run a training step and a test step. For each n test batches, we log the batch of test images caption by the prediction and label, and we create a wandb.Table() in which to store test predictions using our custom callback" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ood7rUH9hpRI" + }, + "source": [ + "No additional dependencies outside the Torch modeling you're used to!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "A5IFtxq7nX5M", + "outputId": "addd4c14-5197-455f-bf58-d8b54c1b7b46" + }, + "outputs": [], + "source": [ + "logger.watch(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sc1vOoqKZvst", + "outputId": "d63efd4f-7182-4c1e-d08b-75868160d4a7" + }, + "outputs": [], + "source": [ + "model.train()\n", + "\n", + "for epoch in range(num_epochs):\n", + " # Training Loop\n", + " fabric.print(f\"Epoch: {epoch}\")\n", + " cum_loss = 0\n", + "\n", + " # Batch by batch of data from training dataset\n", + " for batch in train_dataloader:\n", + " inputs, labels = batch\n", + " optimizer.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = torch.nn.functional.cross_entropy(outputs, labels)\n", + " cum_loss += loss.item()\n", + " fabric.backward(loss)\n", + " optimizer.step()\n", + "\n", + " fabric.log_dict({\"loss\": loss.item()}) # Stream per batch training metrics\n", + "\n", + " fabric.log_dict({\"avg_loss\": cum_loss / len(train_dataloader)}) # Stream per epoch training metrics\n", + "\n", + " # Validation Loop\n", + " correct = 0\n", + " total = 0\n", + " class_correct = list(0. for i in range(10))\n", + " class_total = list(0. for i in range(10))\n", + "\n", + " test_batch_ctr = 0\n", + "\n", + " with torch.no_grad():\n", + "\n", + " # Batch by batch of data from testing dataset\n", + " for batch_ctr, batch in enumerate(test_dataloader):\n", + " images, labels = batch\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs, 1)\n", + "\n", + " # Overall Test Accuracy\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + " # Per Class Accuracy\n", + " c = (predicted == labels).squeeze()\n", + " for i in range(batch[0].size(0)):\n", + " label = labels[i]\n", + " class_correct[label] += c[i].item()\n", + " class_total[label] += 1\n", + "\n", + " if batch_ctr % log_images_after_n_batches == 0:\n", + "\n", + " # Test Images labeled with Class prediction for qualitative analysis\n", + " predictions = [classes[prediction] for prediction in predicted]\n", + " label_names = [classes[truth] for truth in labels]\n", + " loggable_images = [image for image in images]\n", + "\n", + " captions = [\n", + " f\"pred: {pred}\\nlabel: {truth}\" for pred, truth in zip(predictions, label_names)\n", + " ]\n", + "\n", + " logger.log_image(key=\"test_image_batch\", images=loggable_images, step=None, caption=captions) # Automatically construct and log wandb.Images\n", + "\n", + " # Can also just directly log the below list via fabric.log_dict\n", + " # [wandb.Image(image, caption=classes[predicted]) for image, predicted, label in zip(images, predicted, labels)])\n", + "\n", + " fabric.call(\"on_test_batch_end\", images=loggable_images, predictions=predictions, ground_truths=label_names) # Populate per batch data within our table\n", + "\n", + "\n", + "\n", + " # Calculate cumulative test metrics\n", + " test_acc = 100 * correct / total\n", + " class_acc = {f\"{classes[i]}_acc\": 100 * class_correct[i] / class_total[i] for i in range(10) if class_total[i] > 0}\n", + " loggable_dict = {\n", + " \"test_acc\": test_acc,\n", + " }\n", + " loggable_dict.update(class_acc)\n", + "\n", + " fabric.log_dict(loggable_dict) # Stream per epoch validation metrics\n", + " fabric.call(\"on_model_epoch_end\") # Save epoch test data table to dashboard" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jfwdrnr6irWh" + }, + "source": [ + "Finish our experiment!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 454, + "referenced_widgets": [ + "18d68193fdb44f67892c2b519f885a38", + "e5d92952030d485a822ce721031294c7", + "afcce082a3384ae19b66a93f8223618e", + "8165e37fcb3b4307ba1458088ab0feb6", + "7ffedf9c05ef40e6944824345878440a", + "30574f4c8ea44ef8a27d4f20658cb95f", + "1ea964ac59be4b7ea36abf9ce45ad23f", + "c1e624f57ec2451da2494484057d3674" + ] + }, + "id": "SixRB0B9g4rI", + "outputId": "de2285ee-818b-4235-afd9-71ee4b14ccce" + }, + "outputs": [], + "source": [ + "logger.experiment.finish()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "18d68193fdb44f67892c2b519f885a38": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "VBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "VBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "VBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e5d92952030d485a822ce721031294c7", + "IPY_MODEL_afcce082a3384ae19b66a93f8223618e" + ], + "layout": "IPY_MODEL_8165e37fcb3b4307ba1458088ab0feb6" + } + }, + "1ea964ac59be4b7ea36abf9ce45ad23f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "30574f4c8ea44ef8a27d4f20658cb95f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7ffedf9c05ef40e6944824345878440a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8165e37fcb3b4307ba1458088ab0feb6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "afcce082a3384ae19b66a93f8223618e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1ea964ac59be4b7ea36abf9ce45ad23f", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c1e624f57ec2451da2494484057d3674", + "value": 1 + } + }, + "c1e624f57ec2451da2494484057d3674": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e5d92952030d485a822ce721031294c7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "LabelModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "LabelModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "LabelView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7ffedf9c05ef40e6944824345878440a", + "placeholder": "​", + "style": "IPY_MODEL_30574f4c8ea44ef8a27d4f20658cb95f", + "value": "179.551 MB of 179.551 MB uploaded (0.590 MB deduped)\r" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}