From 17a72b0ae79ab7d85bb2e0a5594f4b64788ea5a5 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Sun, 7 Jan 2024 00:17:59 +0530 Subject: [PATCH] add: colab notebooks for yolo-nas report This commit adds the Colab notebooks related to the YOLO-NAS report, placed under colabs/super-gradients Signed-off-by: Suvaditya Mukherjee --- .../yolo_nas_data_analysis.ipynb | 438 ++++++++++++++++++ .../super-gradients/yolo_nas_sweep_run.ipynb | 345 ++++++++++++++ 2 files changed, 783 insertions(+) create mode 100644 colabs/super-gradients/yolo_nas_data_analysis.ipynb create mode 100644 colabs/super-gradients/yolo_nas_sweep_run.ipynb diff --git a/colabs/super-gradients/yolo_nas_data_analysis.ipynb b/colabs/super-gradients/yolo_nas_data_analysis.ipynb new file mode 100644 index 00000000..e98425cf --- /dev/null +++ b/colabs/super-gradients/yolo_nas_data_analysis.ipynb @@ -0,0 +1,438 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SJenmP6BkY2V" + }, + "outputs": [], + "source": [ + "!sudo apt install libcairo2-dev pkg-config python3-dev -qq\n", + "!pip install roboflow pycairo wandb sweeps -qqq\n", + "!pip install super_gradients" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import glob\n", + "import torch\n", + "import wandb\n", + "import warnings\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from matplotlib import patches\n", + "from google.colab import userdata\n", + "from torchvision.io import read_image\n", + "from torch.utils.data import DataLoader\n", + "from super_gradients.training import models, dataloaders\n", + "from super_gradients.training.dataloaders.dataloaders import (\n", + " coco_detection_yolo_format_train, coco_detection_yolo_format_val\n", + ")\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "os.environ[\"WANDB_API_KEY\"] = userdata.get('wandb')\n", + "os.environ[\"ROBOFLOW_API_KEY\"] = userdata.get('roboflow')" + ], + "metadata": { + "id": "zxmwpTBzpd-d" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from roboflow import Roboflow\n", + "rf = Roboflow(api_key=os.getenv(\"ROBOFLOW_API_KEY\"))\n", + "project = rf.workspace(\"easyhyeon\").project(\"trash-sea\")\n", + "dataset = project.version(10).download(\"yolov5\")" + ], + "metadata": { + "id": "77XdRx8rpfqQ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "DATASET_PATH = \"/content/trash-sea-10\"\n", + "WANDB_PROJECT_NAME = \"fconn-yolo-nas\"\n", + "ENTITY = \"ml-colabs\"" + ], + "metadata": { + "id": "8shwiGscpglK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "dataset_params = {\n", + " 'data_dir':DATASET_PATH,\n", + " 'train_images_dir':'train/images',\n", + " 'train_labels_dir':'train/labels',\n", + " 'val_images_dir':'valid/images',\n", + " 'val_labels_dir':'valid/labels',\n", + " 'test_images_dir':'test/images',\n", + " 'test_labels_dir':'test/labels',\n", + " 'classes': [\"Buoy\", \"Can\", \"Paper\", \"Plastic Bag\", \"Plastic Bottle\"]\n", + "}" + ], + "metadata": { + "id": "13DJyeaIvKOf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from IPython.display import clear_output\n", + "\n", + "train_data = coco_detection_yolo_format_train(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['train_images_dir'],\n", + " 'labels_dir': dataset_params['train_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':16,\n", + " 'num_workers':2\n", + " }\n", + ")\n", + "\n", + "val_data = coco_detection_yolo_format_val(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['val_images_dir'],\n", + " 'labels_dir': dataset_params['val_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':16,\n", + " 'num_workers':2\n", + " }\n", + ")\n", + "\n", + "test_data = coco_detection_yolo_format_val(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['test_images_dir'],\n", + " 'labels_dir': dataset_params['test_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':16,\n", + " 'num_workers':2\n", + " }\n", + ")\n", + "\n", + "train_data.dataset.transforms = train_data.dataset.transforms[5:]" + ], + "metadata": { + "id": "7EAeEm3xvS82" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "colors = {\n", + " 0: 'red',\n", + " 1: 'green',\n", + " 2: 'blue',\n", + " 3: 'yellow',\n", + " 4: 'black'\n", + "}\n", + "classes = {\n", + " 0:\"Buoy\",\n", + " 1:\"Can\",\n", + " 2:\"Paper\",\n", + " 3:\"Plastic Bag\",\n", + " 4:\"Plastic Bottle\"\n", + "}" + ], + "metadata": { + "id": "0XKew_u_vRxs" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def process_bounding_boxes_list(annots):\n", + " result = []\n", + "\n", + " class_count = {i:0 for i in range(0, 5)}\n", + "\n", + " for annot_idx, annotation in enumerate(annots):\n", + " class_count[int(annotation[1])] += 1\n", + " result.append({\n", + " \"position\":{\n", + " \"middle\":[float(annotation[2]), float(annotation[3])],\n", + " \"width\":float(annotation[4]),\n", + " \"height\":float(annotation[5])\n", + " },\n", + " \"domain\":\"pixel\",\n", + " \"class_id\":int(annotation[1]),\n", + " \"box_caption\":classes[int(annotation[1])]\n", + " })\n", + "\n", + " return result, class_count\n", + "\n", + "def populate_wandb_image_samples(train_data):\n", + " wandb.init(\n", + " project=WANDB_PROJECT_NAME,\n", + " entity=ENTITY,\n", + " id='add-image-samples',\n", + " job_type=\"add-tables\",\n", + " resume='allow'\n", + " )\n", + "\n", + " class_set = wandb.Classes(\n", + " [\n", + " {\"name\": \"Buoy\", \"id\": 0},\n", + " {\"name\": \"Can\", \"id\": 1},\n", + " {\"name\": \"Paper\", \"id\": 2},\n", + " {\"name\": \"Plastic Bag\", \"id\": 3},\n", + " {\"name\": \"Plastic Bottle\", \"id\": 4},\n", + " ]\n", + " )\n", + "\n", + " table = wandb.Table(\n", + " columns=[\n", + " \"Annotated-Image\", \"Number-of-objects\",\n", + " \"Number-Buoy\", \"Number-Can\", \"Number-Paper\",\n", + " \"Number-Plastic-Bag\", \"Number-Plastic-Bottle\"\n", + " ]\n", + " )\n", + "\n", + " img_count = 0\n", + "\n", + " for batch_idx, batch_sample in enumerate(train_data):\n", + " batch_images = batch_sample[0]\n", + " batch_annotations = batch_sample[1]\n", + "\n", + " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n", + " for annot in batch_annotations:\n", + " annots_dict[int(annot[0])].append(annot)\n", + "\n", + " for idx, image in enumerate(batch_images):\n", + "\n", + " bbox, class_count = process_bounding_boxes_list(annots_dict[idx])\n", + "\n", + " image = image.flip(0)\n", + "\n", + " img = wandb.Image(\n", + " image,\n", + " boxes={\n", + " \"ground_truth\":{\n", + " \"box_data\":bbox,\n", + " \"class_labels\": classes,\n", + " }\n", + " },\n", + " classes=class_set,\n", + " )\n", + "\n", + " table.add_data(img, len(bbox), class_count[0],\n", + " class_count[1], class_count[2],\n", + " class_count[3], class_count[4])\n", + " img_count += 1\n", + "\n", + " print(f\"{img_count}/{len(train_data)*16} completed\")\n", + "\n", + " wandb.log({\"ground_truth_dataset\": table})\n", + " wandb.finish()\n", + "\n", + "populate_wandb_image_samples(train_data)" + ], + "metadata": { + "id": "ncYSU2svvd-B" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def process_bounding_boxes_list(annots):\n", + " result = []\n", + "\n", + " class_count = {i:0 for i in range(0, 5)}\n", + "\n", + " for annot_idx, annotation in enumerate(annots):\n", + " class_count[int(annotation[1])] += 1\n", + " result.append({\n", + " \"position\":{\n", + " \"middle\":[float(annotation[2]), float(annotation[3])],\n", + " \"width\":float(annotation[4]),\n", + " \"height\":float(annotation[5])\n", + " },\n", + " \"domain\":\"pixel\",\n", + " \"class_id\":int(annotation[1]),\n", + " \"box_caption\":classes[int(annotation[1])]\n", + " })\n", + "\n", + " return result, class_count\n", + "\n", + "def populate_wandb_bbox(train_data):\n", + " wandb.init(\n", + " project=WANDB_PROJECT_NAME,\n", + " entity=ENTITY,\n", + " id='add-bbox-data',\n", + " job_type=\"add-tables\",\n", + " resume='allow'\n", + " )\n", + "\n", + " class_set = wandb.Classes(\n", + " [\n", + " {\"name\": \"Buoy\", \"id\": 0},\n", + " {\"name\": \"Can\", \"id\": 1},\n", + " {\"name\": \"Paper\", \"id\": 2},\n", + " {\"name\": \"Plastic Bag\", \"id\": 3},\n", + " {\"name\": \"Plastic Bottle\", \"id\": 4},\n", + " ]\n", + " )\n", + "\n", + " table = wandb.Table(\n", + " columns=[\n", + " \"Image-Id\",\n", + " \"BBox-Height\",\n", + " \"BBox-Width\",\n", + " \"Class-Id\"\n", + " ]\n", + " )\n", + "\n", + " img_count = 0\n", + "\n", + " for batch_idx, batch_sample in enumerate(train_data):\n", + " batch_images = batch_sample[0]\n", + " batch_annotations = batch_sample[1]\n", + "\n", + " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n", + " for annot in batch_annotations:\n", + " annots_dict[int(annot[0])].append(annot)\n", + "\n", + " for idx, image in enumerate(batch_images):\n", + "\n", + " result, class_count = process_bounding_boxes_list(annots_dict[idx])\n", + "\n", + " for bbox in result:\n", + " height = bbox[\"position\"][\"height\"]\n", + " width = bbox[\"position\"][\"width\"]\n", + " class_id = bbox[\"class_id\"]\n", + " table.add_data(img_count, height, width, classes[class_id])\n", + "\n", + " img_count += 1\n", + "\n", + " print(f\"{img_count}/{len(train_data)*16} completed\")\n", + "\n", + " wandb.log({\"bounding_box_information\": table})\n", + " wandb.finish()\n", + "\n", + "populate_wandb_bbox(train_data)" + ], + "metadata": { + "id": "-8W-zcs9vh93" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def populate_wandb_spatial_heatmaps(train_data):\n", + " wandb.init(\n", + " project=WANDB_PROJECT_NAME,\n", + " entity=ENTITY,\n", + " id='add-heatmap',\n", + " job_type=\"add-tables\",\n", + " resume='allow'\n", + " )\n", + "\n", + " class_set = wandb.Classes(\n", + " [\n", + " {\"name\": \"Buoy\", \"id\": 0},\n", + " {\"name\": \"Can\", \"id\": 1},\n", + " {\"name\": \"Paper\", \"id\": 2},\n", + " {\"name\": \"Plastic Bag\", \"id\": 3},\n", + " {\"name\": \"Plastic Bottle\", \"id\": 4},\n", + " ]\n", + " )\n", + " heatmaps = [np.zeros((224, 224, 1), dtype=np.float32) for _ in classes]\n", + " annotation_counts = {i:0 for i in range(len(classes))}\n", + "\n", + " table = wandb.Table(columns=[\"Class-Id\", \"Class-Name\", \"Spatial-Heatmap\",\n", + " \"Num-Total-Objects\"])\n", + "\n", + " for batch_idx, batch_sample in enumerate(train_data):\n", + " batch_images = batch_sample[0]\n", + " batch_annotations = batch_sample[1]\n", + "\n", + " annots_dict = {i:[] for i in range(0, batch_images.shape[0])}\n", + "\n", + " for annot in batch_annotations:\n", + " class_idx = int(annot[1])\n", + "\n", + " midpoint_x = int(annot[2])\n", + " midpoint_y = int(annot[3])\n", + " width = int(annot[4])\n", + " height = int(annot[5])\n", + "\n", + " x_min = midpoint_x - (width//2)\n", + " x_max = midpoint_x + (width//2)\n", + "\n", + " y_min = midpoint_y - (height//2)\n", + " y_max = midpoint_y + (height//2)\n", + "\n", + " heatmaps[class_idx][y_min:y_max, x_min:x_max] += 1\n", + "\n", + " annotation_counts[class_idx] += 1\n", + "\n", + " print(f\"{batch_idx+1}/{len(train_data)} batches completed\")\n", + "\n", + " for class_idx in range(len(classes)):\n", + " heatmap = wandb.Image(\n", + " heatmaps[class_idx],\n", + " caption=classes[class_idx]\n", + " )\n", + " table.add_data(class_idx, classes[class_idx], heatmap, annotation_counts[class_idx])\n", + "\n", + " wandb.log({\"spatial_heatmap_information\": table})\n", + " wandb.finish()\n", + "\n", + "populate_wandb_spatial_heatmaps(train_data)" + ], + "metadata": { + "id": "AaN1pUvJvqjA" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/colabs/super-gradients/yolo_nas_sweep_run.ipynb b/colabs/super-gradients/yolo_nas_sweep_run.ipynb new file mode 100644 index 00000000..93ab971c --- /dev/null +++ b/colabs/super-gradients/yolo_nas_sweep_run.ipynb @@ -0,0 +1,345 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## Installation and Imports" + ], + "metadata": { + "id": "Fe3Dp6NDzPyr" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KXwTjgsFkZnm" + }, + "outputs": [], + "source": [ + "!sudo apt install libcairo2-dev pkg-config python3-dev -qq\n", + "!pip install roboflow pycairo wandb sweeps -qqq\n", + "!pip install super_gradients" + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import glob\n", + "import torch\n", + "import wandb\n", + "import warnings\n", + "import pandas as pd\n", + "\n", + "from google.colab import userdata\n", + "from torchvision.io import read_image\n", + "from torch.utils.data import DataLoader\n", + "from IPython.display import clear_output\n", + "\n", + "from super_gradients.training import models, Trainer, dataloaders\n", + "from super_gradients.training.losses import PPYoloELoss\n", + "from super_gradients.training.metrics import DetectionMetrics_050\n", + "from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback\n", + "from super_gradients.training.dataloaders.dataloaders import coco_detection_yolo_format_train, coco_detection_yolo_format_val\n", + "\n", + "os.environ[\"WANDB_API_KEY\"] = userdata.get('wandb')\n", + "os.environ[\"ROBOFLOW_API_KEY\"] = userdata.get('roboflow')" + ], + "metadata": { + "id": "PpsF3k8awWbP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Definitions" + ], + "metadata": { + "id": "qRjY7eYPzScK" + } + }, + { + "cell_type": "code", + "source": [ + "seed = 42\n", + "torch.manual_seed(seed)\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + "\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.backends.cudnn.benchmark = False" + ], + "metadata": { + "id": "nVJIsdFOwXX2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Download and Register dataset" + ], + "metadata": { + "id": "FhcF4opozT8c" + } + }, + { + "cell_type": "code", + "source": [ + "from roboflow import Roboflow\n", + "rf = Roboflow(api_key=os.getenv(\"ROBOFLOW_API_KEY\"))\n", + "project = rf.workspace(\"easyhyeon\").project(\"trash-sea\")\n", + "dataset = project.version(10).download(\"yolov5\")" + ], + "metadata": { + "id": "BcSFy0iewaJM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ENTITY = \"ml-colabs\"\n", + "SWEEP_NUM_RUNS = 100\n", + "WANDB_PROJECT_NAME = \"fconn-yolo-nas\"\n", + "DATASET_PATH = \"/content/trash-sea-10\"" + ], + "metadata": { + "id": "QIjvNsTowbKO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Define Sweep Configuration and functions" + ], + "metadata": { + "id": "VO0p4iM0zW1N" + } + }, + { + "cell_type": "code", + "source": [ + "sweep_configuration = {\n", + " \"name\": WANDB_EXP_NAME,\n", + " \"metric\": {\"name\": \"Valid_mAP@0.50\", \"goal\": \"maximize\"},\n", + " \"method\": \"bayes\",\n", + " \"parameters\": {\n", + " \"batch_size\": {\"values\": [16, 24, 32]},\n", + " \"optimizer\": {\"values\": [\"Adam\", \"SGD\", \"RMSProp\", \"AdamW\"]},\n", + " \"ema_decay\": {\"min\":0.5, \"max\":0.9},\n", + " \"ema_decay_type\": {\"values\": [\"constant\", \"threshold\"]},\n", + " \"cosine_lr_ratio\": {\"min\": 0.01, \"max\": 0.4},\n", + " \"iou_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n", + " \"dfl_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n", + " \"classification_loss_weight\": {\"min\": 0.25, \"max\": 2.0},\n", + " \"model_flavor\": {\"values\": [\"yolo_nas_s\", \"yolo_nas_m\", \"yolo_nas_l\"]},\n", + " \"weight_decay\": {\"min\": 0.0001, \"max\": 0.01},\n", + " },\n", + "}" + ], + "metadata": { + "id": "8kMUNQ6HwoFa" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def main_call():\n", + "\n", + " CHECKPOINT_DIR = 'checkpoints'\n", + "\n", + " wandb.init(\n", + " project=WANDB_PROJECT_NAME,\n", + " entity=ENTITY,\n", + " resume=\"allow\",\n", + " save_code=True,\n", + " id=WANDB_EXP_NAME\n", + " )\n", + "\n", + " config = wandb.config\n", + "\n", + " dataset_params = {\n", + " 'data_dir':DATASET_PATH,\n", + " 'train_images_dir':'train/images',\n", + " 'train_labels_dir':'train/labels',\n", + " 'val_images_dir':'valid/images',\n", + " 'val_labels_dir':'valid/labels',\n", + " 'test_images_dir':'test/images',\n", + " 'test_labels_dir':'test/labels',\n", + " 'classes': [\"Buoy\", \"Can\", \"Paper\", \"Plastic Bag\", \"Plastic Bottle\"]\n", + " }\n", + "\n", + " train_data = coco_detection_yolo_format_train(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['train_images_dir'],\n", + " 'labels_dir': dataset_params['train_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':config[\"batch_size\"],\n", + " 'num_workers':4\n", + " }\n", + " )\n", + "\n", + " val_data = coco_detection_yolo_format_val(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['val_images_dir'],\n", + " 'labels_dir': dataset_params['val_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':config[\"batch_size\"],\n", + " 'num_workers':4\n", + " }\n", + " )\n", + "\n", + " test_data = coco_detection_yolo_format_val(\n", + " dataset_params={\n", + " 'data_dir': dataset_params['data_dir'],\n", + " 'images_dir': dataset_params['test_images_dir'],\n", + " 'labels_dir': dataset_params['test_labels_dir'],\n", + " 'classes': dataset_params['classes'],\n", + " },\n", + " dataloader_params={\n", + " 'batch_size':config[\"batch_size\"],\n", + " 'num_workers':4\n", + " }\n", + " )\n", + "\n", + " train_data.dataset.transforms = train_data.dataset.transforms[1:]\n", + "\n", + " model = models.get(\n", + " config[\"model_flavor\"],\n", + " num_classes=len(dataset_params['classes']),\n", + " pretrained_weights=\"coco\"\n", + " )\n", + "\n", + " train_params = {\n", + " 'silent_mode': False,\n", + " \"average_best_models\":True,\n", + " \"warmup_mode\": \"linear_epoch_step\",\n", + " \"warmup_initial_lr\": 1e-6,\n", + " \"lr_warmup_epochs\": 3,\n", + " \"initial_lr\": 1e-3,\n", + " \"lr_mode\": \"cosine\",\n", + " \"cosine_final_lr_ratio\": config[\"cosine_lr_ratio\"],\n", + " \"optimizer\": config[\"optimizer\"],\n", + " \"optimizer_params\": {\n", + " \"weight_decay\": config[\"weight_decay\"]\n", + " },\n", + " \"zero_weight_decay_on_bias_and_bn\": True,\n", + " \"ema\": True,\n", + " \"ema_params\": {\n", + " \"decay\": config[\"ema_decay\"],\n", + " \"decay_type\": config[\"ema_decay_type\"]\n", + " },\n", + " \"max_epochs\": 5,\n", + " \"mixed_precision\": False,\n", + " \"loss\": PPYoloELoss(\n", + " use_static_assigner=False,\n", + " num_classes=len(dataset_params['classes']),\n", + " reg_max=16,\n", + " iou_loss_weight=config[\"iou_loss_weight\"],\n", + " dfl_loss_weight=config[\"dfl_loss_weight\"],\n", + " classification_loss_weight=config[\"classification_loss_weight\"]\n", + " ),\n", + " \"valid_metrics_list\": [\n", + " DetectionMetrics_050(\n", + " score_thres=0.1,\n", + " top_k_predictions=300,\n", + " num_cls=len(dataset_params['classes']),\n", + " normalize_targets=True,\n", + " post_prediction_callback=PPYoloEPostPredictionCallback(\n", + " score_threshold=0.01,\n", + " nms_top_k=1000,\n", + " max_predictions=300,\n", + " nms_threshold=0.7\n", + " )\n", + " )\n", + " ],\n", + " \"metric_to_watch\": 'mAP@0.50',\n", + " \"sg_logger\": \"wandb_sg_logger\",\n", + " \"sg_logger_params\": {\n", + " \"project_name\": WANDB_PROJECT_NAME,\n", + " \"save_checkpoints_remote\": True,\n", + " \"save_tensorboard_remote\": True,\n", + " \"save_logs_remote\": True,\n", + " \"entity\": ENTITY\n", + " }\n", + " }\n", + "\n", + " trainer = Trainer(\n", + " experiment_name=WANDB_EXP_NAME,\n", + " ckpt_root_dir=CHECKPOINT_DIR\n", + " )\n", + "\n", + " trainer.train(\n", + " model=model,\n", + " training_params=train_params,\n", + " train_loader=train_data,\n", + " valid_loader=val_data\n", + " )\n", + "\n", + " wandb.finish()" + ], + "metadata": { + "id": "nnZzxVmEwb-U" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Execute Sweep" + ], + "metadata": { + "id": "3Xp0UU73zajw" + } + }, + { + "cell_type": "code", + "source": [ + "sweep_id = wandb.sweep(\n", + " sweep=sweep_configuration,\n", + " project=\"yolo-nas-sweep\"\n", + ")\n", + "\n", + "wandb.agent(sweep_id, function=main_call, count=SWEEP_NUM_RUNS)" + ], + "metadata": { + "id": "9r6LrkhQwxdA" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file