diff --git a/experiments/quantization/LAPQ/lapq_demo.ipynb b/experiments/quantization/LAPQ/lapq_demo.ipynb index bc5fb0a..a93d691 100644 --- a/experiments/quantization/LAPQ/lapq_demo.ipynb +++ b/experiments/quantization/LAPQ/lapq_demo.ipynb @@ -1,554 +1,620 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "498871cf", - "metadata": {}, - "source": [ - "## LAPQ\n", - "This notebook demonstrates the implimentation of the paper [Loss Aware Post-training Quantization](https://arxiv.org/abs/1911.07190)\n", - "\n", - "### Steps to quantize the pretrained model\n", - "- Load the dataset and create dataloader. A subset of training data is used for calibration.\n", - "- Load the pretrained full precision model.\n", - "- Load the configurations from the YAML file.\n", - "- Create a `LAPQ` object and pass the full precision model, dataloaders and configurations.\n", - "- Quantize the model by calling the `compress_model` method." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "dafbd1b3", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append(\"../../../\")\n", - "\n", - "import os\n", - "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\" \n", - "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "417e9692", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/envs/py117/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import yaml\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "from torchvision import transforms\n", - "from trailmet.datasets.classification import DatasetFactory\n", - "from trailmet.models import ModelsFactory\n", - "from trailmet.algorithms import quantize" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "c67e2359", - "metadata": {}, - "source": [ - "## Datasets" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "64a12f9a", - "metadata": {}, - "source": [ - "### Augmentations" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "4c8c6192", - "metadata": {}, - "outputs": [], - "source": [ - "stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n", - "\n", - "train_transform = transforms.Compose([\n", - " transforms.RandomCrop(32, padding=4, padding_mode='reflect'),\n", - " transforms.RandomHorizontalFlip(),\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(*stats, inplace=True)\n", - "])\n", - "val_transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(*stats)\n", - "])\n", - "test_transform = transforms.Compose([\n", - " transforms.ToTensor(),\n", - " transforms.Normalize(*stats)\n", - "])\n", - "\n", - "input_transforms = {\n", - " 'train': train_transform, \n", - " 'val': val_transform, \n", - " 'test': test_transform}\n", - "\n", - "target_transforms = {\n", - " 'train': None, \n", - " 'val': None, \n", - " 'test': None}" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0c41f2b1", - "metadata": {}, - "source": [ - "### Load Datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "b377f3bb", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Files already downloaded and verified\n", - "Files already downloaded and verified\n", - "Files already downloaded and verified\n", - "Train samples: 40000\n", - "Val samples: 10000\n", - "Test samples: 10000\n" - ] - } - ], - "source": [ - "cifar100_dataset = DatasetFactory.create_dataset(\n", - " name = 'CIFAR100', \n", - " root = './data',\n", - " split_types = ['train', 'val', 'test'],\n", - " val_fraction = 0.2,\n", - " transform = input_transforms,\n", - " target_transform = target_transforms)\n", - "\n", - "# getting the size of the different splits\n", - "print('Train samples: ',cifar100_dataset['info']['train_size'])\n", - "print('Val samples: ',cifar100_dataset['info']['val_size'])\n", - "print('Test samples: ',cifar100_dataset['info']['test_size'] )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "eec91853", - "metadata": {}, - "source": [ - "### Define Dataloaders" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "464c2e93", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No. of training batches: 313\n", - "No. of validation batches: 79\n", - "No. of test batches: 79\n" - ] - } - ], - "source": [ - "train_loader = DataLoader(\n", - " cifar100_dataset['train'], batch_size=128, \n", - " sampler=cifar100_dataset['train_sampler'],\n", - " num_workers=2)\n", - "val_loader = DataLoader(\n", - " cifar100_dataset['val'], batch_size=128, \n", - " sampler=cifar100_dataset['val_sampler'],\n", - " num_workers=2)\n", - "test_loader = DataLoader(\n", - " cifar100_dataset['test'], batch_size=128, \n", - " sampler=cifar100_dataset['test_sampler'],\n", - " num_workers=2)\n", - "\n", - "dataloaders = {\"train\": train_loader, \"val\": val_loader, \"test\": test_loader}\n", - "\n", - "print('No. of training batches: ', len(dataloaders['train']))\n", - "print('No. of validation batches: ', len(dataloaders['val']))\n", - "print('No. of test batches: ', len(dataloaders['test']))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "85812739", - "metadata": {}, - "source": [ - "### Load Pretrained Model" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "e4db07ac", - "metadata": {}, - "outputs": [], - "source": [ - "res50_model = ModelsFactory.create_model(name='resnet50', num_classes=100, pretrained=False, insize=32)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0ca3e796", - "metadata": {}, - "source": [ - "### Load Method Config" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5b9625ad", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'GPU_ID': 0,\n", - " 'SEED': 42,\n", - " 'W_BITS': 4,\n", - " 'A_BITS': 8,\n", - " 'ACT_QUANT': True,\n", - " 'CALIB_BATCHES': 4,\n", - " 'MAX_ITER': 1000,\n", - " 'MAX_FEV': 1000,\n", - " 'VERBOSE': True}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with open('./lapq_config.yaml', 'r') as f:\n", - " config = yaml.safe_load(f)\n", - " kwargs = config['GENERAL']\n", - " \n", - "kwargs" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "87e6c554", - "metadata": {}, - "source": [ - "### Quantization Method: BRECQ" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "e21d5073", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==> Using seed: 42 and device: cuda:0\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "498871cf" + }, + "source": [ + "## LAPQ\n", + "This notebook demonstrates the implimentation of the paper [Loss Aware Post-training Quantization](https://arxiv.org/abs/1911.07190)\n", + "\n", + "### Steps to quantize the pretrained model\n", + "- Load the dataset and create dataloader. A subset of training data is used for calibration.\n", + "- Load the pretrained full precision model.\n", + "- Load the configurations from the YAML file.\n", + "- Create a `LAPQ` object and pass the full precision model, dataloaders and configurations.\n", + "- Quantize the model by calling the `compress_model` method." + ], + "id": "498871cf" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33manimesh-007\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "skKWRRtMsfRM", + "outputId": "a1265ade-b421-4bc1-f456-63b88c69fdbb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" + ] + } + ], + "source": [ + "USE_COLAB = True\n", + "\n", + "if USE_COLAB:\n", + " from google.colab import drive\n", + " drive.mount(\"/content/drive\")\n", + " base_path = \"/content/drive/MyDrive/trail\"\n", + "else:\n", + " base_path = \"../../../..\"\n", + "\n", + "library_path = base_path + \"/trailmet\"\n", + "requirements_path = library_path + \"/requirements.txt\"\n", + "config_path = library_path + \"/experiments/quantization/LAPQ/lapq_config.yaml\"\n", + "weights_path = base_path + \"/weights/resnet50_cifar100_pretrained.pth\"" + ], + "id": "skKWRRtMsfRM" }, { - "data": { - "text/html": [ - "wandb version 0.15.4 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P7Kar046xSEd" + }, + "outputs": [], + "source": [ + "%pip install -q -r $requirements_path" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "P7Kar046xSEd" }, { - "data": { - "text/html": [ - "Tracking run with wandb version 0.14.0" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dafbd1b3" + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(library_path)" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "dafbd1b3" }, { - "data": { - "text/html": [ - "Run data is saved locally in /workspace/animesh_trailmet/experiments/quantization/LAPQ/wandb/run-20230625_230422-w4sdlkw4" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "417e9692" + }, + "outputs": [], + "source": [ + "import yaml\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import transforms\n", + "from trailmet.datasets.classification import DatasetFactory\n", + "from trailmet.models import resnet, mobilenet\n", + "from trailmet.algorithms import quantize" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "417e9692" }, { - "data": { - "text/html": [ - "Syncing run CIFAR100_8_Jun-25_23:04:20 to Weights & Biases (docs)
" + "cell_type": "markdown", + "metadata": { + "id": "c67e2359" + }, + "source": [ + "## Datasets" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "c67e2359" }, { - "data": { - "text/html": [ - " View project at https://wandb.ai/animesh-007/Trailmet%20LAPQ" + "cell_type": "markdown", + "metadata": { + "id": "64a12f9a" + }, + "source": [ + "### Augmentations" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "64a12f9a" }, { - "data": { - "text/html": [ - " View run at https://wandb.ai/animesh-007/Trailmet%20LAPQ/runs/w4sdlkw4" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4c8c6192" + }, + "outputs": [], + "source": [ + "stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n", + "\n", + "train_transform = transforms.Compose([\n", + " transforms.RandomCrop(32, padding=4, padding_mode='reflect'),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(*stats, inplace=True)\n", + "])\n", + "val_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(*stats)\n", + "])\n", + "test_transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(*stats)\n", + "])\n", + "\n", + "input_transforms = {\n", + " 'train': train_transform,\n", + " 'val': val_transform,\n", + " 'test': test_transform}\n", + "\n", + "target_transforms = {\n", + " 'train': None,\n", + " 'val': None,\n", + " 'test': None}" ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" + "id": "4c8c6192" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "testing pretrained model before quantization\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "0c41f2b1" + }, + "source": [ + "### Load Datasets" + ], + "id": "0c41f2b1" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Validating network (79 / 79 Steps) (batch time=0.01544s) (loss=9.17796) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 17.72it/s] \n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b377f3bb", + "outputId": "b0d816c7-cbe2-4a50-d9a2-bf9d443e9715" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n", + "Files already downloaded and verified\n", + "Train samples: 40000\n", + "Val samples: 10000\n", + "Test samples: 10000\n" + ] + } + ], + "source": [ + "cifar100_dataset = DatasetFactory.create_dataset(\n", + " name = 'CIFAR100',\n", + " root = './data',\n", + " split_types = ['train', 'val', 'test'],\n", + " val_fraction = 0.2,\n", + " transform = input_transforms,\n", + " target_transform = target_transforms)\n", + "\n", + "# getting the size of the different splits\n", + "print('Train samples: ',cifar100_dataset['info']['train_size'])\n", + "print('Val samples: ',cifar100_dataset['info']['val_size'])\n", + "print('Test samples: ',cifar100_dataset['info']['test_size'] )" + ], + "id": "b377f3bb" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " * acc@1 1.040 acc@5 5.190\n", - "top-1 acc: 1.04%, top-5 acc: 5.19%\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "eec91853" + }, + "source": [ + "### Define Dataloaders" + ], + "id": "eec91853" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Validating network (79 / 79 Steps) (batch time=0.04852s) (loss=11.44887) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 17.18it/s] \n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "464c2e93", + "outputId": "1f95a9d9-3165-427b-c7cc-52377559a9ba" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No. of training batches: 313\n", + "No. of validation batches: 79\n", + "No. of test batches: 79\n" + ] + } + ], + "source": [ + "train_loader = DataLoader(\n", + " cifar100_dataset['train'], batch_size=128,\n", + " sampler=cifar100_dataset['train_sampler'],\n", + " num_workers=0)\n", + "val_loader = DataLoader(\n", + " cifar100_dataset['val'], batch_size=128,\n", + " sampler=cifar100_dataset['val_sampler'],\n", + " num_workers=0)\n", + "test_loader = DataLoader(\n", + " cifar100_dataset['test'], batch_size=128,\n", + " sampler=cifar100_dataset['test_sampler'],\n", + " num_workers=0)\n", + "\n", + "dataloaders = {\"train\": train_loader, \"val\": val_loader, \"test\": test_loader}\n", + "\n", + "print('No. of training batches: ', len(dataloaders['train']))\n", + "print('No. of validation batches: ', len(dataloaders['val']))\n", + "print('No. of test batches: ', len(dataloaders['test']))" + ], + "id": "464c2e93" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " * acc@1 1.010 acc@5 5.040\n", - "==> Quantization (W4A8) accuracy before LAPQ: 1.0100 | 5.0400\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "85812739" + }, + "source": [ + "### Load Pretrained Model" + ], + "id": "85812739" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 10/10 [00:07<00:00, 1.30it/s, loss=9.16, p_val=4] \n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e4db07ac", + "outputId": "7a942a3e-837e-4081-e6cf-26a1910646ae" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = resnet.make_resnet50(100,32)\n", + "checkpoint = torch.load(weights_path, map_location='cuda:0')\n", + "model.load_state_dict(checkpoint['state_dict'])" + ], + "id": "e4db07ac" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "==> using p intr : 4.09\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "0ca3e796" + }, + "source": [ + "### Load Method Config" + ], + "id": "0ca3e796" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Validating network (79 / 79 Steps) (batch time=0.04860s) (loss=9.68071) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 16.46it/s] \n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5b9625ad" + }, + "outputs": [], + "source": [ + "# with open(config_path, 'r') as f:\n", + "# config_all = yaml.safe_load(f)\n", + "# config = config_all['GENERAL']\n", + "\n", + "config = {\n", + " 'w_bits' : 8,\n", + " 'a_bits' : 8,\n", + " 'reduce_range': True,\n", + " 'act_quant': True,\n", + " 'max_iter': 2000,\n", + " 'max_fev': 2000,\n", + " 'calib_bs': 256,\n", + " 'calib_size': 1024,\n", + " 'seed': 42,\n", + " 'gpu_id': 0\n", + "}" + ], + "id": "5b9625ad" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " * acc@1 0.960 acc@5 5.070\n", - "==> Quantization (W4A8) accuracy before Optimization: 0.9600 | 5.0700\n", - "==> Loss after LpNormQuantization: 9.4259\n", - "==> Starting Powell Optimization\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "87e6c554" + }, + "source": [ + "### Quantization Method: LAPQ" + ], + "id": "87e6c554" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1000/1000 [03:28<00:00, 4.80it/s, curr_loss=4.62, min_loss=4.61]\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/" + }, + "id": "e21d5073", + "outputId": "a3022343-845b-43b3-a9c1-5a1497b26408" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 79/79 [00:15<00:00, 4.99it/s, acc1=72.5, acc5=91.5]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Full Precision Model: acc@1 72.518 | acc@5 91.525\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 79/79 [00:08<00:00, 9.10it/s, acc1=71.8, acc5=91.1]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Quantization accuracy before LAPQ: acc@1 71.756 | acc@5 91.149\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 20/20 [01:00<00:00, 3.02s/it, loss=0.0998, p_val=3.9]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> using p-val : 3.481 with lp-loss : 0.093\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 79/79 [00:11<00:00, 6.91it/s, acc1=72.2, acc5=91.5]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Quantization accuracy before optimization: acc@1 72.241 | acc@5 91.535\n", + "==> Starting Powell Optimization\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2000/2000 [04:43<00:00, 7.06it/s, curr_loss=0.092, min_loss=0.092]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Optimization completed with status: False\n", + "==> Optimized alphas :\n", + " [ 0.32192663 0.4921856 0.44087177 0.9391082 0.95024213 0.35209557\n", + " 0.56010642 0.8362113 0.14528238 0.6173208 1.22021341 0.46791642\n", + " 0.25910885 0.74059974 0.45257379 0.14750679 0.15402703 1.13943211\n", + " 0.17752448 0.38750959 1.09536373 0.26694447 0.21156799 0.70860321\n", + " 0.36488586 0.12406421 0.50658086 0.28335763 0.12680141 0.18381498\n", + " 0.83998023 0.10583398 0.16331125 0.65832013 0.09808584 0.16091051\n", + " 0.84497283 0.12403724 0.20360942 0.679278 0.17891735 0.26773851\n", + " 0.54199507 1.34550151 0.21835104 1.16213213 0.40100812 0.04182439\n", + " 0.32693615 0.75804108 0.01931836 0.23854705 0.77586533 0.35659174\n", + " 1.99467082 1.18234275 0.75984839 0.50286948 0.43017 0.54280295\n", + " 0.80808552 0.7176247 0.37142049 0.5562059 0.46602321 0.71515874\n", + " 0.30534324 0.37335814 0.67148537 0.62658289 0.61451279 0.52567379\n", + " 0.63642817 0.61713725 0.63287643 0.24550927 0.4923878 0.70035124\n", + " 0.75159778 0.36112166 0.37687481 0.66791642 0.7132149 0.49287515\n", + " 0.45697315 0.36587468 0.68228734 0.75030695 0.53571547 0.56328548\n", + " 0.47901733 0.68669396 0.31868012 0.36477025 0.611461 0.77932724\n", + " 0.33648931 0.33574072 0.35906766 0.78264685 0.34599894 0.33887686\n", + " 0.49970291 0.94860415 0.35073394 0.42563926 0.48520955 0.74453487\n", + " 0.51277885 0.53216721 0.57372373 1.96872515 0.60519236 0.67802123\n", + " 1.73508274 0.78461236 3.43784676 0.39349705 0.66659818 2.10375101\n", + " 4.73006429 0.28074448 0.38541939 2.15557389 10.93467115]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 79/79 [00:08<00:00, 8.95it/s, acc1=72, acc5=91.2]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Final LAPQ quantization accuracy: 72.004 | 91.228\n" + ] + } + ], + "source": [ + "quantizer = quantize.lapq.LAPQ('resnet', dataloaders, **config)\n", + "qmodel = quantizer.compress_model(model)" + ], + "id": "e21d5073" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "==> Layer-wise Scales :\n", - " [-2.71854830e+00 -3.08051988e-01 9.95627994e-01 3.81617464e+00\n", - " -9.40671566e-01 -3.61076604e+00 -1.81663796e+00 1.21657309e+00\n", - " 1.91811575e+00 2.52682898e+00 -1.42890691e+00 -1.70284810e+00\n", - " 3.04076019e+00 -1.71109412e+00 -4.41954680e-01 -2.75285367e-01\n", - " 4.77568090e+00 -8.70588564e-01 -1.46963710e+01 -6.32294023e-01\n", - " -1.39139490e+00 4.28594499e+00 -1.21911043e+02 -1.34789206e+00\n", - " 1.03977837e+00 1.19332061e+01 -1.46906333e+01 -4.85501959e-01\n", - " 1.23643283e+00 1.34744476e+01 -7.74482854e-01 -1.09169621e+00\n", - " 3.03794852e-01 1.49893110e+01 -1.44991452e+00 -6.44897627e+00\n", - " -9.93479876e-01 1.35904461e-01 2.40335316e+01 2.86178112e-01\n", - " 9.66371353e-02 1.44008197e-01 2.28464613e+01 2.84444329e-01\n", - " 9.55632382e-02 1.41487494e-01 3.59558318e+01 2.82226657e-01\n", - " 9.38017642e-02 1.43612936e-01 5.25450935e+01 2.79624492e-01\n", - " 9.47439224e-02 1.43324569e-01 1.15084450e+02 2.01696590e-01\n", - " 6.74792528e-02 1.00485601e-01 9.49488449e+01 1.02076188e-01\n", - " 2.02921063e-01 6.77159503e-02 1.01089455e-01 1.21774887e+02\n", - " 2.03220874e-01 6.74693212e-02]\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "background_save": true + }, + "id": "c6edd9e9", + "outputId": "3781a5d3-b56b-470a-afb5-3ff166ec5d94" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "testing quantized model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 79/79 [01:26<00:00, 1.10s/it, acc1=72, acc5=91.2]\n" + ] + } + ], + "source": [ + "print('testing quantized model')\n", + "qmodel.to(torch.device('cpu'))\n", + "acc1, acc5 = quantizer.test(model=qmodel, dataloader=dataloaders['test'], device=torch.device('cpu'))" + ], + "id": "c6edd9e9" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Validating network (79 / 79 Steps) (batch time=0.04774s) (loss=4.62941) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 16.54it/s]\n" - ] + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "GMg5Wvv-P5Z9", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5e33b0a4-ac58-4e16-b884-ec2a0eaa224f" + }, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "testing full precision model\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 79/79 [03:25<00:00, 2.61s/it, acc1=72.5, acc5=91.5]\n" + ] + } + ], + "source": [ + "print('testing full precision model')\n", + "model.to(torch.device('cpu'))\n", + "acc1, acc5 = quantizer.test(model=model, dataloader=dataloaders['test'], device=torch.device('cpu'))" + ], + "id": "GMg5Wvv-P5Z9" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " * acc@1 1.090 acc@5 4.920\n", - "==> Full quantization (W4A8) accuracy: 1.0899999141693115\n" - ] - } - ], - "source": [ - "quantizer = quantize.lapq.LAPQ(res50_model, dataloaders, **kwargs)\n", - "\n", - "print('testing pretrained model before quantization')\n", - "_, acc1, acc5 = quantizer.test(model=res50_model, dataloader=dataloaders['test'], loss_fn=torch.nn.CrossEntropyLoss())\n", - "print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')\n", - "\n", - "qmodel = quantizer.compress_model()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "c6edd9e9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "testing quantized model\n" - ] + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "56e04513", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "966c5bcd-d964-44cc-a82a-16ed5b30b1df" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Size: 23.84 MB\n", + "Size: 95.12 MB\n" + ] + } + ], + "source": [ + "import os\n", + "def print_model_size(model):\n", + " torch.save(model.state_dict(), \"temp.p\")\n", + " print(f'Size: {os.path.getsize(\"temp.p\")/1e6:.2f} MB')\n", + " os.remove('temp.p')\n", + "\n", + "print_model_size(qmodel)\n", + "print_model_size(model)" + ], + "id": "56e04513" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Validating network (79 / 79 Steps) (batch time=0.05487s) (loss=4.62941) (top1=0.00000) (top5=0.00000): 100%|| 79/79 [00:04<00:00, 18.34it/s]" - ] + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "TxsdFuMrPleV" + }, + "outputs": [], + "source": [ + "torch.save(qmodel.state_dict(), \"quantized_res50_c100.pth\")" + ], + "id": "TxsdFuMrPleV" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - " * acc@1 1.090 acc@5 4.920\n", - "top-1 acc: 1.09%, top-5 acc: 4.92%\n" - ] + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "JJ4iwacl2pQw" + }, + "outputs": [], + "source": [], + "id": "JJ4iwacl2pQw" + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" } - ], - "source": [ - "print('testing quantized model')\n", - "_, acc1, acc5 = quantizer.test(model=qmodel, dataloader=dataloaders['test'], loss_fn=torch.nn.CrossEntropyLoss())\n", - "print(f'top-1 acc: {acc1:.2f}%, top-5 acc: {acc5:.2f}%')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/trailmet/algorithms/algorithms.py b/trailmet/algorithms/algorithms.py index 6b63cf1..bee4926 100644 --- a/trailmet/algorithms/algorithms.py +++ b/trailmet/algorithms/algorithms.py @@ -180,7 +180,7 @@ def accuracy(self, output, target, topk=(1, )): res.append(correct_k.mul_(100.0 / batch_size)) return res - def test(self, model, dataloader, loss_fn=None, device=None): + def test(self, model, dataloader, loss_fn=None, device=None, progress=True): """This method is used to test the performance of the trained model.""" if device is None: device = next(model.parameters()).device @@ -188,12 +188,12 @@ def test(self, model, dataloader, loss_fn=None, device=None): model.to(device) model.eval() counter = 0 - tk1 = tqdm_notebook(dataloader, total=len(dataloader)) running_acc1 = 0 running_acc5 = 0 running_loss = 0 + pbar = tqdm_notebook(dataloader, total=len(dataloader)) if progress else dataloader with torch.no_grad(): - for images, targets in tk1: + for images, targets in pbar: counter += 1 images = images.to(device) targets = targets.to(device) @@ -204,13 +204,15 @@ def test(self, model, dataloader, loss_fn=None, device=None): if loss_fn is not None: loss = loss_fn(outputs, targets) running_loss += loss.item() - tk1.set_postfix( - loss=running_loss / counter, - acc1=running_acc1 / counter, - acc5=running_acc5 / counter, - ) + if progress: + pbar.set_postfix( + loss=running_loss / counter, + acc1=running_acc1 / counter, + acc5=running_acc5 / counter, + ) else: - tk1.set_postfix(acc1=running_acc1 / counter, + if progress: + pbar.set_postfix(acc1=running_acc1 / counter, acc5=running_acc5 / counter) if loss_fn is not None: return running_acc1 / counter, running_loss / counter diff --git a/trailmet/algorithms/quantize/__init__.py b/trailmet/algorithms/quantize/__init__.py index 4a7ff40..9c56851 100644 --- a/trailmet/algorithms/quantize/__init__.py +++ b/trailmet/algorithms/quantize/__init__.py @@ -19,53 +19,8 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from .bitsplit import BitSplit -from .brecq import BRECQ -from .lapq import LAPQ -from .methods import ( - UniformAffineQuantizer, - AdaRoundQuantizer, - BitSplitQuantizer, - ActQuantizer, - QuantizationBase, - UniformQuantization, - ClippedUniformQuantization, - FixedClipValueQuantization, - MaxAbsStaticQuantization, - LearnedStepSizeQuantization, - LpNormQuantization, -) -from .qmodel import ( - QuantBasicBlock, - QuantBottleneck, - QuantInvertedResidual, - QuantModule, - BaseQuantBlock, - QBasicBlock, - QBottleneck, - QInvertedResidual, - ActivationModuleWrapper, - ParameterModuleWrapper, -) -from .quantize import ( - BaseQuantization, - StraightThrough, - RoundSTE, - Conv2dFunctor, - LinearFunctor, - FoldBN, -) -from .reconstruct import ( - StopForwardException, - DataSaverHook, - GetLayerInpOut, - save_inp_oup_data, - GradSaverHook, - GetLayerGrad, - save_grad_data, - LinearTempDecay, - LayerLossFunction, - layer_reconstruction, - BlockLossFunction, - block_reconstruction, -) + +from . import quantize +from . import lapq +from . import bitsplit +from . import brecq \ No newline at end of file diff --git a/trailmet/algorithms/quantize/_methods.py b/trailmet/algorithms/quantize/_methods.py new file mode 100644 index 0000000..6c61b5f --- /dev/null +++ b/trailmet/algorithms/quantize/_methods.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +from typing import Dict, Callable +from trailmet.algorithms.quantize.observers import BaseObserver, MinMaxObserver, LpNormObserver +from trailmet.algorithms.quantize.utils import reshape_qparams_by_channel + + + +OBSERVER_MAPPING: Dict[str, Callable] = { + 'min_max': MinMaxObserver, + 'lp_norm': LpNormObserver +} + + +class RoundSTE(torch.autograd.Function): + """grad enabled round function""" + @staticmethod + def forward(ctx, input): + return torch.round(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class FloorSTE(torch.autograd.Function): + """grad enabled floor function""" + @staticmethod + def forward(ctx, input): + return torch.floor(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class BaseQuantizer(nn.Module): + def __init__(self, kwargs: dict): + self.observer: BaseObserver = OBSERVER_MAPPING[kwargs.get( + 'observer', 'min_max')](**kwargs) + self.quant_min = self.observer.quant_min + self.quant_max = self.observer.quant_max + self.per_channel = kwargs.get('per_channel', False) + self.ch_axis = kwargs.get('ch_axis', 0) + self.enable_observation = True + self.enable_quantization = True + + def __register_buffer__(self, name, value): + if hasattr(self, name): + delattr(self, name) + self.register_buffer(name, value) + + def __register_parameter__(self, name, value): + if hasattr(self, name): + delattr(self, name) + self.register_parameter(name, nn.Parameter(value)) + + def quantize(self, x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + round_mode: str = 'nearest'): + if self.per_channel: + scale, zero_point = reshape_qparams_by_channel( + x, scale, zero_point, self.ch_axis) + if round_mode == 'nearest': + x_int = RoundSTE.apply(x / scale) + elif round_mode == 'stochastic': + x_floor = FloorSTE.apply(x / scale) + x_int = x_floor + torch.bernoulli((x / scale) - x_floor) + else: + raise NotImplementedError + x_quant = torch.clamp(x_int + zero_point, self.quant_min, self.quant_max) + return x_quant + + def dequantize(self, x_quant: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor): + x_dequant = (x_quant - zero_point) * scale + return x_dequant + + def reset_bitwidth(self, n_bits: int): + self.observer.reset_bitwidth(n_bits) + self.quant_min = self.observer.quant_min + self.quant_max = self.observer.quant_max + + +class UniformQuantizer(BaseQuantizer): + def __init__(self, kwargs: dict): + super().__init__(kwargs) + self.__register_buffer__('scale', torch.tensor([1.0], dtype=torch.float)) + self.__register_buffer__('zero_point', torch.tensor([0], dtype=torch.int)) + + def forward(self, x: torch.Tensor): + if self.enable_observation: + x = self.observer(x) + + if self.enable_quantization: + self.scale, self.zero_point = self.observer.calculate_qparams() + self.scale, self.zero_point = self.scale.to(x.device), self.zero_point.to(x.device) + x_quant = self.quantize(x, self.scale, self.zero_point) + x_dequant = self.dequantize(x_quant, self.scale, self.zero_point) + return x_dequant + + return x + + +class AdaRoundQuantizer(BaseQuantizer): + def __init__(self, kwargs: dict): + super().__init__(kwargs) \ No newline at end of file diff --git a/trailmet/algorithms/quantize/assets/quantization_pipeline.png b/trailmet/algorithms/quantize/assets/quantization_pipeline.png new file mode 100644 index 0000000..28f3de9 Binary files /dev/null and b/trailmet/algorithms/quantize/assets/quantization_pipeline.png differ diff --git a/trailmet/algorithms/quantize/assets/quantizer_flow.png b/trailmet/algorithms/quantize/assets/quantizer_flow.png new file mode 100644 index 0000000..724d6e2 Binary files /dev/null and b/trailmet/algorithms/quantize/assets/quantizer_flow.png differ diff --git a/trailmet/algorithms/quantize/bitsplit.py b/trailmet/algorithms/quantize/bitsplit.py index 19c1206..fa5300e 100644 --- a/trailmet/algorithms/quantize/bitsplit.py +++ b/trailmet/algorithms/quantize/bitsplit.py @@ -19,113 +19,81 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + + +import os, time import copy, random, pickle +import numpy as np import torch import torch.nn as nn from collections import OrderedDict +from tqdm import tqdm from trailmet.utils import seed_everything from trailmet.algorithms.quantize.quantize import BaseQuantization from trailmet.models.resnet import BasicBlock, Bottleneck from trailmet.models.mobilenet import InvertedResidual -from trailmet.algorithms.quantize.qmodel import ( - QBasicBlock, - QBottleneck, - QInvertedResidual, -) +from trailmet.algorithms.quantize.modules import QBasicBlock, QBottleneck, QInvertedResidual from trailmet.algorithms.quantize.methods import BitSplitQuantizer, ActQuantizer -import logging -from datetime import datetime -from tqdm import tqdm -import wandb -import pandas as pd -import numpy as np -import os -import time - -from trailmet.utils import AverageMeter, accuracy, save_checkpoint - -logger = logging.getLogger(__name__) global feat, prev_feat, conv_feat - - def hook(module, input, output): global feat feat = output.data.cpu().numpy() - - def current_input_hook(module, inputdata, outputdata): global prev_feat - prev_feat = inputdata[0].data - - + prev_feat = inputdata[0].data#.cpu()#.numpy() def conv_hook(module, inputdata, outputdata): global conv_feat - conv_feat = outputdata.data - + conv_feat = outputdata.data#.cpu()#.numpy() class QuantModel(nn.Module): - """ - Parameters - ---------- - model (nn.Module): Model to be used. - arch (str): Architecture to be used. - """ - def __init__(self, model: nn.Module, arch='ResNet50'): super().__init__() self.supported = { - BasicBlock: QBasicBlock, - Bottleneck: QBottleneck, - InvertedResidual: QInvertedResidual, + BasicBlock : QBasicBlock, + Bottleneck : QBottleneck, + InvertedResidual : QInvertedResidual, } - if arch == 'ResNet50': + if arch=='ResNet50': setattr(model, 'quant', ActQuantizer()) setattr(model, 'fc', nn.Sequential(ActQuantizer(), model.fc)) - if arch == 'MobileNetV2': + if arch=='MobileNetV2': + # setattr(model, 'quant1', ActQuantizer()) + # setattr(model, 'quant2', ActQuantizer()) setattr(model, 'conv2', nn.Sequential(ActQuantizer(), model.conv2)) - setattr(model, 'linear', nn.Sequential(ActQuantizer(), - model.linear)) + setattr(model, 'linear', nn.Sequential(ActQuantizer(), model.linear)) self.quant_block_refactor(model) def quant_block_refactor(self, module: nn.Module): - """Recursively modify the supported conv-blocks to add activation - quantization layers :param module: nn.Module with supported conv-block - classes in its children.""" + """ + Recursively modify the supported conv-blocks to add activation quantization layers + :param module: nn.Module with supported conv-block classes in its children + """ for name, child_module in module.named_children(): if type(child_module) in self.supported: - setattr(module, name, - self.supported[type(child_module)](child_module)) - elif isinstance(child_module, - (nn.Conv2d, nn.Linear, nn.ReLU, nn.ReLU6)): + setattr(module, name, self.supported[type(child_module)](child_module)) + elif isinstance(child_module, (nn.Conv2d, nn.Linear, nn.ReLU, nn.ReLU6)): continue - else: - self.quant_block_refactor(child_module) - + else: self.quant_block_refactor(child_module) class BitSplit(BaseQuantization): """ - Class for post-training quantization using bit-split and stitching method - based on - Towards accurate post-training network quantization via + Class for post-training quantization using bit-split and stitching method + based on - Towards accurate post-training network quantization via bit-split and stitching [https://dl.acm.org/doi/abs/10.5555/3524938.3525851] - Parameters - ---------- - model (nn.Module): Model to be used - dataloaders (dict): Dictionary with dataloaders for train, test, val - W_BITS: bitwidth for weight quantization - A_BITS: bitwidth for activation quantization - CHANNEL_WISE: apply channel-wise quantization for weights - ACT_QUANT: apply activation quantization - HEAD_STEM_PRECISION: bitwidth for first and last layer - PREC_CONFIG: list of bitwidths of the body for mixed precision - CALIB_BATCHES: num of batches in calibration dataset - LOAD_ACT_SCALES: load precomputed weight scales - LOAD_WEIGHT_SCALES: load precomputed activation scales - SAVE_PATH: path for storing quantized weights and scales + :param W_BITS: bitwidth for weight quantization + :param A_BITS: bitwidth for activation quantization + :param CHANNEL_WISE: apply channel-wise quantization for weights + :param ACT_QUANT: apply activation quantization + :param HEAD_STEM_PRECISION: bitwidth for first and last layer + :param PREC_CONFIG: list of bitwidths of the body for mixed precision + :param CALIB_BATCHES: num of batches in calibration dataset + :param LOAD_ACT_SCALES: load precomputed weight scales + :param LOAD_WEIGHT_SCALES: load precomputed activation scales + :param SAVE_PATH: path for storing quantized weights and scales """ - def __init__(self, model: nn.Module, dataloaders, **kwargs): super(BitSplit, self).__init__(**kwargs) self.model = model @@ -143,10 +111,9 @@ def __init__(self, model: nn.Module, dataloaders, **kwargs): self.dataset = self.kwargs.get('DATASET', '') self.precision_config = self.kwargs.get('PREC_CONFIG', []) if self.precision_config: - w_prefix = str(self.precision_config[0]) + '_mix' - else: - w_prefix = str(self.w_bits) - self.prefix = self.save_path + self.arch + '_' + self.dataset + '/W' + w_prefix + w_prefix = str(self.precision_config[0])+'_mix' + else: w_prefix = str(self.w_bits) + self.prefix = self.save_path+self.arch+'_'+self.dataset+'/W'+w_prefix if not os.path.exists(self.prefix): os.makedirs(self.prefix) self.load_act_scales = self.kwargs.get('LOAD_ACT_SCALES', False) @@ -155,33 +122,7 @@ def __init__(self, model: nn.Module, dataloaders, **kwargs): self.act_quant = self.kwargs.get('ACT_QUANT', True) self.head_stem_precision = self.kwargs.get('HEAD_STEM_PRECISION', None) - self.wandb_monitor = self.kwargs.get('WANDB', 'False') - self.dataset_name = dataloaders['train'].dataset.__class__.__name__ - self.save = './checkpoints/' - - self.name = '_'.join([ - self.dataset_name, - str(self.a_bits), - datetime.now().strftime('%b-%d_%H:%M:%S'), - ]) - - os.makedirs(f'{os.getcwd()}/logs/BitSplit', exist_ok=True) - os.makedirs(self.save, exist_ok=True) - self.logger_file = f'{os.getcwd()}/logs/BitSplit/{self.name}.log' - - logging.basicConfig( - filename=self.logger_file, - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO, - ) - - logger.info(f'Experiment Arguments: {self.kwargs}') - - if self.wandb_monitor: - wandb.init(project='Trailmet BitSplit', name=self.name) - wandb.config.update(self.kwargs) - + def compress_model(self): self.model.to(self.device) self.qmodel = copy.deepcopy(self.model) @@ -195,378 +136,200 @@ def compress_model(self): self.act_quant_modules[-1].set_bitwidth(max(8, self.a_bits)) assert self.arch in ['MobileNetV2', 'ResNet50'] - print('==> Starting weight quantization') - logger.info('==> Starting weight quantization') + print("==> Starting weight quantization") self.weight_quantizer(load_only=self.load_weight_scales) if self.act_quant: if self.load_act_scales: - scales = np.load(self.prefix + '/act_' + str(self.a_bits) + - '_scales.npy') + scales = np.load(self.prefix+'/act_'+str(self.a_bits)+'_scales.npy') for index, q_module in enumerate(self.act_quant_modules): q_module.set_scale(scales[index]) else: - print("==> Starting '{}-bit' activation quantization".format( - self.a_bits)) - logger.info( - "==> Starting '{}-bit' activation quantization".format( - self.a_bits)) - self.act_quantizer(self.qmodel, - prefix=self.prefix, - n_batches=self.calib_batches) - - save_checkpoint( - { - 'state_dict': self.qmodel.module.state_dict(), - }, - is_best=False, - save=self.save, - ) - # save_state_dict(self.qmodel.state_dict(), self.prefix, filename='state_dict.pth') - - print('testing quantized model') - logger.info('testing quantized model') - - val_top1_acc_list = [] - val_top5_acc_list = [] - - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - self.qmodel, self.dataloaders['val'], nn.CrossEntropyLoss()) - val_top1_acc_list.append(valid_top1_acc.cpu().numpy()) - val_top5_acc_list.append(valid_top5_acc.cpu().numpy()) - - df_data = np.array([ - val_top1_acc_list, - val_top5_acc_list, - ]).T - df = pd.DataFrame( - df_data, - columns=[ - 'Validation Top1', - 'Validation Top5', - ], - ) - df.to_csv( - f'{os.getcwd()}/logs/BitSplit/{self.name}.csv', - index=False, - ) - - # TODO : Use functions to process submodules of respective models so that adding new models in future is easier + print("==> Starting '{}-bit' activation quantization".format(self.a_bits)) + self.act_quantizer(self.qmodel, prefix=self.prefix, n_batches=self.calib_batches) + save_state_dict(self.qmodel.state_dict(), self.prefix, filename='state_dict.pth') + return self.qmodel + + +# TODO : Use functions to process submodules of respective models so that adding new models in future is easier def weight_quantizer(self, load_only=False): - """Find optimum weight quantization scales for ResNet & Mobilenet.""" + """ + Find optimum weight quantization scales for ResNet & Mobilenet + """ #### Quantizer for MobilenetV2 #### - if self.arch == 'MobileNetV2': + if self.arch=='MobileNetV2': count = 3 for i in range(len(self.model.layers)): - if len(self.model.layers[i].shortcut) > 0: - count += 4 - else: - count += 3 + if len(self.model.layers[i].shortcut)>0: count+=4 + else: count+=3 pbar = tqdm(total=count) - layer_to_block = [ - 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4 - ] - assert (len(self.precision_config) == 0 or len( - self.precision_config) == 7), 'config list must be of length 7' + layer_to_block = [1,1,1,1,1,1,1,2,2,2,2,2,2,2,3,3,4] + assert len(self.precision_config)==0 or len(self.precision_config)==7, 'config list must be of length 7' # quantize first conv layer conv = self.model.conv1 conv_quan = self.qmodel.conv1 w_bit = self.w_bits - if self.head_stem_precision is not None: - w_bit = self.head_stem_precision - if self.precision_config: - w_bit = self.precision_config[0] + if self.head_stem_precision is not None: w_bit = self.head_stem_precision + if self.precision_config: w_bit = self.precision_config[0] if w_bit == 32: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=self.prefix + '/conv1', - device=self.device, - ec=False, - ) - load_ofwa(conv, conv_quan, prefix=self.prefix + '/conv1') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=self.prefix+'/conv1', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=self.prefix+'/conv1') pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # quantize blocks for layer_idx in range(len(self.model.layers)): current_layer_pretrained = self.model.layers[layer_idx] current_layer_quan = self.qmodel.layers[layer_idx] - w_bit = (self.precision_config[layer_to_block[layer_idx]] - if self.precision_config else self.w_bits) - skip = w_bit == 32 - pkl_path = self.prefix + '/layer' + str(layer_idx) + w_bit = self.precision_config[layer_to_block[layer_idx]] if self.precision_config else self.w_bits + skip = (w_bit==32) + pkl_path = self.prefix+'/layer'+str(layer_idx) # conv layers - for idx in range(1, 4): + for idx in range(1,4): conv = eval('current_layer_pretrained.conv{}'.format(idx)) conv_quan = eval('current_layer_quan.conv{}'.format(idx)) if skip: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=pkl_path + '_conv{}'.format(idx), - device=self.device, - dw=(idx == 2), - ec=False, - ) - load_ofwa(conv, - conv_quan, - prefix=pkl_path + '_conv' + str(idx)) + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, + w_bit, self.calib_batches, prefix=pkl_path+'_conv{}'.format(idx), + device=self.device, dw=(idx==2), ec=False) + load_ofwa(conv, conv_quan, prefix=pkl_path+'_conv'+str(idx)) pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # shortcut layer - if len(current_layer_pretrained.shortcut) > 0: + if len(current_layer_pretrained.shortcut)>0: conv = current_layer_pretrained.shortcut[0] conv_quan = current_layer_quan.shortcut[0] if skip: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=pkl_path + '_shortcut'.format(idx), - device=self.device, - ec=False, - ) - load_ofwa(conv, - conv_quan, - prefix=pkl_path + '_shortcut') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=pkl_path+'_shortcut'.format(idx), device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=pkl_path+'_shortcut') pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # quantize last conv layer conv = self.model.conv2 conv_quan = self.qmodel.conv2[1] - if self.precision_config: - w_bit = self.precision_config[-2] - if w_bit == 32: + if self.precision_config: w_bit = self.precision_config[-2] + if w_bit==32: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=self.prefix + '/conv2', - device=self.device, - ec=False, - ) - load_ofwa(conv, conv_quan, prefix=self.prefix + '/conv2') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=self.prefix+'/conv2', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=self.prefix+'/conv2') pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # quantize last linear layer conv = self.model.linear conv_quan = self.qmodel.linear[1] w_bit = self.w_bits - if self.head_stem_precision is not None: - w_bit = self.head_stem_precision - if self.precision_config: - w_bit = self.precision_config[-1] + if self.head_stem_precision is not None: w_bit = self.head_stem_precision + if self.precision_config: w_bit = self.precision_config[-1] if w_bit == 32: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=self.prefix + '/linear', - device=self.device, - ec=False, - ) - load_ofwa(conv, conv_quan, prefix=self.prefix + '/linear') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=self.prefix+'/linear', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=self.prefix+'/linear') pbar.update(1) pbar.close() - #### Quantizer for Resnet50 #### - elif self.arch == 'ResNet50': + #### Quantizer for Resnet50 #### + elif self.arch=='ResNet50': count = 2 - for i in range(1, 5): + for i in range(1,5): layer = eval('self.model.layer{}'.format(i)) for j in range(len(layer)): - count += 3 - if layer[j].downsample is not None: - count += 1 + count+=3 + if(layer[j].downsample is not None): count+=1 pbar = tqdm(total=count) - # quantize first conv layer + # quantize first conv layer conv = self.model.conv1 conv_quan = self.qmodel.conv1 w_bit = self.w_bits - if self.head_stem_precision is not None: - w_bit = self.head_stem_precision - if self.precision_config: - w_bit = self.precision_config[0] - if w_bit == 32: + if self.head_stem_precision is not None: w_bit = self.head_stem_precision + if self.precision_config: w_bit = self.precision_config[0] + if w_bit==32: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=self.prefix + '/conv1', - device=self.device, - ec=False, - ) - load_ofwa(conv, conv_quan, prefix=self.prefix + '/conv1') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=self.prefix+'/conv1', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=self.prefix+'/conv1') pbar.update(1) - time.sleep(0.1) - # quantize blocks + time.sleep(.1) + # quantize blocks for layer_idx in range(1, 5): - current_layer_pretrained = eval( - 'self.model.layer{}'.format(layer_idx)) - current_layer_quan = eval( - 'self.qmodel.layer{}'.format(layer_idx)) - w_bit = (self.precision_config[layer_idx] - if self.precision_config else self.w_bits) - skip = w_bit == 32 + current_layer_pretrained = eval('self.model.layer{}'.format(layer_idx)) + current_layer_quan = eval('self.qmodel.layer{}'.format(layer_idx)) + w_bit = self.precision_config[layer_idx] if self.precision_config else self.w_bits + skip = w_bit==32 for block_idx in range(len(current_layer_pretrained)): - current_block_pretrained = current_layer_pretrained[ - block_idx] + current_block_pretrained = current_layer_pretrained[block_idx] current_block_quan = current_layer_quan[block_idx] - pkl_path = (self.prefix + '/layer' + str(layer_idx) + - '_block' + str(block_idx)) + pkl_path = self.prefix+'/layer'+str(layer_idx)+'_block'+str(block_idx) # conv layers for idx in range(1, 4): - conv = eval( - 'current_block_pretrained.conv{}'.format(idx)) - conv_quan = eval( - 'current_block_quan.conv{}'.format(idx)) + conv = eval('current_block_pretrained.conv{}'.format(idx)) + conv_quan = eval('current_block_quan.conv{}'.format(idx)) if skip: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=pkl_path + '_conv{}'.format(idx), - device=self.device, - ec=False, - ) - load_ofwa(conv, - conv_quan, - prefix=pkl_path + '_conv' + str(idx)) + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=pkl_path+'_conv{}'.format(idx), device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=pkl_path+'_conv'+str(idx)) pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # downsample if current_block_pretrained.downsample is not None: conv = current_block_pretrained.downsample[0] conv_quan = current_block_quan.downsample[0] if skip: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=pkl_path + '_downsample', - device=self.device, - ec=False, - ) - load_ofwa(conv, - conv_quan, - prefix=pkl_path + '_downsample') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=pkl_path+'_downsample', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=pkl_path+'_downsample') pbar.update(1) - time.sleep(0.1) + time.sleep(.1) # quantize last fc layer conv = self.model.fc conv_quan = self.qmodel.fc[1] w_bit = self.w_bits - if self.head_stem_precision is not None: - w_bit = self.head_stem_precision - if self.precision_config: - w_bit = self.precision_config[-1] - if w_bit == 32: + if self.head_stem_precision is not None: w_bit = self.head_stem_precision + if self.precision_config: w_bit = self.precision_config[-1] + if w_bit==32: conv_quan.weight.data.copy_(conv.weight.data) - else: - if not load_only: - conduct_ofwa( - self.train_loader, - self.model, - self.qmodel, - conv, - conv_quan, - w_bit, - self.calib_batches, - prefix=self.prefix + '/fc', - device=self.device, - ec=False, - ) - load_ofwa(conv, conv_quan, prefix=self.prefix + '/fc') + else: + if not load_only: conduct_ofwa(self.train_loader, self.model, self.qmodel, conv, conv_quan, w_bit, + self.calib_batches, prefix=self.prefix+'/fc', device=self.device, ec=False) + load_ofwa(conv, conv_quan, prefix=self.prefix+'/fc') pbar.update(1) pbar.close() - else: - raise NotImplementedError + else: raise NotImplementedError - # TODO : Write this in a more cleaner way +# TODO : Write this in a more cleaner way def act_quantizer(self, model, prefix, n_batches): - """Find optimum activation quantization scale for ResNet model based on - feature map.""" - + """ + Find optimum activation quantization scale for ResNet model based on feature map + """ # train_batches = iter(self.train_loader) # per_batch = len(next(train_batches)[1]) # act_sta_len = (n_batches+1)*per_batch def get_safe_len(x): - x /= 10 - y = 1 - while x >= 10: - x /= 10 - y *= 10 - return int(y) - + x/=10 + y=1 + while(x>=10): + x/=10 + y*=10 + return int(y) act_sta_len = 3000000 feat_buf = np.zeros(act_sta_len) scales = np.zeros(len(self.act_quant_modules)) - - pbar = tqdm( - self.act_quant_modules, - desc= - 'Activation quantization, q_module [X] (X / X Steps) (prev_layer_scale=X.X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - disable=False, - ) + + pbar = tqdm(self.act_quant_modules, total=len(self.act_quant_modules)) with torch.no_grad(): for index, q_module in enumerate(pbar): batch_iterator = iter(self.train_loader) @@ -578,138 +341,58 @@ def get_safe_len(x): model(images) feat_len = feat.size per_batch = min(get_safe_len(feat_len), 100000) - n_batches = int(act_sta_len / per_batch) + n_batches = int(act_sta_len/per_batch) repeat = True - while repeat: + while(repeat): repeat = False for batch_idx in range(0, n_batches): - pbar.set_description( - 'Activation quantization, q_module [%d] (%d / %d Steps) (prev_layer_scale=%2.5f)' - % ( - index, - batch_idx + 1, - n_batches, - scales[index - 1], - )) + pbar.set_postfix(batch=f'{batch_idx+1}/{n_batches}', prev_layer_scale=scales[index-1]) images, targets = next(batch_iterator) - images = images.cuda(device=self.device, - non_blocking=True) + images = images.cuda(device=self.device, non_blocking=True) model(images) if q_module.signed: feat_tmp = np.abs(feat).reshape(-1) else: - feat_tmp = feat[feat > 0].reshape(-1) + feat_tmp = feat[feat>0].reshape(-1) if feat_tmp.size < per_batch: - per_batch = int(per_batch / 10) - n_batches = int(n_batches * 10) + per_batch = int(per_batch/10) + n_batches = int(n_batches*10) repeat = True break np.random.shuffle(feat_tmp) - feat_buf[batch_idx * per_batch:(batch_idx + 1) * - per_batch] = feat_tmp[0:per_batch] - if not repeat: + feat_buf[batch_idx*per_batch:(batch_idx+1)*per_batch] = feat_tmp[0:per_batch] + if(not repeat): scales[index] = q_module.init_quantization(feat_buf) handle.remove() + # for batch_idx in range(0, n_batches): + # images, targets = next(batch_iterator) + # images = images.cuda(device=self.device, non_blocking=True) + # model(images) + # if q_module.signed: + # feat_tmp = np.abs(feat).reshape(-1) + # else: + # feat_tmp = feat[feat>0].reshape(-1) + # np.random.shuffle(feat_tmp) + # feat_buf[batch_idx*per_batch:(batch_idx+1)*per_batch] = feat_tmp[0:per_batch] + + # scales[index] = q_module.init_quantization(feat_buf) + # pbar.set_postfix(curr_layer_scale=scales[index]) + # np.save(os.path.join(prefix, 'act_'+str(self.a_bits)+'_scales.npy'), scales) + # handle.remove() + pbar.close() - np.save( - os.path.join(prefix, 'act_' + str(self.a_bits) + '_scales.npy'), - scales) + np.save(os.path.join(prefix, 'act_' + str(self.a_bits) + '_scales.npy'), scales) for index, q_module in enumerate(self.act_quant_modules): q_module.set_scale(scales[index]) - def test(self, model, dataloader, loss_fn): - batch_time = AverageMeter('Time', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') - - epoch_iterator = tqdm( - dataloader, - desc= - 'Validating network (X / X Steps) (batch time=X.Xs) (loss=X.X) (top1=X.X) (top5=X.X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - disable=False, - ) - model.eval() - model.to(self.device) - - with torch.no_grad(): - end = time.time() - - for i, (images, labels) in enumerate(epoch_iterator): - images = images.to(self.device, dtype=torch.float) - labels = labels.to(self.device) - - preds = model(images) - - loss = loss_fn(preds, labels) - - pred1, pred5 = accuracy(preds, labels, topk=(1, 5)) - - n = images.size(0) - losses.update(loss.item(), n) - top1.update(pred1[0], n) - top5.update(pred5[0], n) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - epoch_iterator.set_description( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) - - logger.info( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) - - if self.wandb_monitor: - wandb.log({ - 'val_loss': losses.val, - 'val_top1_acc': top1.val, - 'val_top5_acc': top5.val, - }) - - print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'.format( - top1=top1, top5=top5)) - return losses.avg, top1.avg, top5.avg - - -def conduct_ofwa( - train_loader, - model_pretrained, - model_quan, - conv, - conv_quan, - bitwidth, - n_batches, - device, - num_epochs=100, - prefix=None, - dw=False, - ec=False, -): +def conduct_ofwa(train_loader, model_pretrained, model_quan, conv, conv_quan, + bitwidth, n_batches, device, num_epochs=100, prefix=None, dw=False, ec=False): # for fc if not hasattr(conv, 'kernel_size'): - W = conv.weight.data # .cpu() + W = conv.weight.data#.cpu() W_shape = W.shape B_sav, B, alpha = BitSplitQuantizer(W.cpu().numpy(), bitwidth).ofwa() # B_sav, B, alpha = ofwa(W.cpu().numpy(), bitwidth) @@ -731,11 +414,11 @@ def conduct_ofwa( batch_iterator = iter(train_loader) # weights and bias - W = conv.weight.data # .cpu() + W = conv.weight.data#.cpu() if conv.bias is None: bias = torch.zeros(W.shape[0]).to(conv.weight.device) else: - bias = conv.bias.data # .cpu() + bias = conv.bias.data#.cpu() # feat extract per_batch = 400 @@ -747,9 +430,8 @@ def conduct_ofwa( [prev_feat_n, prev_feat_c, prev_feat_h, prev_feat_w] = prev_feat.shape [conv_feat_n, conv_feat_c, conv_feat_h, conv_feat_w] = conv_feat.shape - X = torch.zeros(n_batches * per_batch, prev_feat_c, kernel_h, - kernel_w).to(device) - Y = torch.zeros(n_batches * per_batch, conv_feat_c).to(device) + X = torch.zeros(n_batches*per_batch, prev_feat_c, kernel_h, kernel_w).to(device) + Y = torch.zeros(n_batches*per_batch, conv_feat_c).to(device) for batch_idx in range(0, n_batches): input, target = next(batch_iterator) @@ -757,52 +439,34 @@ def conduct_ofwa( model_pretrained(input_pretrained) input_quan = input.cuda(device=device, non_blocking=True) model_quan(input_quan) - - prev_feat_pad = torch.zeros(prev_feat_n, prev_feat_c, - prev_feat_h + 2 * pad_h, - prev_feat_w + 2 * pad_w).to(device) - prev_feat_pad[:, :, pad_h:pad_h + prev_feat_h, - pad_w:pad_w + prev_feat_w] = prev_feat - prev_feat_pad = (prev_feat_pad.unfold(2, kernel_h, stride_h).unfold( - 3, kernel_w, stride_w).permute(0, 2, 3, 1, 4, 5)) - [ - feat_pad_n, - feat_pad_h, - feat_pad_w, - feat_pad_c, - feat_pad_hh, - feat_pad_ww, - ] = prev_feat_pad.shape - assert feat_pad_hh == kernel_h - assert feat_pad_ww == kernel_w - - prev_feat_pad = prev_feat_pad.reshape( - feat_pad_n * feat_pad_h * feat_pad_w, feat_pad_c, kernel_h, - kernel_w) + + prev_feat_pad = torch.zeros(prev_feat_n, prev_feat_c, prev_feat_h+2*pad_h, prev_feat_w+2*pad_w).to(device) + prev_feat_pad[:, :, pad_h:pad_h+prev_feat_h, pad_w:pad_w+prev_feat_w] = prev_feat + prev_feat_pad = prev_feat_pad.unfold(2, kernel_h, stride_h).unfold(3, kernel_w, stride_w).permute(0,2,3,1,4,5) + [feat_pad_n, feat_pad_h, feat_pad_w, feat_pad_c, feat_pad_hh, feat_pad_ww] = prev_feat_pad.shape + assert(feat_pad_hh==kernel_h) + assert(feat_pad_ww==kernel_w) + + prev_feat_pad = prev_feat_pad.reshape(feat_pad_n*feat_pad_h*feat_pad_w, feat_pad_c, kernel_h, kernel_w) rand_index = list(range(prev_feat_pad.shape[0])) random.shuffle(rand_index) rand_index = rand_index[0:per_batch] - X[per_batch * batch_idx:per_batch * - (batch_idx + 1), :] = prev_feat_pad[rand_index, :] - conv_feat_tmp = conv_feat.permute(0, 2, 3, 1).reshape( - -1, conv_feat_c) - bias - Y[per_batch * batch_idx:per_batch * - (batch_idx + 1), :] = conv_feat_tmp[rand_index, :] - + X[per_batch*batch_idx:per_batch*(batch_idx+1),:] = prev_feat_pad[rand_index, :] + conv_feat_tmp = conv_feat.permute(0,2,3,1).reshape(-1, conv_feat_c) - bias + Y[per_batch*batch_idx:per_batch*(batch_idx+1),:] = conv_feat_tmp[rand_index, :] + handle_prev.remove() handle_conv.remove() - + ## ofwa init W_shape = W.shape X = X.cpu().numpy() Y = Y.cpu().numpy() W = W.reshape(W_shape[0], -1) if dw: - B, alpha = BitSplitQuantizer(W.cpu().numpy(), - bitwidth).ofwa_rr_dw(X, Y, num_epochs) - else: - B, alpha = BitSplitQuantizer(W.cpu().numpy(), - bitwidth).ofwa_rr(X, Y, num_epochs) + B, alpha = BitSplitQuantizer(W.cpu().numpy(), bitwidth).ofwa_rr_dw(X, Y, num_epochs) + else: + B, alpha = BitSplitQuantizer(W.cpu().numpy(), bitwidth).ofwa_rr(X, Y, num_epochs) with open(prefix + '_rr_b30x400_e100.pkl', 'wb') as f: pickle.dump({'B': B, 'alpha': alpha}, f, pickle.HIGHEST_PROTOCOL) @@ -810,7 +474,7 @@ def conduct_ofwa( def load_ofwa(conv, conv_quan, prefix=None): # for fc if not hasattr(conv, 'kernel_size'): - W = conv.weight.data # .cpu() + W = conv.weight.data#.cpu() W_shape = W.shape with open(prefix + '_fwa.pkl', 'rb') as f: B_alpha = pickle.load(f) @@ -821,7 +485,7 @@ def load_ofwa(conv, conv_quan, prefix=None): return # weights and bias - W = conv.weight.data # .cpu() + W = conv.weight.data#.cpu() W_shape = W.shape with open(prefix + '_rr_b30x400_e100.pkl', 'rb') as f: @@ -837,8 +501,7 @@ def save_state_dict(state_dict, path, filename='state_dict.pth'): new_state_dict = OrderedDict() for key in state_dict.keys(): if '.module.' in key: - new_state_dict[key.replace('.module.', - '.')] = state_dict[key].cpu() + new_state_dict[key.replace('.module.', '.')] = state_dict[key].cpu() else: new_state_dict[key] = state_dict[key].cpu() - torch.save(new_state_dict, saved_path) + torch.save(new_state_dict, saved_path) \ No newline at end of file diff --git a/trailmet/algorithms/quantize/brecq.py b/trailmet/algorithms/quantize/brecq.py index f94c64c..59d4c08 100644 --- a/trailmet/algorithms/quantize/brecq.py +++ b/trailmet/algorithms/quantize/brecq.py @@ -19,75 +19,69 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import copy + import torch import torch.nn as nn -import torch.distributed as dist +from typing import Union from trailmet.utils import seed_everything -from trailmet.algorithms.quantize.quantize import ( - BaseQuantization, - FoldBN, - StraightThrough, -) -from trailmet.models.resnet import BasicBlock, Bottleneck -from trailmet.models.mobilenet import InvertedResidual -from trailmet.algorithms.quantize.qmodel import ( - QuantBasicBlock, - QuantBottleneck, - QuantInvertedResidual, - QuantModule, - BaseQuantBlock, -) -from trailmet.algorithms.quantize.reconstruct import ( - layer_reconstruction, - block_reconstruction, -) - -import logging -from datetime import datetime -from tqdm import tqdm -import wandb -import pandas as pd -import numpy as np -import os -import time - -from trailmet.utils import AverageMeter, accuracy, save_checkpoint - -logger = logging.getLogger(__name__) - -supported = { - BasicBlock: QuantBasicBlock, - Bottleneck: QuantBottleneck, - InvertedResidual: QuantInvertedResidual, -} +from trailmet.algorithms.quantize.quantize import BaseQuantization, BaseQuantModel +from trailmet.algorithms.quantize.quantize import GetLayerGrad, GetLayerInpOut, BaseQuantLoss +from trailmet.algorithms.quantize.modules import StraightThrough, QuantModule, BaseQuantBlock +from trailmet.algorithms.quantize.methods import UniformAffineQuantizer, AdaRoundQuantizer + + +class QuantModel(BaseQuantModel): + def __init__(self, model: nn.Module, weight_quant_params: dict, act_quant_params: dict): + super(QuantModel, self).__init__(model, weight_quant_params, act_quant_params, fold_bn=True) + + def reset_scale_method(self, scale_method = 'mse', act_quant_reset = False): + for module in self.quant_modules: + module.weight_quantizer.scale_method = scale_method + module.weight_quantizer.inited = False + if act_quant_reset: + module.act_quantizer.scale_method = scale_method + module.act_quantizer.inited = False + + def set_head_stem_precision(self, bitwidth): + """ + Set the precision (bitwidth) for weights and activations for the first and last + layers of the model. Also ignore reconstruction for the first layer. + """ + assert len(self.quant_modules) >= 2, 'Model has less than 2 quantization modules' + self.quant_modules[0].weight_quantizer.bitwidth_refactor(bitwidth) + self.quant_modules[0].act_quantizer.bitwidth_refactor(bitwidth) + self.quant_modules[-1].weight_quantizer.bitwidth_refactor(bitwidth) + self.quant_modules[-2].act_quantizer.bitwidth_refactor(bitwidth) + self.quant_modules[0].ignore_reconstruction = True + def disable_network_output_quantization(self): + """ + Disable Network Output Quantization + """ + self.quant_modules[-1].disable_act_quant = True + + class BRECQ(BaseQuantization): """ - Class for post-training quantization using block reconstruction method - based on - BRECQ: PUSHING THE LIMIT OF POST-TRAINING QUANTIZATION + Class for post-training quantization using block reconstruction method + based on - BRECQ: PUSHING THE LIMIT OF POST-TRAINING QUANTIZATION BY BLOCK RECONSTRUCTION [https://arxiv.org/abs/2102.05426] - Parameters - ---------- - model (nn.Module): Model to be used - dataloaders (dict): Dictionary with dataloaders for train, test, val - W_BITS: bitwidth for weight quantization - A_BITS: bitwidth for activation quantization - CHANNEL_WISE: apply channel_wise quantization for weights - ACT_QUANT: apply activation quantization - SET_8BIT_HEAD_STEM: Set the first and the last layer to 8-bit - NUM_SAMPLES: size of calibration dataset - WEIGHT: weight of rounding cost vs the reconstruction loss - ITERS_W: number of iteration for AdaRound - ITERS_A: number of iteration for LSQ - LR: learning rate for LSQ + :param W_BITS: bitwidth for weight quantization + :param A_BITS: bitwidth for activation quantization + :param CHANNEL_WISE: apply channel_wise quantization for weights + :param ACT_QUANT: apply activation quantization + :param SET_8BIT_HEAD_STEM: Set the first and the last layer to 8-bit + :param NUM_SAMPLES: size of calibration dataset + :param WEIGHT: weight of rounding cost vs the reconstruction loss + :param ITERS_W: number of iteration for AdaRound + :param ITERS_A: number of iteration for LSQ + :param LR: learning rate for LSQ """ - def __init__(self, model: nn.Module, dataloaders, **kwargs): super(BRECQ, self).__init__(**kwargs) - self.model = copy.deepcopy(model) + self.model = model self.train_loader = dataloaders['train'] self.test_loader = dataloaders['test'] self.kwargs = kwargs @@ -96,399 +90,271 @@ def __init__(self, model: nn.Module, dataloaders, **kwargs): self.channel_wise = self.kwargs.get('CHANNEL_WISE', True) self.act_quant = self.kwargs.get('ACT_QUANT', True) self.set_8bit_head_stem = self.kwargs.get('SET_8BIT_HEAD_STEM', False) - self.precision_config = self.kwargs.get('PREC_CONFIG', []) + self.w_budget = self.kwargs.get('W_BUDGET', None) + self.use_bits = self.kwargs.get('USE_BITS', [2,4,8]) + self.arch = self.kwargs.get('ARCH', '') + self.save_path = self.kwargs.get('SAVE_PATH', './runs/') self.num_samples = self.kwargs.get('NUM_SAMPLES', 1024) - self.weight = self.kwargs.get('WEIGHT', 0.01) + self.scale_method = self.kwargs.get('SCALE_METHOD', 'mse') + self.iters_w = self.kwargs.get('ITERS_W', 10000) self.iters_a = self.kwargs.get('ITERS_A', 10000) - self.optimizer = self.kwargs.get('OPTIMIZER', 'adam') - self.lr = self.kwargs.get('LR', 4e-4) + self.optim = self.kwargs.get('OPTIMIZER', torch.optim.adam) + self.weight = self.kwargs.get('WEIGHT', 0.01) + self.lr = self.kwargs.get('LR', 4e-5) + self.p = self.kwargs.get('P_VAL', 2.4) # Lp norm minimization for LSQ + self.gpu_id = self.kwargs.get('GPU_ID', 0) - self.calib_bs = self.kwargs.get('CALIB_BS', 64) + self.batch_size = self.kwargs.get('BATCH_SIZE', 64) self.seed = self.kwargs.get('SEED', 42) - self.p = 2.4 # Lp norm minimization for LSQ - self.b_start = 20 # temperature at the beginning of calibration - self.b_end = 2 # temperature at the end of calibration + self.b_start = 20 # temperature at the beginning of calibration + self.b_end = 2 # temperature at the end of calibration self.test_before_calibration = True self.device = torch.device('cuda:{}'.format(self.gpu_id)) torch.cuda.set_device(self.gpu_id) + self.calib_data = self.get_calib_samples(self.train_loader, self.num_samples) seed_everything(self.seed) - print('==> Using seed :', self.seed) - - self.wandb_monitor = self.kwargs.get('WANDB', 'False') - self.dataset_name = dataloaders['train'].dataset.__class__.__name__ - self.save = './checkpoints/' - - self.name = '_'.join([ - self.dataset_name, - f'{self.a_bits}', - f'{self.lr}', - datetime.now().strftime('%b-%d_%H:%M:%S'), - ]) - - os.makedirs(f'{os.getcwd()}/logs/BRECQ', exist_ok=True) - os.makedirs(self.save, exist_ok=True) - self.logger_file = f'{os.getcwd()}/logs/BRECQ/{self.name}.log' - - logging.basicConfig( - filename=self.logger_file, - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO, - ) - - logger.info(f'Experiment Arguments: {self.kwargs}') + print('==> Using seed :',self.seed) - if self.wandb_monitor: - wandb.init(project='Trailmet BRECQ', name=self.name) - wandb.config.update(self.kwargs) def compress_model(self): - """Method to build quantization parameters and finetune weights and/or - activations.""" - wq_params = { - 'n_bits': self.w_bits, - 'channel_wise': self.channel_wise, - 'scale_method': 'mse', + """ + method to build quantization parameters and finetune weights and/or activations + """ + self.model.to(self.device) + self.model.eval() + weight_quant_params = { + 'n_bits': self.w_bits, + 'channel_wise': self.channel_wise, + 'method': UniformAffineQuantizer, + 'scale_method': self.scale_method, } - aq_params = { - 'n_bits': self.a_bits, - 'channel_wise': False, - 'scale_method': 'mse', + act_quant_params = { + 'n_bits': self.a_bits, + 'channel_wise': False, + 'method': UniformAffineQuantizer, + 'scale_method': self.scale_method, 'leaf_param': self.act_quant, } - self.model = self.model.to(self.device) - self.model.eval() - self.qnn = QuantModel(model=self.model, - weight_quant_params=wq_params, - act_quant_params=aq_params) - self.qnn = self.qnn.to(self.device) + self.qnn = QuantModel(self.model, weight_quant_params, act_quant_params) + self.qnn.to(self.device) self.qnn.eval() - for i in range(len(self.precision_config)): - conf = self.precision_config[i] - self.qnn.set_layer_precision(conf[2], conf[3], conf[0], conf[1]) - print( - f'==> Layers from {conf[0]} to {conf[1]} set to precision w{conf[2]}a{conf[3]}' - ) - logger.info( - f'==> Layers from {conf[0]} to {conf[1]} set to precision w{conf[2]}a{conf[3]}' - ) - + w_compr = self.w_bits/32 if self.w_budget is None else self.w_budget + if self.w_budget is not None: + w_bits, qm_size, max_size = self.sensitivity_analysis( + self.qnn, self.test_loader, self.use_bits, self.w_budget, + self.save_path, '{}_{}_{}'.format(self.arch, w_compr, self.a_bits)) + print('==> Found optimal config for approx model size: {:.2f} MB ' \ + ' (orig {:.2f} MB)'.format(qm_size, max_size/self.w_budget)) + self.qnn.set_layer_precision(w_bits, self.a_bits) + self.qnn.reset_scale_method(self.scale_method, True) + if self.set_8bit_head_stem: print('==> Setting the first and the last layer to 8-bit') - logger.info('==> Setting the first and the last layer to 8-bit') - self.qnn.set_first_last_layer_to_8bit() - - self.cali_data = self.get_calib_samples(self.train_loader, - self.num_samples) - # device = next(self.qnn.parameters()).device + self.qnn.set_head_stem_precision(8) - # Initialize weight quantization parameters self.qnn.set_quant_state(True, False) print('==> Initializing weight quantization parameters') - logger.info('==> Initializing weight quantization parameters') - _ = self.qnn(self.cali_data[:self.calib_bs].to(self.device)) + _ = self.qnn(self.calib_data[:self.batch_size].to(self.device)) if self.test_before_calibration: - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - self.qnn, self.test_loader, nn.CrossEntropyLoss()) - print('Quantized accuracy before brecq: {}'.format(valid_top1_acc)) - logger.info( - 'Quantized accuracy before brecq: {}'.format(valid_top1_acc)) - - # Start weight calibration + print('Quantized accuracy before brecq: {}'.format(self.test(self.qnn, self.test_loader, device=self.device))) + + # Start quantized weight calibration kwargs = dict( - cali_data=self.cali_data, - iters=self.iters_w, - weight=self.weight, + iters=self.iters_w, + opt_mode='mse', + act_quant=False, asym=True, - b_range=(self.b_start, self.b_end), - warmup=0.2, - act_quant=False, - opt_mode='mse', - optim=self.optimizer, + b_range=(self.b_start, self.b_end), + warmup=0.2, ) - print('==> Starting weight calibration') - logger.info('==> Starting weight calibration') + print('==> Starting quantized-weight rounding parameter (alpha) calibration') self.reconstruct_model(self.qnn, **kwargs) self.qnn.set_quant_state(weight_quant=True, act_quant=False) - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - self.qnn, self.test_loader, nn.CrossEntropyLoss()) - print('Weight quantization accuracy: {}'.format(valid_top1_acc)) + print('Weight quantization accuracy: {}'.format(self.test(self.qnn, self.test_loader, device=self.device))) if self.act_quant: # Initialize activation quantization parameters self.qnn.set_quant_state(True, True) with torch.no_grad(): - _ = self.qnn(self.cali_data[:self.calib_bs].to(self.device)) - - # Disable output quantization because network output - # does not get involved in further computation + _ = self.qnn(self.calib_data[:self.calib_bs].to(self.device)) self.qnn.disable_network_output_quantization() - + # Start activation rounding calibration kwargs = dict( - cali_data=self.cali_data, - iters=self.iters_a, - act_quant=True, - opt_mode='mse', - lr=self.lr, - p=self.p, - optim=self.optimizer, + iters=self.iters_a, + opt_mode='mse', + act_quant=True, ) + print('==> Starting quantized-activation scaling parameter (delta) calibration') self.reconstruct_model(self.qnn, **kwargs) self.qnn.set_quant_state(weight_quant=True, act_quant=True) - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - self.qnn, self.test_loader, nn.CrossEntropyLoss()) - print('Full quantization (W{}A{}) accuracy: {}'.format( - self.w_bits, self.a_bits, valid_top1_acc)) - logger.info('Full quantization (W{}A{}) accuracy: {}'.format( - self.w_bits, self.a_bits, valid_top1_acc)) + # torch.save(self.qnn.state_dict(), f'{self.save_path}/weights/{self.arch}_{w_compr}_{self.a_bits}.pth') + print('Full quantization (W{}A{}) accuracy: {}'.format(w_compr, self.a_bits, + self.test(self.qnn, self.test_loader, device=self.device))) return self.qnn - def reconstruct_model(self, model: nn.Module, **kwargs): - """Method for model parameters reconstruction. - Takes in quantized model and optimizes weights by applying layer-wise - reconstruction for first and last layer, and block reconstruction - otherwise. + def reconstruct_model(self, module: nn.Module, **kwargs): """ - for name, module in model.named_children(): - if isinstance(module, QuantModule): - if module.ignore_reconstruction is True: + Method for model parameters reconstruction. Takes in quantized model + and optimizes weights by applying layer-wise reconstruction for first + and last layer, and block reconstruction otherwise. + """ + for name, child_module in module.named_children(): + if isinstance(child_module, QuantModule): + if child_module.ignore_reconstruction is True: print('Ignore reconstruction of layer {}'.format(name)) - logger.info( - 'Ignore reconstruction of layer {}'.format(name)) continue else: print('Reconstruction for layer {}'.format(name)) - logger.info('Reconstruction for layer {}'.format(name)) - layer_reconstruction(self.qnn, module, **kwargs) - elif isinstance(module, BaseQuantBlock): - if module.ignore_reconstruction is True: - print('Ignore reconstruction of block {}'.format(name)) - logger.info( - 'Ignore reconstruction of block {}'.format(name)) + self.reconstruct_module(self.qnn, child_module, **kwargs) + elif isinstance(child_module, BaseQuantBlock): + if child_module.ignore_reconstruction is True: + print('Ignore reconstruction of {} block {}'.format(self._parent_name, name)) continue else: - print('Reconstruction for block {}'.format(name)) - logger.info('Reconstruction for block {}'.format(name)) - block_reconstruction(self.qnn, module, **kwargs) - else: - self.reconstruct_model(module, **kwargs) - - def test(self, model, dataloader, loss_fn): - batch_time = AverageMeter('Time', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') - - epoch_iterator = tqdm( - dataloader, - desc= - 'Validating network (X / X Steps) (batch time=X.Xs) (loss=X.X) (top1=X.X) (top5=X.X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - disable=False, - ) - - model.eval() - model.to(self.device) - - with torch.no_grad(): - end = time.time() - - for i, (images, labels) in enumerate(epoch_iterator): - images = images.to(self.device, dtype=torch.float) - labels = labels.to(self.device) - - preds = model(images) - - loss = loss_fn(preds, labels) - - pred1, pred5 = accuracy(preds, labels, topk=(1, 5)) - - n = images.size(0) - losses.update(loss.item(), n) - top1.update(pred1[0], n) - top5.update(pred5[0], n) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - epoch_iterator.set_description( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) - - logger.info( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) - - if self.wandb_monitor: - wandb.log({ - 'val_loss': losses.val, - 'val_top1_acc': top1.val, - 'val_top5_acc': top5.val, - }) - - print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'.format( - top1=top1, top5=top5)) - return losses.avg, top1.avg, top5.avg - - -class QuantModel(nn.Module): - """Recursively replace the normal conv2d and Linear layer to QuantModule, - to enable calculating activation statistics and storing scaling factors. - - Parameters - ---------- - model (nn.Module): nn.Module with nn.Conv2d or nn.Linear in its children - weight_quant_params (dict): quantization parameters like n_bits for weight - quantizer - act_quant_params(dict): quantization parameters like n_bits for activation - quantizer - """ - - def __init__( - self, - model: nn.Module, - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - ): - super().__init__() - self.model = model - bn = FoldBN() - bn.search_fold_and_remove_bn(self.model) - self.quant_module_refactor(self.model, weight_quant_params, - act_quant_params) - self.quant_modules = [ - m for m in self.model.modules() if isinstance(m, QuantModule) - ] - - def quant_module_refactor( - self, - module: nn.Module, - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - ): - prev_quantmodule = None - for name, child_module in module.named_children(): - if type(child_module) in supported: - setattr( - module, - name, - supported[type(child_module)](child_module, - weight_quant_params, - act_quant_params), - ) - - elif isinstance(child_module, (nn.Conv2d, nn.Linear)): - setattr( - module, - name, - QuantModule(child_module, weight_quant_params, - act_quant_params), - ) - prev_quantmodule = getattr(module, name) - - elif isinstance(child_module, (nn.ReLU, nn.ReLU6)): - if prev_quantmodule is not None: - prev_quantmodule.activation_function = child_module - setattr(module, name, StraightThrough()) - else: - continue - - elif isinstance(child_module, StraightThrough): - continue - + print('Reconstruction for {} block {}'.format(self._parent_name, name)) + self.reconstruct_module(self.qnn, child_module, **kwargs) else: - self.quant_module_refactor(child_module, weight_quant_params, - act_quant_params) - - def set_quant_state(self, - weight_quant: bool = False, - act_quant: bool = False): + self._parent_name = name + self.reconstruct_model(child_module, **kwargs) + + + def reconstruct_module(self, + model: BaseQuantModel, module: Union[QuantModule, BaseQuantBlock], + iters: int = 10000, opt_mode: str = 'mse', act_quant: bool = False, + asym: bool = False, include_act_func: bool = True, b_range: tuple = (20, 2), + warmup: float = 0.0): + + model.set_quant_state(False, False) + module.set_quant_state(True, act_quant) + round_mode = 'learned_hard_sigmoid' + opt_params = [] + + if not include_act_func: + org_act_func = module.activation_function + module.activation_function = StraightThrough() + + if not act_quant: + # Replace weight quantizer to AdaRoundQuantizer and learn alpha + if isinstance(module, QuantModule): + module.weight_quantizer = AdaRoundQuantizer( + uaq = module.weight_quantizer, round_mode = round_mode, + weight_tensor = module.org_weight.data) + module.weight_quantizer.soft_targets = True + opt_params.append(module.weight_quantizer.alpha) + + if isinstance(module, BaseQuantBlock): + for name, submodule in module.named_modules(): + if isinstance(submodule, QuantModule): + submodule.weight_quantizer = AdaRoundQuantizer( + uaq = submodule.weight_quantizer, round_mode = round_mode, + weight_tensor = submodule.org_weight.data) + submodule.weight_quantizer.soft_targets = True + opt_params.append(submodule.weight_quantizer.alpha) + + optimizer = self.optim(opt_params) + scheduler = None + else: + # Use UniformAffineQuantizer to learn delta for activations + if hasattr(module.act_quantizer, 'delta'): + opt_params.append(module.act_quantizer.delta) + if isinstance(module, BaseQuantBlock): + for name, submodule in module.named_modules(): + if isinstance(submodule, QuantModule) and submodule.act_quantizer.delta is not None: + opt_params.append(submodule.act_quantizer.delta) + + optimizer = self.optim(opt_params, lr = self.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iters, eta_min=0.) + + + loss_mode = 'none' if act_quant else 'relaxation' + # rec_loss = opt_mode + loss_func = BaseQuantLoss( + module, round_loss=loss_mode, weight=self.weight, max_count=iters, + rec_loss=opt_mode, b_range=b_range, decay_start=0, warmup=warmup, p=self.p) + + # Save data before optimizing the rounding + cached_inps, cached_outs = self.save_inp_oup_data(model, module, asym, act_quant) + if opt_mode != 'mse': + cached_grads = self.save_grad_data(model, module, act_quant) + else: + cached_grads = None + + for i in range(iters): + idx = torch.randperm(cached_inps.size(0))[:self.batch_size] + cur_inp = cached_inps[idx].to(self.device) + cur_out = cached_outs[idx].to(self.device) + cur_grad = cached_grads[idx].to(self.device) if opt_mode != 'mse' else None + + optimizer.zero_grad() + out_quant = module(cur_inp) + + err = loss_func(out_quant, cur_out, cur_grad) + err.backward(retain_graph=True) + + optimizer.step() + if scheduler: + scheduler.step() + + torch.cuda.empty_cache() + + # Finish optimization, use hard rounding. + if isinstance(module, QuantModule): + module.weight_quantizer.soft_targets = False + if isinstance(module, BaseQuantBlock): + for name, submodule in module.named_modules(): + if isinstance(submodule, QuantModule): + submodule.weight_quantizer.soft_targets = False + + # Reset original activation function + if not include_act_func: + module.activation_function = org_act_func + + + def save_inp_oup_data(self, model, layer: Union[QuantModule, BaseQuantBlock], + asym: bool = False, act_quant: bool = False): """ - :param weight_quant: set True for weight quantization - :param act_quant: set True for activation quantization + Function to save input data and output data of a particular layer/block over calibration dataset. """ - for m in self.model.modules(): - if isinstance(m, (QuantModule, BaseQuantBlock)): - m.set_quant_state(weight_quant, act_quant) + get_inp_out = GetLayerInpOut(model, layer, device=self.device, asym=asym, act_quant=act_quant) + cached_batches = [] + torch.cuda.empty_cache() - def quantize_model_till(self, layer, act_quant: bool = False): - """ - :param layer: block/layer upto which model is to be quantized. - :param act_quant: set True for activation quantization - """ - self.set_quant_state(False, False) - for name, module in self.model.named_modules(): - if isinstance(module, (QuantModule, BaseQuantBlock)): - module.set_quant_state(True, act_quant) - if module == layer: - break + for i in range(int(self.calib_data.size(0) / self.batch_size)): + cur_inp, cur_out = get_inp_out(self.calib_data[i * self.batch_size:(i + 1) * self.batch_size]) + cached_batches.append((cur_inp.cpu(), cur_out.cpu())) - def forward(self, input): - return self.model(input) + cached_inps = torch.cat([x[0] for x in cached_batches]) + cached_outs = torch.cat([x[1] for x in cached_batches]) + torch.cuda.empty_cache() - def set_first_last_layer_to_8bit(self): - """Set the precision (bitwidth) used for quantizing weights and - activations to 8-bit for the first and last layers of the model. + cached_inps = cached_inps.to(self.device) + cached_outs = cached_outs.to(self.device) + return cached_inps, cached_outs + - Also ignore reconstruction for the first layer. + def save_grad_data(self, model: QuantModel, layer: Union[QuantModule, BaseQuantBlock], + act_quant: bool = False): """ - assert (len(self.quant_modules) - >= 2), 'Model has less than 2 quantization modules' - self.quant_modules[0].weight_quantizer.bitwidth_refactor(8) - self.quant_modules[0].act_quantizer.bitwidth_refactor(8) - self.quant_modules[-1].weight_quantizer.bitwidth_refactor(8) - self.quant_modules[-2].act_quantizer.bitwidth_refactor(8) - self.quant_modules[0].ignore_reconstruction = True - - def disable_network_output_quantization(self): - self.quant_modules[-1].disable_act_quant = True - - def set_layer_precision(self, weight_bit=8, act_bit=8, start=0, end=None): - """Set the precision (bitwidth) used for quantizing weights and - activations for a range of layers in the model. - - :param weight_bit: number of bits to use for quantizing weights - :param act_bit: number of bits to use for quantizing activations - :param start: index of the first layer to set the precision for - (default: 0) - :param end: index of the last layer to set the precision for (default: - None, i.e., the last layer) + Function to save gradient data of a particular layer/block over calibration dataset. """ - assert start >= 0 and end >= 0, 'layer index cannot be negative' - assert start < len(self.quant_modules) and end < len( - self.quant_modules), 'layer index out of range' - - for module in self.quant_modules[start:end + 1]: - module.weight_quantizer.bitwidth_refactor(weight_bit) - if module is not self.quant_modules[-1]: - module.act_quantizer.bitwidth_refactor(act_bit) - - def synchorize_activation_statistics(self): - """Synchronize the statistics of the activation quantizers across all - distributed workers.""" - for m in self.modules(): - if isinstance(m, QuantModule): - if m.act_quantizer.delta is not None: - m.act_quantizer.delta.data /= dist.get_world_size() - dist.all_reduce(m.act_quantizer.delta.data) + get_grad = GetLayerGrad(model, layer, self.device, act_quant=act_quant) + cached_batches = [] + torch.cuda.empty_cache() + + for i in range(int(self.calib_data.size(0) / self.batch_size)): + cur_grad = get_grad(self.calib_data[i * self.batch_size:(i + 1) * self.batch_size]) + cached_batches.append(cur_grad.cpu()) + + cached_grads = torch.cat([x for x in cached_batches]) + cached_grads = cached_grads.abs() + 1.0 + # scaling to make sure its mean is 1 + # cached_grads = cached_grads * torch.sqrt(cached_grads.numel() / cached_grads.pow(2).sum()) + torch.cuda.empty_cache() + + cached_grads = cached_grads.to(self.device) + return cached_grads \ No newline at end of file diff --git a/trailmet/algorithms/quantize/info.md b/trailmet/algorithms/quantize/info.md new file mode 100644 index 0000000..9fc278f --- /dev/null +++ b/trailmet/algorithms/quantize/info.md @@ -0,0 +1,17 @@ +![quantization](./assets/quantization_pipeline.png) + +- The following quantization configuration schema has been adopted throughout all quantizable modules to ensure compatibility with `torch.ao.quantization`. + - **qscheme**: enum[per_tensor_affine, per_tensor_symmetric, per_channel_affine, per_channel_symmetric] + - **dtype**: enum[qint8, quint8, qint32] + - **scale**: type[float] + - **zero_point**: type[int] + - **quant_min**: type[int] + - **quant_max**: type[int] + + for each quantized module `qscheme` is set according to the quantization algorithm and module under consideration, then `quant_min` and `quant_max` are determined based on the compression requirement and approximate precision importance for the given module. `dtype` is set such that it satisfies the given `quant_min` and `quant_max` range. `scale` and `zero_point` are finally determined dynamically during calibration using the applied algorithm. + +- current implementation for deployment (x86 cpu only) stores model weights in `int8`/`int32` along with scale (`float32`) and zero_point (`int64`), effectively reducing required memory space (by about a factor of 4 for layer-wise granularity, and slightly lower for channel-wise granularity). Activations are quantized to `uint8` based on scaling factors determined during calibration (static during inference). + +- when using the `x86 backend`, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the `quant_min`, `quant_max`, ie. if dtype is `torch.quint8`, we need to set `quant_min` to be 0 and `quant_max` to be 127 (255 / 2) and if dtype is `torch.qint8`, make sure to set `quant_min` to be -64 (-128 / 2) and `quant_max` to be 63 (127 / 2). This functionality is implemented and can be enabled by setting the configuration argument `reduce_range` to True. However, no need for this in `qnnpack backend`. + +![quantizer](./assets/quantizer_flow.png) diff --git a/trailmet/algorithms/quantize/lapq.py b/trailmet/algorithms/quantize/lapq.py index cc70bd9..4d13752 100644 --- a/trailmet/algorithms/quantize/lapq.py +++ b/trailmet/algorithms/quantize/lapq.py @@ -19,460 +19,205 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import copy + import torch import torch.nn as nn +import numpy as np import scipy.optimize as optim +from tqdm import tqdm from itertools import count from trailmet.utils import seed_everything -from trailmet.algorithms.quantize.quantize import ( - BaseQuantization, - Conv2dFunctor, - LinearFunctor, -) -from trailmet.algorithms.quantize.methods import ( - LearnedStepSizeQuantization, - FixedClipValueQuantization, -) -from trailmet.algorithms.quantize.qmodel import ( - ParameterModuleWrapper, - ActivationModuleWrapper, -) - -import logging -from datetime import datetime -from tqdm import tqdm -import wandb -import pandas as pd -import numpy as np -import os -import time - -from trailmet.utils import AverageMeter, accuracy, save_checkpoint - -logger = logging.getLogger(__name__) - +from trailmet.algorithms.quantize.quantize import BaseQuantModel, BaseQuantization +from trailmet.algorithms.quantize.modules import QuantModule, BaseQuantBlock +from trailmet.algorithms.quantize.methods import UniformSymmetricQuantizer, LpNormQuantizer + +supported = [ + "resnet", + "mobilenetv2" +] + +class QuantModel(BaseQuantModel): + def __init__(self, model: nn.Module, weight_quant_params: dict, act_quant_params: dict, + inplace=False, fuse_model=True): + super().__init__(model, weight_quant_params, act_quant_params, inplace, fuse_model) + self.weight_quantizers = [] + self.act_quantizers = [] + self.act_quantizers.append(self.model.inp_quant) + for module in self.model.modules(): + if isinstance(module, QuantModule): + self.weight_quantizers.append(module.weight_quantizer) + # if not module.disable_act_quant: + self.act_quantizers.append(module.act_quantizer) + elif isinstance(module, BaseQuantBlock): + self.act_quantizers.append(module.act_quantizer) + + def get_alphas_np(self, weight=True, act=True): + alphas = [] + quantizers = (self.weight_quantizers if weight else []) + (self.act_quantizers if act else []) + for quantizer in quantizers: + alphas.append(quantizer.alpha) + return torch.tensor(alphas).numpy() + + def set_alphas_np(self, alphas: np.ndarray, weight=True, act=True): + quantizers = (self.weight_quantizers if weight else []) + (self.act_quantizers if act else []) + for i, quantizer in enumerate(quantizers): + quantizer.set_params_from_alpha(alphas[i]) + class LAPQ(BaseQuantization): - """ - Parameters - ---------- - model (nn.Module): Model to be used - dataloaders (dict): Dictionary with dataloaders for train, test, val - kwargs (object): A yaml safe loaded file with information like W_BITS, A_BITS. CALIB_BATCHES, etc. - """ - - def __init__(self, model: nn.Module, dataloaders, **kwargs): - super(LAPQ, self).__init__(**kwargs) - self.model = model - self.train_loader = dataloaders['train'] - self.test_loader = dataloaders['test'] + def __init__(self, arch: str, dataloaders: dict, **kwargs): + super(LAPQ, self).__init__(dataloaders, **kwargs) + if arch not in supported: + raise ValueError(f"Network architecture '{arch}' not in supported: {supported}") self.kwargs = kwargs - self.w_bits = kwargs.get('W_BITS', 8) - self.a_bits = kwargs.get('A_BITS', 8) - self.calib_batches = kwargs.get('CALIB_BATCHES', 16) - self.act_quant = kwargs.get('ACT_QUANT', True) - self.test_before_calibration = kwargs.get('DRY_RUN', True) - self.maxiter = kwargs.get('MAX_ITER', 1) - self.maxfev = kwargs.get('MAX_FEV', 1) - self.verbose = kwargs.get('VERBOSE', True) - self.print_freq = kwargs.get('PRINT_FREQ', 20) - self.gpu_id = kwargs.get('GPU_ID', 0) - self.seed = kwargs.get('SEED', 42) + self.w_bits = kwargs.get('w_bits', 8) + self.a_bits = kwargs.get('a_bits', 8) + self.reduce_range = kwargs.get('reduce_range', True) + self.act_quant = kwargs.get('act_quant', True) + self.p_val = kwargs.get('p_val', None) + self.max_iter = kwargs.get('max_iter', 2000) + self.max_fev = kwargs.get('max_fev', 2000) + self.eval_freq = kwargs.get('eval_freq', 500) + self.verbose = kwargs.get('verbose', True) + calib_bs = kwargs.get('calib_bs', 256) + calib_size = kwargs.get('calib_size', 1024) + self.calib_data = self.get_calib_data(calib_bs, calib_size) + self.seed = kwargs.get('seed', 42) seed_everything(self.seed) - self.device = torch.device('cuda:{}'.format(self.gpu_id)) - if self.verbose: - print('==> Using seed: {} and device: cuda:{}'.format( - self.seed, self.gpu_id)) - self.calib_data = self.get_calib_samples(self.train_loader, - 64 * self.calib_batches) - self.eval_count = count(0) - self.min_loss = 1e6 - - self.wandb_monitor = self.kwargs.get('WANDB', 'False') - self.dataset_name = dataloaders['train'].dataset.__class__.__name__ - self.save = './checkpoints/' - - self.name = '_'.join([ - self.dataset_name, - f'{self.a_bits}', - datetime.now().strftime('%b-%d_%H:%M:%S'), - ]) - - os.makedirs(f'{os.getcwd()}/logs/LAPQ', exist_ok=True) - os.makedirs(self.save, exist_ok=True) - self.logger_file = f'{os.getcwd()}/logs/LAPQ/{self.name}.log' - - logging.basicConfig( - filename=self.logger_file, - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO, - ) - - logger.info(f'Experiment Arguments: {self.kwargs}') - - if self.wandb_monitor: - wandb.init(project='Trailmet LAPQ', name=self.name) - wandb.config.update(self.kwargs) + assert torch.cuda.is_available(), "GPU is required for calibration" + gpu_id = kwargs.get('gpu_id', 0) + self.device = torch.device('cuda:{}'.format(gpu_id)) + + + def compress_model(self, model: nn.Module, inplace: bool = False, + test_before_calibration: bool = True, + return_fake_quantized: bool = False, + ) -> nn.Module: + model.to(self.device) + model.eval() - def compress_model(self): - self.model.to(self.device) - self.search_absorbe_bn(self.model) - args = { - 'bit_weights': self.w_bits, - 'bit_act': self.a_bits, - 'bcorr_w': True, - 'qtype': 'lp_norm', - 'lp': 2.0, + weight_quant_params = { + 'n_bits': self.w_bits, + 'reduce_range': self.reduce_range, + 'unsigned': False, + 'symmetric': True, + 'per_channel': False, + 'quantizer': 'uniform', + 'observer': 'min_max' + } + act_quant_params = { + 'n_bits': self.a_bits, + 'reduce_range': self.reduce_range, + 'unsigned': True, + 'symmetric': False, + 'quantizer': 'uniform', + 'observer': 'min_max' } - layers = [] - layers += [ - n for n, m in self.model.named_modules() - if isinstance(m, nn.Conv2d) - ][1:-1] - if self.act_quant: - layers += [ - n for n, m in self.model.named_modules() - if isinstance(m, nn.ReLU) - ][1:-1] - layers += [ - n for n, m in self.model.named_modules() - if isinstance(m, nn.ReLU6) - ][1:-1] - - if self.test_before_calibration: - args['qtype'] = 'max_static' - cnn = copy.deepcopy(self.model) - qm = QuantModel(cnn, args, layers) - - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - qm.model, self.test_loader) - - print( - '==> Quantization (W{}A{}) accuracy before LAPQ: {:.4f} | {:.4f}' - .format(self.w_bits, self.a_bits, valid_top1_acc, - valid_top5_acc)) - logger.info( - '==> Quantization (W{}A{}) accuracy before LAPQ: {:.4f} | {:.4f}' - .format(self.w_bits, self.a_bits, valid_top1_acc, - valid_top5_acc)) - del qm, cnn - - ps = np.linspace(2, 4, 10) - losses = [] - - tk1 = tqdm(ps, total=len(ps)) - - for p in tk1: - args['qtype'] = 'lp_norm' - args['lp'] = p - cnn = copy.deepcopy(self.model) - qm = QuantModel(cnn, args, layers) - loss = self.evaluate_loss(model=qm.model, device=self.device) - losses.append(loss.item()) - tk1.set_postfix(p_val=p, loss=loss.item()) - del qm, cnn - # using quadratic interpolation to approximate the optimal quantization step size ∆p∗ - z = np.polyfit(ps, losses, 2) - y = np.poly1d(z) - p_intr = y.deriv().roots[0] - - print('==> using p intr : {:.2f}'.format(p_intr)) - logger.info('==> using p intr : {:.2f}'.format(p_intr)) - args['lp'] = p_intr - quant_model = QuantModel(self.model, args, layers) - - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - quant_model.model, self.test_loader) - lp_point = quant_model.get_clipping() - print( - '==> Quantization (W{}A{}) accuracy before Optimization: {:.4f} | {:.4f}' - .format(self.w_bits, self.a_bits, valid_top1_acc, valid_top5_acc)) - print('==> Loss after LpNormQuantization: {:.4f}'.format(valid_loss)) - print('==> Starting Powell Optimization') + if test_before_calibration: + qmodel = QuantModel(model, weight_quant_params, act_quant_params) + qmodel.set_quantization_state(False, False) + acc1, acc5 = self.test(qmodel, self.test_data, device=self.device, progress=True) + if self.verbose: + print('==> Full Precision Model: acc@1 {:.3f} | acc@5 {:.3f}'.format(acc1, acc5)) + qmodel.set_quantization_state(True, True) + _ = self.evaluate_loss(qmodel, self.calib_data, self.device) + qmodel.set_observation_state(False, False) + acc1, acc5 = self.test(qmodel, self.test_data, device=self.device, progress=True) + if self.verbose: + print('==> Quantization accuracy before LAPQ: acc@1 {:.3f} | acc@5 {:.3f}'.format(acc1, acc5)) + del qmodel + + weight_quant_params.update({ + 'observer': 'lp_norm', + 'p_val': self.p_val, + }) + act_quant_params.update({ + 'observer': 'lp_norm', + 'p_val': self.p_val, + 'pos_dist': True, + }) + + if self.p_val is None: + p_vals = np.linspace(2,3.9,20) + losses = [] + pbar = tqdm(p_vals, total=len(p_vals)) + for p in pbar: + weight_quant_params['p_val'] = p + act_quant_params['p_val'] = p + qmodel = QuantModel(model, weight_quant_params, act_quant_params) + qmodel.set_quantization_state(True, True) + loss = self.evaluate_loss(qmodel, self.calib_data, self.device) + losses.append(loss) + pbar.set_postfix(p_val=p, loss=loss) + del qmodel + # using quadratic interpolation to approximate the optimal ∆p∗ + z = np.polyfit(p_vals, losses, 2) + y = np.poly1d(z) + self.p_val = y.deriv().roots[0] + + weight_quant_params['p_val'] = self.p_val + act_quant_params['p_val'] = self.p_val + qmodel = QuantModel(model, weight_quant_params, act_quant_params, inplace=inplace) + qmodel.set_quantization_state(True, True) + min_loss = self.evaluate_loss(qmodel, self.calib_data, self.device) + if self.verbose: + print("==> using p-val : {:.3f} with lp-loss : {:.3f}".format(self.p_val, min_loss)) - logger.info( - '==> Quantization (W{}A{}) accuracy before Optimization: {:.4f} | {:.4f}' - .format(self.w_bits, self.a_bits, valid_top1_acc, valid_top5_acc)) - logger.info( - '==> Loss after LpNormQuantization: {:.4f}'.format(valid_loss)) - logger.info('==> Starting Powell Optimization') + qmodel.set_observation_state(False, False) + acc1, acc5 = self.test(qmodel, self.test_data, device=self.device, progress=True) + if self.verbose: + print('==> Quantization accuracy before optimization: acc@1 {:.3f} | acc@5 {:.3f}'.format(acc1, acc5)) + print("==> Starting Powell Optimization") + + init_alphas = qmodel.get_alphas_np() + min_method = "Powell" + min_options = {'maxiter' : self.max_iter, 'maxfev' : self.max_fev} - min_method = 'Powell' - min_options = {'maxiter': self.maxiter, 'maxfev': self.maxfev} - init_scale = lp_point.cpu().numpy() count_iter = count(0) - + self.eval_acc = 0 + self.eval_iter = 0 + def local_search_callback(x): it = next(count_iter) - quant_model.set_clipping(x, self.device) - loss = self.evaluate_loss(quant_model.model, self.device) - if self.verbose: - print('\n==> Loss at end of iter [{}] : {:.4f}\n'.format( - it, loss.item())) + print(it) + if self.verbose and it%self.eval_freq==0: + qmodel.set_alphas_np(x) + self.eval_acc, _ = self.test(qmodel, self.test_data, device=self.device, progress=False) + self.eval_iter = it + # print('\n==> Quantization accuracy at iter [{}]: acc@1 {:.2f} | acc@5 {:.2f}\n'.format(it, acc1, acc5)) - self.pbar = tqdm(total=min(self.maxiter, self.maxfev)) + self.min_loss = 1e6 + self.pbar = tqdm(total=min(self.max_iter, self.max_fev)) res = optim.minimize( - lambda scales: self.evaluate_calibration(scales, quant_model, self. - device), - init_scale, - method=min_method, - options=min_options, - callback=local_search_callback, + lambda alphas: self.evaluate_calibration(alphas, qmodel), init_alphas, + method=min_method, options=min_options, callback=local_search_callback ) self.pbar.close() - scales = res.x - print('==> Layer-wise Scales :\n', scales) - logger.info(f'==> Layer-wise Scales :{scales}') - - quant_model.set_clipping(scales, self.device) - - valid_loss, valid_top1_acc, valid_top5_acc = self.test( - quant_model.model, self.test_loader) - print('==> Full quantization (W{}A{}) accuracy: {}'.format( - self.w_bits, self.a_bits, valid_top1_acc)) - - logger.info('==> Full quantization (W{}A{}) accuracy: {}'.format( - self.w_bits, self.a_bits, valid_top1_acc)) - self.qnn = copy.deepcopy(quant_model.model) - return self.qnn - - def test(self, model, dataloader, loss_fn=nn.CrossEntropyLoss()): - batch_time = AverageMeter('Time', ':6.3f') - losses = AverageMeter('Loss', ':.4e') - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') - - epoch_iterator = tqdm( - dataloader, - desc= - 'Validating network (X / X Steps) (batch time=X.Xs) (loss=X.X) (top1=X.X) (top5=X.X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - disable=False, - ) - - model.eval() - model.to(self.device) - - with torch.no_grad(): - end = time.time() - - for i, (images, labels) in enumerate(epoch_iterator): - images = images.to(self.device, dtype=torch.float) - labels = labels.to(self.device) - - preds = model(images) - - loss = loss_fn(preds, labels) - - pred1, pred5 = accuracy(preds, labels, topk=(1, 5)) - - n = images.size(0) - losses.update(loss.item(), n) - top1.update(pred1[0], n) - top5.update(pred5[0], n) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - epoch_iterator.set_description( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) + alphas = res.x + status = res.success + if self.verbose: + print('==> Optimization completed with status:', status) + print('==> Optimized alphas :\n', alphas) + qmodel.set_alphas_np(alphas) + + acc1, acc5 = self.test(qmodel, self.test_data, device=self.device, progress=True) + if self.verbose: + print('==> Final LAPQ quantization accuracy: {:.3f} | {:.3f}'.format(acc1, acc5)) - logger.info( - 'Validating network (%d / %d Steps) (batch time=%2.5fs) (loss=%2.5f) (top1=%2.5f) (top5=%2.5f)' - % ( - (i + 1), - len(dataloader), - batch_time.val, - losses.val, - top1.val, - top5.val, - )) + if return_fake_quantized: + quantized_model = qmodel + else: + quantized_model = qmodel.convert_model_to_quantized(inplace=inplace) - if self.wandb_monitor: - wandb.log({ - 'val_loss': losses.val, - 'val_top1_acc': top1.val, - 'val_top5_acc': top5.val, - }) + return quantized_model - print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'.format( - top1=top1, top5=top5)) - return losses.avg, top1.avg, top5.avg - def evaluate_calibration(self, scales, QM, device): - eval_count = next(self.eval_count) - QM.set_clipping(scales, device) - loss = self.evaluate_loss(QM.model, device).item() + def evaluate_calibration(self, alphas: np.ndarray, qmodel: QuantModel): + qmodel.set_alphas_np(alphas) + loss = self.evaluate_loss(qmodel, self.calib_data, self.device) if loss < self.min_loss: self.min_loss = loss - # if self.verbose and eval_count%self.print_freq==0: - # print("==> iteration: {}, minimum loss so far: {:.4f}".format( - # eval_count, self.min_loss)) self.pbar.set_postfix(curr_loss=loss, min_loss=self.min_loss) self.pbar.update(1) return loss - - def evaluate_loss(self, model: nn.Module, device): - criterion = torch.nn.CrossEntropyLoss().to(device) - model.eval() - with torch.no_grad(): - if not hasattr(self, 'cal_set'): - self.cal_set = [] - for i, (images, target) in enumerate(self.train_loader): - if (i >= self.calib_batches - ): # TODO: make this robust for variable batch size - break - images = images.to(device, non_blocking=True) - target = target.to(device, non_blocking=True) - self.cal_set.append((images, target)) - res = torch.tensor([0.0]).to(device) - for i in range(len(self.cal_set)): - images, target = self.cal_set[i] - output = model(images) - loss = criterion(output, target) - res += loss - return res / len(self.cal_set) - - -class QuantModel: - """ - Parameters - ---------- - model (nn.Module): Model to be used - args (object): A yaml safe loadec file with information like bit_weights, bit_act, etc. - quantizable_layers (list): A list of the quantizable layers. - optimizer_bridge (): - """ - - def __init__(self, model, args, quantizable_layers, optimizer_bridge=None): - self.model = model - self.args = args - self.bit_weights = args['bit_weights'] - self.bit_act = args['bit_act'] - self.post_relu = True - - self.replacement_factory = { - nn.ReLU: ActivationModuleWrapper, - nn.ReLU6: ActivationModuleWrapper, - nn.Conv2d: ParameterModuleWrapper, - } - self.functor_map = { - nn.Conv2d: Conv2dFunctor, - nn.Linear: LinearFunctor, - } - self.optimizer_bridge = optimizer_bridge - - self.quantization_wrappers = [] - self.quantizable_modules = [] - self.quantizable_layers = quantizable_layers - self._pre_process_container(model) - self._create_quantization_wrappers() - self.quantization_params = LearnedStepSizeQuantization.learned_parameters( - ) - - def load_state_dict(self, state_dict): - for name, qwrapper in self.quantization_wrappers: - qwrapper.load_state_dict(state_dict) - - def freeze(self): - for n, p in self.model.named_parameters(): - # TODO: hack, make it more robust - if not np.any([qp in n for qp in self.quantization_params]): - p.requires_grad = False - - @staticmethod - def has_children(module): - try: - next(module.children()) - return True - except StopIteration: - return False - - def _create_quantization_wrappers(self): - for qm in self.quantizable_modules: - # replace module by it's wrapper - fn = (self.functor_map[type(qm.module)](qm.module) - if type(qm.module) in self.functor_map else None) - args = { - 'bits_out': self.bit_act, - 'bits_weight': self.bit_weights, - 'forward_functor': fn, - 'post_relu': self.post_relu, - 'optim_bridge': self.optimizer_bridge, - } - args.update(self.args) - if hasattr(qm, 'bn'): - args['bn'] = qm.bn - module_wrapper = self.replacement_factory[type(qm.module)]( - qm.full_name, qm.module, **args) - setattr(qm.container, qm.name, module_wrapper) - self.quantization_wrappers.append((qm.full_name, module_wrapper)) - - def _pre_process_container(self, container, prefix=''): - prev, prev_name = None, None - for name, module in container.named_children(): - # if is_bn(module) and is_absorbing(prev) and prev_name in self.quantizable_layers: - # # Pass BN module to prev module quantization wrapper for BN folding/unfolding - # self.quantizable_modules[-1].bn = module - - full_name = prefix + name - if full_name in self.quantizable_layers: - self.quantizable_modules.append( - type( - '', - (object, ), - { - 'name': name, - 'full_name': full_name, - 'module': module, - 'container': container, - }, - )()) - - if self.has_children(module): - # For container we call recursively - self._pre_process_container(module, full_name + '.') - - prev = module - prev_name = full_name - - def get_qwrappers(self): - return [ - qwrapper for (name, qwrapper) in self.quantization_wrappers - if qwrapper.__enabled__() - ] - - def set_clipping(self, clipping, - device): # TODO: handle device internally somehow - qwrappers = self.get_qwrappers() - for i, qwrapper in enumerate(qwrappers): - qwrapper.set_quantization( - FixedClipValueQuantization, - { - 'clip_value': clipping[i], - 'device': device - }, - ) - - def get_clipping(self): - clipping = [] - qwrappers = self.get_qwrappers() - for i, qwrapper in enumerate(qwrappers): - q = qwrapper.get_quantization() - clip_value = getattr(q, 'alpha') - clipping.append(clip_value.item()) - - return qwrappers[0].get_quantization().alpha.new_tensor(clipping) diff --git a/trailmet/algorithms/quantize/methods.py b/trailmet/algorithms/quantize/methods.py index 7858ef5..ebfacee 100644 --- a/trailmet/algorithms/quantize/methods.py +++ b/trailmet/algorithms/quantize/methods.py @@ -19,307 +19,379 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + + import torch import torch.nn as nn import numpy as np import scipy.optimize as optim -import warnings -from trailmet.algorithms.quantize.quantize import BaseQuantization, RoundSTE -from trailmet.utils import lp_loss __all__ = [ + 'RoundSTE', + 'FloorSTE', + 'BaseQuantizer', 'UniformAffineQuantizer', 'AdaRoundQuantizer', - 'BitSplitQuantizer', - 'ActQuantizer', - 'QuantizationBase', - 'UniformQuantization', - 'ClippedUniformQuantization', - 'FixedClipValueQuantization', - 'MaxAbsStaticQuantization', - 'LearnedStepSizeQuantization', - 'LpNormQuantization', + 'UniformSymmetricQuantizer', + 'LpNormQuantizer', + 'BitSplitQuantizer' ] -"""Quantization classes:- - -[BRECQ] - - UniformAffineQuantizer - - AdaRoundQuantizer -[BitSplit] - - BitSplitQuantizer - - ActQuantizer -[LAPQ] - - QuantizationBase - - UniformQuantization - - ClippedUniformQuantization - - FixedClipValueQuantization - - MaxAbsStaticQuantization - - LearnedStepSizeQuantization - - LpNormQuantization -""" - - -class UniformAffineQuantizer(nn.Module): - """PyTorch Function that can be used for asymmetric quantization (uniform - affine quantization). +class RoundSTE(torch.autograd.Function): + """grad enabled round function""" + @staticmethod + def forward(ctx, input): + return torch.round(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class FloorSTE(torch.autograd.Function): + """grad enabled floor function""" + @staticmethod + def forward(ctx, input): + return torch.floor(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class BaseQuantizer(nn.Module): + def __init__(self, n_bits: int, reduce_range: bool, unsigned: bool, + scale, zero_point): + super(BaseQuantizer, self).__init__() + assert 2 <= n_bits <= 32, "n_bits is outside allowed range [2, 32]" + if reduce_range: # handle qint overflow in x86 backend + n_bits -= 1 + if unsigned: # use unsigned int + self.q_max = (2 ** n_bits) - 1 + self.q_min = 0 + else: + self.q_max = (2 ** (n_bits-1)) - 1 + self.q_min = -(2 ** (n_bits-1)) + 1 + self.scale = scale + self.zero_point = zero_point + self.reduce_range = reduce_range + + def __register_buffer__(self, name, value): + if hasattr(self, name): + delattr(self, name) + self.register_buffer(name, value) + + def __register_parameter__(self, name, value): + if hasattr(self, name): + delattr(self, name) + self.register_parameter(name, nn.Parameter(value)) + + def quantize(self, x: torch.Tensor, round_mode: str): + assert None not in [self.scale, self.zero_point] + if round_mode == 'nearest': + x_int = torch.round(x / self.scale) + elif round_mode == 'nearest_ste': + x_int = RoundSTE.apply(x / self.scale) + elif round_mode == 'stochastic': + x_floor = FloorSTE.apply(x / self.scale) + x_int = x_floor + torch.bernoulli((x / self.scale) - x_floor) + else: ValueError('wrong rounding mode') + x_quant = torch.clamp(x_int + self.zero_point, self.q_min, self.q_max) + return x_quant + + def dequantize(self, xq: torch.Tensor): + xq_float = (xq - self.zero_point) * self.scale + return xq_float + + def get_qparams(self) -> dict: + return { + "scale": self.scale, + "zero_point": self.zero_point, + "quant_max": self.q_max, + "quant_min": self.q_min, + "reduce_range": self.reduce_range + } + + + +class UniformAffineQuantizer(BaseQuantizer): + """ + PyTorch Function that can be used for asymmetric quantization (uniform affine quantization). Quantizes its argument in the forward pass, passes the gradient 'straight through' on the backward pass, ignoring the quantization that occurred. - Based on - https://arxiv.org/abs/1806.08342. - Parameters - ---------- - n_bits: number of bit for quantization - symmetric: if True, the zero_point should always be 0 - channel_wise: if True, compute scale and zero_point in each channel - scale_method: determines the quantization scale and zero point + Based on https://arxiv.org/abs/1806.08342. + :param n_bits: number of bit for quantization + :param symmetric: if True, the zero_point should always be 0 + :param channel_wise: if True, compute scale and zero_point in each channel + :param scale_method: determines the quantization scale and zero point """ - - def __init__( - self, - n_bits: int = 8, - symmetric: bool = False, - channel_wise: bool = False, - scale_method: str = 'max', - leaf_param: bool = False, - ): - super(UniformAffineQuantizer, self).__init__() - self.sym = symmetric - assert 2 <= n_bits <= 8, 'bitwidth not supported' + def __init__(self, n_bits: int = 8, unsigned: bool = False, reduce_range: bool = False, + channel_wise: bool = False, scale_method: str = 'max', leaf_param: bool = False, + inited: bool = False, **kwargs): + super(UniformAffineQuantizer, self).__init__(n_bits=n_bits, reduce_range=reduce_range, + unsigned=unsigned, scale=None, zero_point=None) + self.symmetric = False self.n_bits = n_bits - self.n_levels = 2**self.n_bits - self.delta = None - self.zero_point = None - self.inited = False - self.leaf_param = leaf_param + self.unsigned = unsigned + self.reduce_range = reduce_range self.channel_wise = channel_wise + self.leaf_param = leaf_param self.scale_method = scale_method + self.inited = inited + self.eps = torch.finfo(torch.float32).eps def forward(self, x: torch.Tensor): - if self.inited is False: + if not self.inited: + scale, zero_point = self.init_quantization_params(x, self.channel_wise) if self.leaf_param: - delta, self.zero_point = self.init_quantization_scale( - x, self.channel_wise) - self.delta = torch.nn.Parameter(delta) - # self.zero_point = torch.nn.Parameter(self.zero_point) + self.__register_parameter__('scale', scale) else: - self.delta, self.zero_point = self.init_quantization_scale( - x, self.channel_wise) + self.__register_buffer__('scale', scale) + self.__register_buffer__('zero_point', zero_point) self.inited = True - - # start quantization - # x_int = BQ.round_ste(x / self.delta) + self.zero_point - x_int = RoundSTE.apply(x / self.delta) + self.zero_point - x_quant = torch.clamp(x_int, 0, self.n_levels - 1) - x_dequant = (x_quant - self.zero_point) * self.delta + # apply fake quantization + x_quant = self.quantize(x, 'nearest_ste') + x_dequant = self.dequantize(x_quant) return x_dequant - def init_quantization_scale(self, - x: torch.Tensor, - channel_wise: bool = False): - delta, zero_point = None, None + def init_quantization_params(self, x: torch.Tensor, channel_wise = False): + scale, zero_point = None, None if channel_wise: x_clone = x.clone().detach() n_channels = x_clone.shape[0] if len(x.shape) == 4: - x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max( - dim=-1)[0] + x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] else: x_max = x_clone.abs().max(dim=-1)[0] - delta = x_max.clone() + scale = x_max.clone() zero_point = x_max.clone() # determine the scale and zero point channel-by-channel for c in range(n_channels): - delta[c], zero_point[c] = self.init_quantization_scale( - x_clone[c], channel_wise=False) + scale[c], zero_point[c] = self.init_quantization_params(x_clone[c], channel_wise=False) if len(x.shape) == 4: - delta = delta.view(-1, 1, 1, 1) + scale = scale.view(-1, 1, 1, 1) zero_point = zero_point.view(-1, 1, 1, 1) else: - delta = delta.view(-1, 1) + scale = scale.view(-1, 1) zero_point = zero_point.view(-1, 1) else: if 'max' in self.scale_method: - x_min = min(x.min().item(), 0) - x_max = max(x.max().item(), 0) - if 'scale' in self.scale_method: - x_min = x_min * (self.n_bits + 2) / 8 - x_max = x_max * (self.n_bits + 2) / 8 - - x_absmax = max(abs(x_min), x_max) - if self.sym: - x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax - - delta = float(x_max - x_min) / (self.n_levels - 1) - if delta < 1e-8: - warnings.warn( - 'Quantization range close to zero: [{}, {}]'.format( - x_min, x_max)) - delta = 1e-8 - - zero_point = torch.round(-x_min / delta) - delta = torch.tensor(delta).type_as(x) - - elif self.scale_method == 'mse': + x_min, x_max = torch.aminmax(x) + scale = (x_max - x_min) / float(self.q_max - self.q_min) + scale = torch.max(scale, self.eps) + zero_point = self.q_min - torch.round(x_min/scale).to(torch.int) + zero_point = torch.clamp(zero_point, self.q_min, self.q_max) + + elif 'mse' in self.scale_method: # For Lp norm minimization as described in LAPQ - # https://arxiv.org/abs/1911.07190 - x_max = x.max() - x_min = x.min() - best_score = 1e10 - for i in range(80): - new_max = x_max * (1.0 - (i * 0.01)) - new_min = x_min * (1.0 - (i * 0.01)) - x_q = self.quantize(x, new_max, new_min) - score = lp_loss(pred=x, tgt=x_q, p=2.4, reduction='all') - if score < best_score: - best_score = score - delta = (new_max - new_min) / (2**self.n_bits - 1) - zero_point = torch.round(-new_min / delta) + x_min, x_max = torch.aminmax(x) + with torch.no_grad(): + optim_alpha = optim.minimize_scalar( + lambda alpha: self.estimate_quant_error(x, x_max, x_min, alpha), + bounds=(0.2, 1.0)).x + scale = optim_alpha * (x_max - x_min) / float(self.q_max - self.q_min) + zero_point = torch.round( -optim_alpha * x_min / scale) else: raise NotImplementedError - - return delta, zero_point - - def quantize(self, x, max, min): - delta = (max - min) / (2**self.n_bits - 1) - zero_point = torch.round(-min / delta) - # we assume weight quantization is always signed - x_int = torch.round(x / delta) - x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1) - x_float_q = (x_quant - zero_point) * delta - return x_float_q + return scale, zero_point + + def estimate_quant_error(self, x: torch.Tensor, x_max, x_min, alpha, p=2.4): + scale = alpha * (x_max - x_min) / float(self.q_max - self.q_min) + scale = torch.max(scale, self.eps) + zero_point = self.q_min - torch.round(alpha * x_min / scale).to(torch.int) + zero_point = torch.clamp(zero_point, self.q_min, self.q_max) + # we simulate fake quantization and calculate error + x_int = torch.round(x / scale) + x_quant = torch.clamp(x_int + zero_point, self.q_min, self.q_max) + x_dequant = (x_quant - zero_point) * scale + q_err = torch.mean(torch.abs(x_dequant - x) ** p) + return q_err.item() def bitwidth_refactor(self, refactored_bit: int): - assert 2 <= refactored_bit <= 8, 'bitwidth not supported' - self.n_bits = refactored_bit - self.n_levels = 2**self.n_bits + # assert refactored_bit in [2,3,4,8,16,32], 'bitwidth not supported' + if self.reduce_range: + n_bits = refactored_bit - 1 + else: + n_bits = refactored_bit + if self.unsigned: + self.q_max = (2 ** n_bits) - 1 + self.q_min = 0 + else: + self.q_max = (2 ** (n_bits-1)) - 1 + self.q_min = -((2 ** (n_bits-1)) - 1) + self.inited = False def extra_repr(self): - s = ( - 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' - ' leaf_param={leaf_param}') + s = 'bits={n_bits}, unsigned={unsigned}, symmetric={symmetric}, channel_wise={channel_wise}, ' \ + 'scale_method={scale_method}' return s.format(**self.__dict__) + + def get_qparams(self) -> dict: + qparams = super().get_qparams() + qparams.update({ + "symmetric": self.symmetric, + "channel_wise": self.channel_wise + }) + return qparams -class AdaRoundQuantizer(nn.Module): - """Adaptive Rounding Quantizer, used to optimize the rounding policy by - reconstructing the intermediate output. - - Based on Up or Down? Adaptive Rounding for Post-Training Quantization: - https://arxiv.org/abs/2004.10568 - https: //arxiv.org/abs/2004.10568 - Parameters - ---------- - uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer - round_mode: controls the forward pass in this quantizer - weight_tensor: initialize alpha +class AdaRoundQuantizer(BaseQuantizer): + """ + Adaptive Rounding Quantizer, used to optimize the rounding policy + by reconstructing the intermediate output. + Based on + Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568 + :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer + :param round_mode: controls the forward pass in this quantizer + :param weight_tensor: initialize alpha """ - def __init__( - self, - uaq: UniformAffineQuantizer, - weight_tensor: torch.Tensor, - round_mode='learned_round_sigmoid', - ): - super(AdaRoundQuantizer, self).__init__() + def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, + round_mode='learned_hard_sigmoid', **kwargs): # copying all attributes from UniformAffineQuantizer - self.n_bits = uaq.n_bits - self.sym = uaq.sym - self.delta = uaq.delta - self.zero_point = uaq.zero_point - self.n_levels = uaq.n_levels - + super(AdaRoundQuantizer, self).__init__(uaq.n_bits, uaq.reduce_range, + uaq.unsigned, uaq.scale, uaq.zero_point) self.round_mode = round_mode - self.alpha = None self.soft_targets = False + self.__register_buffer__('scale', uaq.scale) + self.__register_buffer__('zero_point', uaq.zero_point) # params for sigmoid function + self.alpha = None + self.beta = 2/3 self.gamma, self.zeta = -0.1, 1.1 - self.beta = 2 / 3 - self.init_alpha(x=weight_tensor.clone()) + self.init_alpha(x = weight_tensor.clone()) - def forward(self, x): - if self.round_mode == 'nearest': - x_int = torch.round(x / self.delta) - elif self.round_mode == 'nearest_ste': - # x_int = BQ.round_ste(x / self.delta) - x_int = RoundSTE.apply(x / self.delta) - elif self.round_mode == 'stochastic': - x_floor = torch.floor(x / self.delta) - rest = (x / self.delta) - x_floor # rest of rounding - x_int = x_floor + torch.bernoulli(rest) - print('Draw stochastic sample') - elif self.round_mode == 'learned_hard_sigmoid': - x_floor = torch.floor(x / self.delta) + def forward(self, x: torch.tensor): + if self.round_mode == 'learned_hard_sigmoid': + x_floor = FloorSTE.apply(x / self.scale) if self.soft_targets: x_int = x_floor + self.get_soft_targets() else: x_int = x_floor + (self.alpha >= 0).float() + x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) else: - raise ValueError('Wrong rounding mode') - - x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) - x_float_q = (x_quant - self.zero_point) * self.delta - - return x_float_q + x_quant = self.quantize(x, mode = self.round_mode) + x_dequant = self.dequantize(x_quant) + return x_dequant def get_soft_targets(self): - return torch.clamp( - torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, - 0, 1) + return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) def init_alpha(self, x: torch.Tensor): - x_floor = torch.floor(x / self.delta) + x_floor = FloorSTE.apply(x / self.scale) if self.round_mode == 'learned_hard_sigmoid': - # print('Init alpha to be FP32') - rest = (x / self.delta) - x_floor # rest of rounding [0, 1) - alpha = -torch.log( - (self.zeta - self.gamma) / - (rest - self.gamma) - 1) # => sigmoid(alpha) = rest - self.alpha = nn.Parameter(alpha) + rest = (x / self.scale) - x_floor # rest of rounding [0, 1) + alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # sigmoid(alpha) = rest + self.__register_parameter__('alpha', alpha) else: raise NotImplementedError + + def extra_repr(self): + s = 'bit={n_bits}, round_mode={round_mode}, symmetric={symmetric}, channel_wise={channel_wise}' + return s.format(**self.__dict__) -class BitSplitQuantizer(object): - """ - Parameters - ---------- - W (np.ndarray): Weight vector - bitwidth (int): bitwidth to be used - """ +class UniformSymmetricQuantizer(BaseQuantizer): + def __init__(self, n_bits, reduce_range, unsigned, **kwargs): + super().__init__(n_bits, reduce_range, unsigned, scale=None, zero_point=None) + self.n_bits = n_bits + self.inited = False + self.symmetric = True + self.channel_wise = False #TODO: add support for channel wise + self.alpha = None + self.eps = torch.finfo(torch.float32).eps + + def forward(self, x: torch.Tensor): + if not self.inited: + self.init_quantization_params(x) + x_quant = self.quantize(x, 'nearest_ste') + x_dequant = self.dequantize(x_quant) + return x_dequant + + def init_quantization_params(self, x: torch.Tensor): + alpha = x.abs().max().item() + self.set_params_from_alpha(alpha) + self.inited = True + + def set_params_from_alpha(self, alpha): + self.scale = max((2 * alpha) / float(self.q_max - self.q_min), self.eps) + self.zero_point = (self.q_max + self.q_min)//2 + self.__register_buffer__('alpha', torch.tensor(alpha)) + + def extra_repr(self): + s = 'alpha={alpha}, scale={scale}, zero_point={zero_point}, q_min={q_min}, q_max={q_max}, '\ + 'symmetric={symmetric}, channel_wise={channel_wise}' + return s.format(**self.__dict__) + + def get_qparams(self) -> dict: + qparams = super().get_qparams() + qparams.update({ + "symmetric": self.symmetric, + "channel_wise": self.channel_wise + }) + return qparams + + +class LpNormQuantizer(UniformSymmetricQuantizer): + def __init__(self, n_bits, reduce_range, unsigned, p_val, **kwargs): + super().__init__(n_bits, reduce_range, unsigned) + self.p = p_val + + def init_quantization_params(self, x: torch.Tensor): + with torch.no_grad(): + optim_alpha = optim.minimize_scalar(lambda alpha: self.estimate_quant_error(x, alpha), + bounds=(x.abs().min().item(), x.abs().max().item())).x + self.set_params_from_alpha(optim_alpha) + self.inited = True + def estimate_quant_error(self, x, alpha): + scale = max((2 * alpha) / (self.q_max - self.q_min), 1e-8) + zero_point = (self.q_max + self.q_min)//2 + x_int = torch.round(x / scale) + x_quant = torch.clamp(x_int + zero_point, self.q_min, self.q_max) + x_dequant = (x_quant - zero_point) * scale + q_err = torch.mean(torch.abs(x_dequant - x) ** self.p) + return q_err.item() + + + + +class BitSplitQuantizer(object): def __init__(self, W: np.ndarray, bitwidth): self.W = W self.bitwidth = bitwidth @staticmethod def splitWeightInteger(Q, bitwidth): - """Split low-bit weight integers into a list of ternary weights.""" + """ + Split low-bit weight integers into a list of ternary weights. + """ Q_sign = np.sign(Q) Q_abs = np.abs(Q) B_sav = [] - for idx in range(bitwidth - 1): - B = Q_abs - Q_abs.astype(np.int) // 2 * 2 # get current last bit + for idx in range(bitwidth-1): + B = (Q_abs - Q_abs.astype(np.int)//2*2) # get current last bit B *= Q_sign B_sav.append(B) - Q_abs = (Q_abs.astype(np.int) // 2).astype( - np.float32) # Q_abs >> 1 + Q_abs = (Q_abs.astype(np.int)//2).astype(np.float32) # Q_abs >> 1 return B_sav[::-1] @staticmethod def splitWeightVector(W, alpha): - """Get the optimal ternary vector, given the quantization scale - alpha.""" - B = W.copy() - B = (B >= 0).astype(np.float32) * 2 - 1 + """ + Get the optimal ternary vector, given the quantization scale alpha + """ + B=W.copy() + B=(B>=0).astype(np.float32)*2-1 abs_W2 = np.abs(W) * 2 - B[abs_W2 < alpha[:, np.newaxis]] = 0 + B[abs_W2 1e-9: + alpha_old = alpha*1.1 + while(np.linalg.norm(alpha-alpha_old)>1e-9): q = self.W / alpha[:, np.newaxis] q = np.round(q) q = np.clip(q, -max_val, max_val) alpha_old = alpha - alpha = np.sum(self.W * q, axis=1) / np.sum(q * q, axis=1) + alpha = np.sum(self.W*q, axis=1) / np.sum(q*q, axis=1) return q, alpha def ofwa(self, max_epoch=50): - """Optimal Fixed Point Weight Approximation Method. - + """ + Optimal Fixed Point Weight Approximation Method. Minimize weight matrix reconstruction error using bit-split strategy. - Given quantization scale, we find the 'optimal' low-bit weights that + Given quantization scale, we find the 'optimal' low-bit weights that minimizes the weight quantization error (instead of using round-off). Initialized by "fwa". """ - assert 2 <= self.bitwidth <= 16 + assert(2 <= self.bitwidth <= 16) Q, alpha = self.fwa() B_sav = self.splitWeightInteger(Q, self.bitwidth) - alpha *= 2**(self.bitwidth - 2) - # NOTE: the position of the decimal point is not at the end. + alpha *= (2**(self.bitwidth-2)) + # NOTE: the position of the decimal point is not at the end. # E.g. 4-bit fixed-point numbers could be something like: # +1.01, -0.11, +1.10, ... instead of +101, -011, +110, ... . ### iterative optimization @@ -379,23 +449,20 @@ def ofwa(self, max_epoch=50): alpha_old = np.copy(alpha) B_sum = self.stitch(B_sav) # given Ws, optimize alpha - alpha = np.sum(self.W * B_sum, axis=1) / np.sum(B_sum * B_sum, - axis=1) - if np.linalg.norm(alpha_old - alpha) <= 1e-9: + alpha = np.sum(self.W*B_sum, axis=1) / np.sum(B_sum*B_sum, axis=1) + if np.linalg.norm(alpha_old-alpha) <= 1e-9: break # given alpha, optimize Ws - for bit in range(self.bitwidth - 1): - W_res = self.W - self.stitchExclusive( - B_sav, bit) * alpha[:, np.newaxis] - B = self.splitWeightVector(W_res * (2**bit), alpha) + for bit in range(self.bitwidth-1): + W_res = self.W - self.stitchExclusive(B_sav, bit) * alpha[:, np.newaxis] + B = self.splitWeightVector(W_res*(2**bit), alpha) B_sav[bit] = B B_sum = self.stitch(B_sav) return B_sav, B_sum, alpha def ofwa_rr(self, X: np.ndarray, Y: np.ndarray, max_epoch=100): - """Optimal Fixed Point Weight Approximation with Response - Reconstruction. - + """ + Optimal Fixed Point Weight Approximation with Response Reconstruction. Minimize activation matrix reconstruction error using bit-split strategy. Initialized by "ofwa". :X: K,C,d,d @@ -408,93 +475,80 @@ def ofwa_rr(self, X: np.ndarray, Y: np.ndarray, max_epoch=100): B_sav, _, alpha = self.ofwa() X = X.reshape(X.shape[0], -1) K, N = X.shape - A = np.dot(X.T, X) # N,N + A = np.dot(X.T, X) # N,N for epoch in range(max_epoch): # given Bi, optimize alpha B_sum = self.stitch(B_sav) - XB = np.dot(X, B_sum.T) # k,m - alpha = np.einsum('ij,ij->j', Y, XB) - alpha = alpha / np.einsum('ij,ij->j', XB, XB) + XB = np.dot(X, B_sum.T) # k,m + alpha = np.einsum("ij,ij->j", Y, XB) + alpha = alpha / np.einsum("ij,ij->j", XB, XB) # given alpha, optimize Bi - for bit in range(self.bitwidth - 1): + for bit in range(self.bitwidth-1): B = B_sav[bit] - B_others = self.stitchExclusive(B_sav, bit) * alpha[:, - np.newaxis] + B_others = self.stitchExclusive(B_sav, bit) * alpha[:, np.newaxis] Y_res = Y - np.dot(X, B_others.T) - T = np.dot(Y_res.T, X) # M,N + T = np.dot(Y_res.T, X) # M,N ## fix alpha, optimize B # parallel degree: M for n in range(N): B[:, n] = 0 ABn = np.dot(A[n], B.T) - lump = 2 * (ABn * (alpha / (2**bit)) - T[:, n]) # M + lump = 2 * (ABn * (alpha/(2**bit))- T[:, n]) # M B[:, n] = -np.sign(lump) - B[np.abs(lump) < (alpha / (2**bit)) * A[n, n], n] = 0 + B[np.abs(lump) < (alpha/(2**bit)) * A[n,n], n] = 0 B_sum = self.stitch(B_sav) return B_sum, alpha def ofwa_rr_dw(self, X: np.ndarray, Y: np.ndarray, max_epoch=100): - """ + ''' # X: K,M,d,d # Y: K,M # B: M,N M kernels objective: min(Y-XWA)^2 - """ + ''' # X: M,K,9 (N=d*d) B_sav, _, alpha = self.ofwa() - X = np.transpose(X.reshape(X.shape[0], X.shape[1], -1), - (1, 0, 2)) # M, K, 9 - As = np.matmul(np.transpose(X, (0, 2, 1)), X) # M, 9, 9 + X = np.transpose(X.reshape(X.shape[0], X.shape[1], -1), (1, 0, 2)) # M, K, 9 + As = np.matmul(np.transpose(X, (0, 2, 1)), X) # M, 9, 9 alpha_bk = alpha for epoch in range(max_epoch): # given Bi, optimize alpha B_sum = self.stitch(B_sav) - XB = np.matmul(X, np.expand_dims(B_sum, axis=2)) # M, K, 1 - XB = np.squeeze(XB, axis=2) # M, K + XB = np.matmul(X, np.expand_dims(B_sum, axis=2)) # M, K, 1 + XB = np.squeeze(XB, axis=2) # M, K XB = XB.T - alpha = np.einsum('ij,ij->j', Y, XB) - alpha = alpha / np.einsum('ij,ij->j', XB, XB) + alpha = np.einsum("ij,ij->j", Y, XB) + alpha = alpha / np.einsum("ij,ij->j", XB, XB) nan_pos = np.isnan(alpha) alpha[nan_pos] = alpha_bk[nan_pos] # given alpha, optimize Bi - for bit in range(self.bitwidth - 1): + for bit in range(self.bitwidth-1): B = B_sav[bit] - B_others = self.stitchExclusive(B_sav, bit) * alpha[:, - np.newaxis] - Y_res = (Y - np.squeeze( - np.matmul(X, np.expand_dims(B_others, axis=2)), axis=2).T - ) # Y_res = Y - np.dot(X, B_others.T) - - T = np.squeeze(np.matmul(np.expand_dims(Y_res.T, axis=1), X), - axis=1) # T = np.dot(Y_res.T, X) # M,N + B_others = self.stitchExclusive(B_sav, bit) * alpha[:, np.newaxis] + Y_res = Y - np.squeeze(np.matmul(X, np.expand_dims(B_others, axis=2)), axis=2).T # Y_res = Y - np.dot(X, B_others.T) + + T = np.squeeze(np.matmul(np.expand_dims(Y_res.T, axis=1), X), axis=1) #T = np.dot(Y_res.T, X) # M,N ## fix alpha, optimize B # parallel degree: M - for n in range(9): # N=9 + for n in range(9): # N=9 B[:, n] = 0 - ABn = np.diagonal(np.dot( - As[:, n], B.T)) # M #ABn = np.dot(A[n], B.T) - lump = 2 * (ABn * (alpha / (2**bit)) - T[:, n]) # M + ABn = np.diagonal(np.dot(As[:,n], B.T)) # M #ABn = np.dot(A[n], B.T) + lump = 2 * (ABn * (alpha/(2**bit))- T[:, n]) # M B[:, n] = -np.sign(lump) - B[np.abs(lump) < (alpha / (2**bit)) * As[:, n, n], n] = 0 + B[np.abs(lump) < (alpha/(2**bit)) * As[:,n,n], n] = 0 B_sum = self.stitch(B_sav) return B_sum, alpha + class ActQuantizer(nn.Module): - """ - Parameters - ---------- - islinear (bool): - bit_width (int): bit width to be used - """ - def __init__(self, islinear=False, bit_width=8): super(ActQuantizer, self).__init__() # self.scale = None @@ -508,7 +562,7 @@ def set_bitwidth(self, bit_width): self.bit_width = bit_width if self.signed: self.max_val = (1 << (self.bit_width - 1)) - 1 - self.min_val = -self.max_val + self.min_val = - self.max_val else: self.max_val = (1 << self.bit_width) - 1 self.min_val = 0 @@ -532,387 +586,30 @@ def set_outscale(self, out_scale): self.out_scale = torch.tensor(self.out_scale).view(1, -1, 1, 1) def init_quantization(self, x): - assert np.min(x) >= 0 - circle_detection_queue = [ - 0, - ] * 5 + assert(np.min(x)>=0) + circle_detection_queue = [0,]*5 alpha = np.max(np.fabs(x)) / self.max_val alpha_old = alpha * 0 n_iter = 0 circle_detection_queue[n_iter] = alpha - while np.sum(alpha != alpha_old): + while(np.sum(alpha!=alpha_old)): q = x / alpha q = np.clip(np.round(q), self.min_val, self.max_val) alpha_old = alpha - alpha = np.sum(x * q) / np.sum(q * q) + alpha = np.sum(x*q) / np.sum(q*q) if alpha in circle_detection_queue: break n_iter += 1 - circle_detection_queue[n_iter % 5] = alpha + circle_detection_queue[n_iter%5] = alpha return alpha def forward(self, x): if self.in_scale is None: - assert self.out_scale is None + assert(self.out_scale is None) return x if not isinstance(self.in_scale, (float, np.float32, np.float64)): self.in_scale = self.in_scale.to(x.device) if not isinstance(self.out_scale, (float, np.float32, np.float64)): self.out_scale = self.out_scale.to(x.device) # return torch.clamp(torch.round(x/self.in_scale), self.min_val, self.max_val) * self.out_scale - return (torch.clamp(RoundSTE.apply(x / self.in_scale), self.min_val, - self.max_val) * self.out_scale) - - -class QuantizationBase(object): - """ - Parameters - ---------- - module (object): Module to be used - num_bits (int): Number of bits to be used - """ - - def __init__(self, module, num_bits): - self.module = module - self.num_bits = num_bits - self.num_bins = int(2**num_bits) - self.opt_params = {} - self.named_params = [] - - def register_buffer(self, name, value): - if hasattr(self.module, name): - delattr(self.module, name) - self.module.register_buffer(name, value) - setattr(self, name, getattr(self.module, name)) - - def register_parameter(self, name, value): - if hasattr(self.module, name): - delattr(self.module, name) - self.module.register_parameter(name, nn.Parameter(value)) - setattr(self, name, getattr(self.module, name)) - - self.named_params.append((name, getattr(self.module, name))) - - def __add_optim_params__(self, optim_type, dataset, params): - learnable_params = [ - d for n, d in params if n in self.learned_parameters() - ] - self.opt_params[optim_type + '_' + dataset] = learnable_params - - def optim_parameters(self): - return self.opt_params - - def loggable_parameters(self): - return self.named_parameters() - - def named_parameters(self): - named_params = [(n, p) for n, p in self.named_params - if n in self.learned_parameters()] - return named_params - - @staticmethod - def learned_parameters(): - return [] - - -class UniformQuantization(QuantizationBase): - """ - Parameters - ---------- - module (object): Module to be used - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - tails (bool): - """ - - def __init__(self, - module, - num_bits, - symmetric, - uint=False, - stochastic=False, - tails=False): - super(UniformQuantization, self).__init__(module, num_bits) - if not symmetric and not uint: - raise RuntimeError( - "Can't perform integer quantization on non symmetric distributions." - ) - self.symmetric = symmetric - self.uint = uint - self.stochastic = stochastic - self.tails = tails - if uint: - self.qmax = 2**self.num_bits - 1 - self.qmin = 0 - else: - self.qmax = 2**(self.num_bits - 1) - 1 - self.qmin = -self.qmax - 1 - if tails: - self.qmax -= 0.5 + 1e-6 - self.qmin -= 0.5 - - def __quantize__(self, tensor, alpha): - delta = (2 if self.symmetric else 1) * alpha / (self.num_bins - 1) - delta = max(delta, 1e-8) - # quantize - if self.uint and self.symmetric: - t_q = (tensor + alpha) / delta - else: - t_q = tensor / delta - # stochastic rounding - if self.stochastic and self.module.training: - with torch.no_grad(): - noise = t_q.new_empty(t_q.shape).uniform_(-0.5, 0.5) - t_q += noise - # clamp and round - t_q = torch.clamp(t_q, self.qmin, self.qmax) - t_q = RoundSTE.apply(t_q) - assert torch.unique(t_q).shape[0] <= self.num_bins - # de-quantize - if self.uint and self.symmetric: - t_q = t_q * delta - alpha - else: - t_q = t_q * delta - return t_q - - def __quantize_gemmlowp__(self, tensor, min_, max_): - assert self.uint is True - delta = (max_ - min_) / (self.num_bins - 1) - delta = max(delta, 1e-8) - # quantize - t_q = (tensor - min_) / delta - # stochastic rounding - if self.stochastic and self.module.training: - with torch.no_grad(): - noise = t_q.new_empty(t_q.shape).uniform_(-0.5, 0.5) - t_q += noise - # clamp and round - t_q = torch.clamp(t_q, self.qmin, self.qmax) - t_q = RoundSTE.apply(t_q) - assert torch.unique(t_q).shape[0] <= self.num_bins - # de-quantize - t_q = t_q * delta + min_ - return t_q - - def __for_repr__(self): - return [ - ('bits', self.num_bits), - ('symmetric', self.symmetric), - ('tails', self.tails), - ] - - def __repr__(self): - s = '{} - ['.format(type(self).__name__) - for name, value in self.__for_repr__(): - s += '{}: {}, '.format(name, value) - return s + ']' - - -class ClippedUniformQuantization(UniformQuantization): - """ - Parameters - ---------- - module (object): Module to be used - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - tails (bool): - """ - - alpha_param_name = 'alpha' - - def __init__(self, - module, - num_bits, - symmetric, - uint=False, - stochastic=False, - tails=False): - super(ClippedUniformQuantization, - self).__init__(module, num_bits, symmetric, uint, stochastic, - tails) - - def __call__(self, tensor): - t_q = self.__quantize__(tensor, self.alpha) - return t_q - - def __for_repr__(self): - rpr = super(ClippedUniformQuantization, self).__for_repr__() - return [( - self.alpha_param_name, - '{:.4f}'.format(getattr(self, self.alpha_param_name).item()), - )] + rpr - - -class FixedClipValueQuantization(ClippedUniformQuantization): - """ - Parameters - ---------- - module (object): Module to be used - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - tails (bool): - kwargs (object): A yaml safe loaded file with information like clip_value, device. - """ - - def __init__(self, - module, - num_bits, - symmetric, - uint=False, - stochastic=False, - kwargs={}): - super(FixedClipValueQuantization, - self).__init__(module, num_bits, symmetric, uint, stochastic) - self.clip_value = kwargs['clip_value'] - self.device = kwargs['device'] - with torch.no_grad(): - self.register_buffer( - self.alpha_param_name, - torch.tensor([self.clip_value], - dtype=torch.float32).to(self.device), - ) - - -class MaxAbsStaticQuantization(ClippedUniformQuantization): - """ - Parameters - ---------- - module (object): Module to be used - tensor (torch.Tensor): Tensor which wpuld be quantized - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - """ - - def __init__( - self, - module, - tensor, - num_bits, - symmetric, - uint=False, - stochastic=False, - kwargs={}, - ): - super(MaxAbsStaticQuantization, - self).__init__(module, num_bits, symmetric, uint, stochastic) - - with torch.no_grad(): - self.register_buffer(self.alpha_param_name, - tensor.new_tensor([tensor.abs().max()])) - - -class LearnedStepSizeQuantization(ClippedUniformQuantization): - """ - Parameters - ---------- - module (object): Module to be used - tensor (torch.Tensor): Tensor which wpuld be quantized - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - """ - - def __init__(self, - module, - tensor, - num_bits, - symmetric, - uint=False, - stochastic=False, - **kwargs): - super(LearnedStepSizeQuantization, - self).__init__(module, num_bits, symmetric, uint, stochastic) - - with torch.no_grad(): - maxabs = tensor.abs().max() - - self.register_parameter(self.alpha_param_name, - tensor.new_tensor([maxabs])) - - self.__create_optim_params__() - - def __create_optim_params__(self): - # TODO: create default configuration - self.__add_optim_params__( - 'SGD', - 'imagenet', - [( - self.alpha_param_name, - { - 'params': [getattr(self, self.alpha_param_name)], - 'lr': 1e-3, - 'momentum': 0, - 'weight_decay': 0, - }, - )], - ) - self.__add_optim_params__( - 'SGD', - 'cifar10', - [( - self.alpha_param_name, - { - 'params': [getattr(self, self.alpha_param_name)], - 'lr': 1e-1, - 'momentum': 0, - 'weight_decay': 0, - }, - )], - ) - - @staticmethod - def learned_parameters(): - return [LearnedStepSizeQuantization.alpha_param_name] - - -class LpNormQuantization(ClippedUniformQuantization): - """ - Parameters - ---------- - module (object): Module to be used - tensor (torch.Tensor): Tensor which wpuld be quantized - num_bits (int): Number of bits to be used - symmetric (bool): Whether the distribution is symmetric or not - uint (bool): - stochastic (bool): if True, stochastic rounding will be done - tails (bool): - kwargs (object): A yaml safe loaded file with information like lp - """ - - def __init__( - self, - module, - tensor, - num_bits, - symmetric, - uint=False, - stochastic=False, - tails=False, - kwargs={}, - ): - super(LpNormQuantization, self).__init__(module, num_bits, symmetric, - uint, stochastic, tails) - - self.p = kwargs['lp'] - with torch.no_grad(): - opt_alpha = optim.minimize_scalar( - lambda alpha: self.estimate_quant_error(alpha, tensor), - bounds=(tensor.min().item(), tensor.max().item()), - ).x - - self.register_buffer(self.alpha_param_name, - tensor.new_tensor([opt_alpha])) - - def estimate_quant_error(self, alpha, x): - xq = self.__quantize__(x, alpha) - err = torch.mean(torch.abs(xq - x)**self.p) - return err.item() + return torch.clamp(RoundSTE.apply(x/self.in_scale), self.min_val, self.max_val) * self.out_scale \ No newline at end of file diff --git a/trailmet/algorithms/quantize/modules.py b/trailmet/algorithms/quantize/modules.py new file mode 100644 index 0000000..9ab23f5 --- /dev/null +++ b/trailmet/algorithms/quantize/modules.py @@ -0,0 +1,375 @@ +# MIT License +# +# Copyright (c) 2023 Transmute AI Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, Dict, Callable +from trailmet.models.resnet import BasicBlock, Bottleneck +from trailmet.models.mobilenet import InvertedResidual +from trailmet.algorithms.quantize.methods import UniformAffineQuantizer, ActQuantizer +from trailmet.algorithms.quantize._methods import BaseQuantizer, UniformQuantizer, AdaRoundQuantizer +from trailmet.algorithms.quantize.utils import get_qscheme, get_dtype +from torch.ao.quantization import QConfig, FixedQParamsObserver +import torch.ao.nn.quantized as nnq +import torch.ao.nn.intrinsic as nni + +__all__ = [ + 'StraightThrough', + 'QuantModule', + 'BaseQuantBlock', + 'QuantBasicblock', + 'QuantBottleneck', + 'QuantInvertedResidual', + # old modules kept temporarily for BC + 'QBasicblock' + 'QBottleneck', + 'QInvertedResidual', +] + +QUANTIZER_MAPPING: Dict[str, Callable] = { + 'uniform': UniformQuantizer, + 'adaround': AdaRoundQuantizer +} +class StraightThrough(nn.Module): + """ + Identity Layer, same as torch.nn.modules.linear.Identity + """ + def __int__(self, *args, **kwargs) -> None: + super().__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input + +class QuantModule(nn.Module): + """ + Wrapper Module to simulate fake quantization + """ + def __init__(self, orig_module: nn.Module, weight_qparams: dict, act_qparams: dict): + super().__init__() + self.orig_module = orig_module + if isinstance(orig_module, (nni.ConvReLU2d, nni.LinearReLU)): + assert len(orig_module)==2 + if type(orig_module[0]) == nn.Conv2d: + self.fwd_func = F.conv2d + self.fwd_kwargs = dict( + stride = orig_module[0].stride, + padding = orig_module[0].padding, + dilation = orig_module[0].dilation, + groups = orig_module[0].groups + ) + elif type(orig_module[0]) == nn.Linear: + self.fwd_func = F.linear + self.fwd_kwargs = dict() + else: + raise NotImplementedError + + if type(orig_module[1]) == nn.ReLU: + self.fwd_post = F.relu + else: + raise NotImplementedError + + self.weight = orig_module[0].weight + self.orig_weight = orig_module[0].weight.data.clone() + self.bias = orig_module[0].bias + + elif isinstance(orig_module, (nn.Conv2d, nn.Linear)): + if type(orig_module) == nn.Conv2d: + self.fwd_func = F.conv2d + self.fwd_kwargs = dict( + stride = orig_module.stride, + padding = orig_module.padding, + dilation = orig_module.dilation, + groups = orig_module.groups + ) + elif type(orig_module) == nn.Linear: + self.fwd_func = F.linear + self.fwd_kwargs = dict() + else: + raise NotImplementedError + + self.fwd_post = self.identity + self.weight = orig_module.weight + self.orig_weight = orig_module.weight.data.clone() + self.bias = orig_module.bias + + else: + raise NotImplementedError + + self.weight_quantizer: BaseQuantizer = QUANTIZER_MAPPING[weight_qparams.get( + 'quantizer', 'uniform')](weight_qparams) + self.act_quantizer: BaseQuantizer = QUANTIZER_MAPPING[act_qparams.get( + 'quantizer', 'uniform')](act_qparams) + + self.use_act_quant = False + self.use_weight_quant = False + self.ignore_reconstruction = False + self.extra_repr = orig_module.extra_repr + + + def forward(self, input: torch.Tensor): + if self.use_weight_quant: + weight = self.weight_quantizer(self.weight) + bias = self.bias + else: + weight = self.orig_weight + bias = self.bias + out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) + out = self.fwd_post(out) + if self.use_act_quant: + out = self.act_quantizer(out) + return out + + + def identity(self, x: torch.Tensor): + return x + + def set_observation_state(self, weight_obs: bool, act_obs: bool): + self.weight_quantizer.enable_observation = weight_obs + self.act_quantizer.enable_observation = act_obs + + def set_quantization_state(self, weight_quant: bool, act_quant: bool): + self.use_weight_quant = weight_quant + self.use_act_quant = act_quant + + + +class BaseQuantBlock(nn.Module): + def __init__(self, act_qparams: dict) -> None: + super().__init__() + self.act_quantizer: BaseQuantizer = QUANTIZER_MAPPING[act_qparams.get( + 'quantizer', 'uniform')](act_qparams) + self.use_act_quant = False + self._fake_quantization = True + self.ignore_reconstruction = False + + def set_observation_state(self, weight_obs: bool, act_obs: bool): + self.act_quantizer.enable_observation = act_obs + for module in self.modules(): + if isinstance(module, QuantModule): + module.set_observation_state(weight_obs, act_obs) + + def set_quantization_state(self, weight_quant: bool, act_quant: bool): + self.use_act_quant = act_quant + for module in self.modules(): + if isinstance(module, QuantModule): + module.set_quantization_state(weight_quant, act_quant) + + + def _convert_to_quantizable_with_qconfig(self, module: nn.Module): + self._fake_quantization = False + module_attach = dict() + module_reassign = dict() + + for name, submodule in module.named_modules(): + if isinstance(submodule, QuantModule): + module_attach[name]['weight'] = submodule.weight_quantizer.observer + module_attach[name]['activation'] = submodule.act_quantizer.observer + module_reassign[name] = submodule.orig_module + + for name, orig_module in module_reassign.items(): + delattr(module, name) + setattr(module, name, orig_module) + + for name, observers in module_attach.items(): + submodule = getattr(module, name, None) + assert submodule is not None + if isinstance(submodule, nni.ConvReLU2d): # propagate qconfig + setattr(submodule[0], 'qconfig', QConfig( + weight = observers['weight'], + activation = None + )) + setattr(submodule, 'qconfig', QConfig( + weight = observers['weight'], + activation = observers['activation'] + )) + submodule.add_module('activation_post_process', submodule.qconfig.activation()) + + + +class QuantBasicBlock(BaseQuantBlock): + def __init__(self, basicblock: BasicBlock, weight_qparams: dict, act_qparams: dict): + super().__init__(act_qparams) + # assuming all bn and relu are fused in conv + self.conv1 = QuantModule(basicblock.conv1, weight_qparams, act_qparams) + self.conv2 = QuantModule(basicblock.conv2, weight_qparams, act_qparams) + if basicblock.downsample is not None: + self.downsample = QuantModule(basicblock.downsample[0], weight_qparams, act_qparams) + else: + self.downsample = None + self.add_skip = nnq.FloatFunctional() + + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + skip = inp if self.downsample is None else self.downsample(inp) + out = self.conv1(inp) + out = self.conv2(out) + out = self.add_skip.add_relu(out, skip) + if self._fake_quantization and self.use_act_quant: + out = self.act_quantizer(out) + return out + + + def _convert_to_quantizable_with_qconfig(self): + super()._convert_to_quantizable_with_qconfig(self) + setattr(self.add_skip, 'qconfig', QConfig( + weight = None, + activation = self.act_quantizer.observer + )) + self.add_skip.add_module('activation_post_process', + self.add_skip.qconfig.activation()) + + + +class QuantBottleneck(BaseQuantBlock): + def __init__(self, bottleneck: Bottleneck, weight_qparams: dict, act_qparams: dict) -> None: + super().__init__(act_qparams) + # assuming all bn and relu are fused in conv + self.conv1 = QuantModule(bottleneck.conv1, weight_qparams, act_qparams) # ConvReLU2d + self.conv2 = QuantModule(bottleneck.conv2, weight_qparams, act_qparams) # ConvReLU2d + self.conv3 = QuantModule(bottleneck.conv3, weight_qparams, act_qparams) # ConvReLU2d + if bottleneck.downsample is not None: + self.downsample = QuantModule(bottleneck.downsample[0], weight_qparams, act_qparams) # Conv2d + else: + self.downsample = None + self.add_skip = nnq.FloatFunctional() + + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + skip = inp if self.downsample is None else self.downsample(inp) + out = self.conv1(inp) + out = self.conv2(out) + out = self.conv3(out) + out = self.add_skip.add_relu(out, skip) + if self._fake_quantization and self.use_act_quant: + out = self.act_quantizer(out) + return out + + + def _convert_to_quantizable_with_qconfig(self): + super()._convert_to_quantizable_with_qconfig(self) + setattr(self.add_skip, 'qconfig', QConfig( + weight = None, + activation = self.act_quantizer.observer + )) + self.add_skip.add_module('activation_post_process', + self.add_skip.qconfig.activation()) + + + +class QuantInvertedResidual(BaseQuantBlock): + def __init__(self, act_quant_params: dict = {}) -> None: + super().__init__(act_quant_params) + + + + + +class QBasicBlock(nn.Module): + expansion = 1 + def __init__(self, basic_block: BasicBlock): + super().__init__() + self.quant1 = ActQuantizer() + self.conv1 = basic_block.conv1 + self.bn1 = basic_block.bn1 + self.activ = basic_block.activ + self.quant2 = ActQuantizer() + self.conv2 = basic_block.conv2 + self.bn2 = basic_block.bn2 + self.downsample = basic_block.downsample + self.stride = basic_block.stride + + def forward(self, x): + residual = x + x = self.quant1(x) + out = self.activ(self.bn1(self.conv1(x))) + out = self.quant2(out) + out = self.bn2(self.conv2(out)) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.activ(out) + return out + + +class QBottleneck(nn.Module): + expansion = 4 + def __init__(self, bottleneck: Bottleneck): + super().__init__() + self.quant1 = ActQuantizer() + self.conv1 = bottleneck.conv1 + self.bn1 = bottleneck.bn1 + self.quant2 = ActQuantizer() + self.conv2 = bottleneck.conv2 + self.bn2 = bottleneck.bn2 + self.quant3 = ActQuantizer() + self.conv3 = bottleneck.conv3 + self.bn3 = bottleneck.bn3 + self.activ = bottleneck.activ + self.downsample = bottleneck.downsample + self.stride = bottleneck.stride + + def forward(self, x): + residual = x + x = self.quant1(x) + out = self.activ(self.bn1(self.conv1(x))) + out = self.quant2(out) + out = self.activ(self.bn2(self.conv2(out))) + out = self.quant3(out) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.activ(out) + return out + + +class QInvertedResidual(nn.Module): + def __init__(self, inv_res: InvertedResidual): + super().__init__() + self.stride = inv_res.stride + self.inp = inv_res.inp + self.oup = inv_res.oup + self.exp = inv_res.exp + self.quant1 = ActQuantizer(islinear=1) + self.conv1 = inv_res.conv1 + self.bn1 = inv_res.bn1 + self.quant2 = ActQuantizer(islinear=1) + self.conv2 = inv_res.conv2 + self.bn2 = inv_res.bn2 + self.quant3 = ActQuantizer(islinear=0) + self.conv3 = inv_res.conv3 + self.bn3 = inv_res.bn3 + self.shortcut = inv_res.shortcut + + def forward(self, x): + x = self.quant1(x) + out = F.relu(self.bn1(self.conv1(x))) + out = self.quant2(out) + out = F.relu(self.bn2(self.conv2(out))) + out = self.quant3(out) + out = self.bn3(self.conv3(out)) + out = out + self.shortcut(x) if self.stride==1 else out + return out + + + \ No newline at end of file diff --git a/trailmet/algorithms/quantize/observers.py b/trailmet/algorithms/quantize/observers.py new file mode 100644 index 0000000..4a93213 --- /dev/null +++ b/trailmet/algorithms/quantize/observers.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +import scipy.optimize as optim +from trailmet.algorithms.quantize.utils import get_dtype, get_qscheme, \ + transform_and_flatten_tensor_by_channel, reshape_qparams_by_channel, \ + fake_quantize + +class BaseObserver(nn.Module): + def __init__(self, n_bits: int = 8, reduce_range: bool = True, unsigned: bool = False, + symmetric: bool = False, per_channel: bool = False): + super(BaseObserver, self).__init__() + self.n_bits = n_bits + assert 2 <= n_bits <= 32, "n_bits is outside allowed range [2, 32]" + + if reduce_range: + n_bits -= 1 + if unsigned: + self.quant_min = 0 + self.quant_max = (2 ** n_bits) - 1 + else: + self.quant_min = -(2 ** (n_bits - 1)) + self.quant_max = (2 ** (n_bits - 1)) - 1 + + self.reduce_range = reduce_range + self.unsigned = unsigned + self.symmetric = symmetric + self.per_channel = per_channel + + self.eps = torch.tensor(1e-8, dtype=torch.float32) + self.dtype = get_dtype(self.quant_min, self.quant_max, self.reduce_range) + self.qscheme = get_qscheme(self.per_channel, self.symmetric) + self.register_buffer("min_val", torch.tensor(float("inf"))) + self.register_buffer("max_val", torch.tensor(float("-inf"))) + if (unsigned and symmetric and per_channel and reduce_range): + raise NotImplementedError( + "cannot reduce range for per-channel-symmetric unsigned quantization" + ) + self.inited = False + + def reset_bitwidth(self, n_bits): + self.n_bits = n_bits + assert 2 <= n_bits <= 32, "n_bits is outside allowed range [2, 32]" + if self.reduce_range: + n_bits -= 1 + if self.unsigned: + self.quant_min = 0 + self.quant_max = (2 ** n_bits) - 1 + else: + self.quant_min = -(2 ** (n_bits - 1)) + self.quant_max = (2 ** (n_bits - 1)) - 1 + + def reset_min_max_vals(self): + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + self.inited = False + + def forward(self, x: torch.Tensor): + # update min_val and max_val from x and make inited true + return x + + @torch.jit.export + def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor): + quant_min, quant_max = self.quant_min, self.quant_max + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val.device + scale = torch.ones(min_val.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val.size(), dtype=torch.int64, device=device) + + if self.symmetric: + abs_max_val = torch.max(-min_val_neg, max_val_pos) + scale = (2 * abs_max_val) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + if self.unsigned: + zero_point = zero_point.new_full(zero_point.size(), (quant_min + quant_max) // 2) + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # for scalar values, cast them to tensors of size 1 to keep the shape consistent + if len(scale.shape) == 0: + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype, device=device) + + return scale, zero_point + + @torch.jit.export + def calculate_qparams(self): + assert self.inited, "need to run observation atleast once" + return self._calculate_qparams(self.min_val, self.max_val) + + + +class MinMaxObserver(BaseObserver): + def __init__(self, n_bits: int = 8, reduce_range: bool = True, unsigned: bool = False, + symmetric: bool = False, per_channel: bool = False, ch_axis: int = 0, **kwargs): + super().__init__(n_bits, reduce_range, unsigned, symmetric, per_channel) + self.ch_axis = ch_axis + + def forward(self, x_orig: torch.Tensor): + if x_orig.numel() == 0: + return x_orig + # dtype must match because updates to buffers are done inplace + x = x_orig.clone().detach().to(self.min_val.dtype) + + if self.per_channel: + y = transform_and_flatten_tensor_by_channel(x, self.ch_axis) + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + + if not self.inited: + self.min_val = min_val_cur + self.max_val = max_val_cur + self.inited = True + else: + self.min_val = torch.min(self.min_val, min_val_cur) + self.max_val = torch.max(self.max_val, max_val_cur) + + return x_orig + + + +class LpNormObserver(BaseObserver): + def __init__(self, n_bits: int = 8, reduce_range: bool = True, unsigned: bool = False, + symmetric: bool = False, per_channel: bool = False, ch_axis: int = 0, + p_val: float = 2.4, num_iters: int = 1000, pos_dist: bool = False, **kwargs): + super().__init__(n_bits, reduce_range, unsigned, symmetric, per_channel) + self.pos_dist = pos_dist + self.ch_axis = ch_axis + self.num_iters = num_iters + self.p = p_val + + def lp_loss(self, pred: torch.Tensor, trgt: torch.Tensor, + p: float = 2.4, per_channel: bool = False): + err = (pred - trgt).abs().pow(p) + if per_channel: + err_ = transform_and_flatten_tensor_by_channel(err, self.ch_axis) + return err_.mean(1) + else: + return err.mean() + + def get_quant_loss_from_range(self, x, min_val, max_val): + scale, zero_point = self._calculate_qparams(min_val, max_val) + if self.per_channel: + scale, zero_point = reshape_qparams_by_channel(x, scale, zero_point, self.ch_axis) + x_q = fake_quantize(x, scale, zero_point, self.quant_min, self.quant_max) + loss = self.lp_loss(x_q, x, self.p, self.per_channel) + return loss + + def get_quant_loss_from_alpha(self, x, alpha, x_min, x_max): + min_val, max_val = x_min * alpha, x_max * alpha + scale, zero_point = self._calculate_qparams(min_val, max_val) + x_q = fake_quantize(x, scale, zero_point, self.quant_min, self.quant_max) + loss = self.lp_loss(x_q, x, self.p, False) + return loss.item() + + def perform_linear_1D_search(self, x: torch.Tensor): + pass + + def perform_fast_1D_search(self, x: torch.Tensor): + if self.per_channel: + alphas = [] + x_ = transform_and_flatten_tensor_by_channel(x, self.ch_axis) + x_min, x_max = torch.aminmax(x_, dim=1) + if self.pos_dist: + x_min = torch.zeros_like(x_min) + + for ch in range(len(x_)): + x_ch = x_[ch] + ch_min, ch_max = x_min[ch], x_max[ch] + optim_alpha = optim.minimize_scalar( + lambda alpha: self.get_quant_loss_from_alpha(x_ch, alpha, ch_min, ch_max), + bounds=(0.2, 1.0)).x + alphas.append(optim_alpha) + + alphas = torch.tensor(alphas, dtype=torch.float32, device=x.device) + min_val, max_val = x_min * alphas, x_max * alphas + + else: + x_min, x_max = torch.aminmax(x) + if self.pos_dist: + x_min = torch.zeros_like(x_min) + optim_alpha = optim.minimize_scalar( + lambda alpha: self.get_quant_loss_from_alpha(x, alpha, x_min, x_max), + bounds=(0.2, 1.0)).x + min_val, max_val = x_min * optim_alpha, x_max * optim_alpha + + return min_val, max_val + + + def forward(self, x_orig: torch.Tensor): + if x_orig.numel() == 0: + return x_orig + x = x_orig.clone().detach().to(self.min_val.dtype) + + if self.symmetric or self.pos_dist: + min_val_cur, max_val_cur = self.perform_fast_1D_search(x) + else: + raise NotImplementedError + + if not self.inited: + self.min_val = min_val_cur + self.max_val = max_val_cur + self.inited = True + else: + self.min_val = torch.min(self.min_val, min_val_cur) + self.max_val = torch.max(self.max_val, max_val_cur) + + return x_orig + + diff --git a/trailmet/algorithms/quantize/qmodel.py b/trailmet/algorithms/quantize/qmodel.py deleted file mode 100644 index 8fd2862..0000000 --- a/trailmet/algorithms/quantize/qmodel.py +++ /dev/null @@ -1,651 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Transmute AI Lab -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Union -from trailmet.models.resnet import BasicBlock, Bottleneck -from trailmet.models.mobilenet import InvertedResidual -from trailmet.algorithms.quantize.quantize import StraightThrough -from trailmet.algorithms.quantize.methods import ( - MaxAbsStaticQuantization, - LpNormQuantization, -) -from trailmet.algorithms.quantize.methods import UniformAffineQuantizer -from trailmet.algorithms.quantize.methods import ActQuantizer - -__all__ = [ - 'QuantBasicBlock', - 'QuantBottleneck', - 'QuantInvertedResidual', - 'QuantModule', - 'BaseQuantBlock', - 'QBasicBlock', - 'QBottleneck', - 'QInvertedResidual', - 'ActivationModuleWrapper', - 'ParameterModuleWrapper', -] -# ============================================ -# ***** Quantization Modules for BRECQ ******* -# ============================================ -""" -Supported quantization wrappers for pytorch modules :- - - BasicBlock(nn.Module) -> QuantBasicBlock(BaseQuantBlock(nn.Module)) - - Bottleneck(nn.Module) -> QuantBottleneck(BaseQuantBlock(nn.Module)) - - InvertedResidual(nn.Module) -> QuantInvertedResidual(BaseQuantBlock(nn.Module)) - - nn.Conv2d, nn.Linear -> QuantModule(nn.Module) -""" - - -class QuantModule(nn.Module): - """Quantized Module that can perform quantized convolution or normal - convolution. - - To activate quantization, please use set_quant_state function. - - Parameters - ---------- - org_module (nn.Module): Module to be used - weight_quant_params (dict): Weight parameters - act_quant_params (dict): Activation Parameters - disable_act_quant (bool): if True, activation layer will be disabled - se_module (nn.Module): SE Module to be used - """ - - def __init__( - self, - org_module: Union[nn.Conv2d, nn.Linear], - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - disable_act_quant: bool = False, - se_module=None, - ): - super(QuantModule, self).__init__() - if isinstance(org_module, nn.Conv2d): - self.fwd_kwargs = dict( - stride=org_module.stride, - padding=org_module.padding, - dilation=org_module.dilation, - groups=org_module.groups, - ) - self.fwd_func = F.conv2d - else: - self.fwd_kwargs = dict() - self.fwd_func = F.linear - self.weight = org_module.weight - self.org_weight = org_module.weight.data.clone() - if org_module.bias is not None: - self.bias = org_module.bias - self.org_bias = org_module.bias.data.clone() - else: - self.bias = None - self.org_bias = None - # de-activate the quantized forward default - self.use_weight_quant = False - self.use_act_quant = False - self.disable_act_quant = disable_act_quant - # initialize quantizer - self.weight_quantizer = UniformAffineQuantizer(**weight_quant_params) - self.act_quantizer = UniformAffineQuantizer(**act_quant_params) - - self.activation_function = StraightThrough() - self.ignore_reconstruction = False - - self.se_module = se_module - self.extra_repr = org_module.extra_repr - - def forward(self, input: torch.Tensor): - if self.use_weight_quant: - weight = self.weight_quantizer(self.weight) - bias = self.bias - else: - weight = self.org_weight - bias = self.org_bias - out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) - # disable act quantization is designed for convolution before elemental-wise operation, - # in that case, we apply activation function and quantization after ele-wise op. - if self.se_module is not None: - out = self.se_module(out) - out = self.activation_function(out) - if self.disable_act_quant: - return out - if self.use_act_quant: - out = self.act_quantizer(out) - return out - - def set_quant_state(self, - weight_quant: bool = False, - act_quant: bool = False): - self.use_weight_quant = weight_quant - self.use_act_quant = act_quant - - -class BaseQuantBlock(nn.Module): - """Base implementation of block structures for all networks. - - Due to the branch architecture, we have to perform activation function and - quantization after the elemental-wise add operation, therefore, we put this - part in this class. - - Parameters - ---------- - act_quant_params (dict): Activation parameters - """ - - def __init__(self, act_quant_params: dict = {}): - super().__init__() - self.use_weight_quant = False - self.use_act_quant = False - # initialize quantizer - - self.act_quantizer = UniformAffineQuantizer(**act_quant_params) - self.activation_function = StraightThrough() - - self.ignore_reconstruction = False - - def set_quant_state(self, - weight_quant: bool = False, - act_quant: bool = False): - # setting weight quantization here does not affect actual forward pass - self.use_weight_quant = weight_quant - self.use_act_quant = act_quant - for m in self.modules(): - if isinstance(m, QuantModule): - m.set_quant_state(weight_quant, act_quant) - - -class QuantBasicBlock(BaseQuantBlock): - """Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34. - - Parameters - ---------- - basic_block (object): BasicBlock which is to be used - weight_quant_params (dict): Weight parameters - act_quant_params (dict): Activation Parameters - """ - - def __init__( - self, - basic_block: BasicBlock, - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - ): - super().__init__(act_quant_params) - self.conv1 = QuantModule(basic_block.conv1, weight_quant_params, - act_quant_params) - self.conv1.activation_function = basic_block.active - self.conv2 = QuantModule( - basic_block.conv2, - weight_quant_params, - act_quant_params, - disable_act_quant=True, - ) - - # modify the activation function to ReLU - self.activation_function = basic_block.active - - if basic_block.downsample is None: - self.downsample = None - else: - self.downsample = QuantModule( - basic_block.downsample[0], - weight_quant_params, - act_quant_params, - disable_act_quant=True, - ) - # copying all attributes in original block - self.stride = basic_block.stride - - def forward(self, x): - residual = x if self.downsample is None else self.downsample(x) - out = self.conv1(x) - out = self.conv2(out) - out += residual - out = self.activation_function(out) - if self.use_act_quant: - out = self.act_quantizer(out) - return out - - -class QuantBottleneck(BaseQuantBlock): - """ - Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and - -152. - - Parameters - ---------- - bottleneck (object): Bottleneck to be used - weight_quant_params (dict): Weight parameters - act_quant_params (dict): Activation Parameters - """ - - def __init__( - self, - bottleneck: Bottleneck, - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - ): - super().__init__(act_quant_params) - self.conv1 = QuantModule(bottleneck.conv1, weight_quant_params, - act_quant_params) - self.conv1.activation_function = bottleneck.active - self.conv2 = QuantModule(bottleneck.conv2, weight_quant_params, - act_quant_params) - self.conv2.activation_function = bottleneck.active - self.conv3 = QuantModule( - bottleneck.conv3, - weight_quant_params, - act_quant_params, - disable_act_quant=True, - ) - - # modify the activation function to ReLU - self.activation_function = bottleneck.active - - if bottleneck.downsample is None: - self.downsample = None - else: - self.downsample = QuantModule( - bottleneck.downsample[0], - weight_quant_params, - act_quant_params, - disable_act_quant=True, - ) - # copying all attributes in original block - self.stride = bottleneck.stride - - def forward(self, x): - residual = x if self.downsample is None else self.downsample(x) - out = self.conv1(x) - out = self.conv2(out) - out = self.conv3(out) - out += residual - out = self.activation_function(out) - if self.use_act_quant: - out = self.act_quantizer(out) - return out - - -class QuantInvertedResidual(BaseQuantBlock): - """Implementation of Quantized Inverted Residual Block used in MobileNetV2. - - Inverted Residual does not have activation function. - - Parameters - ---------- - inv_res (object): Inverted Residual block to be used - weight_quant_params (dict): Weight parameters - act_quant_params (dict): Activation Parameters - """ - - def __init__( - self, - inv_res: InvertedResidual, - weight_quant_params: dict = {}, - act_quant_params: dict = {}, - ): - super().__init__(act_quant_params) - self.stride = inv_res.stride - self.inp = inv_res.inp - self.oup = inv_res.oup - self.exp = inv_res.exp - self.conv1 = QuantModule(inv_res.conv1, weight_quant_params, - act_quant_params) - self.conv1.activation_function = nn.ReLU6(inplace=True) - self.conv2 = QuantModule(inv_res.conv2, weight_quant_params, - act_quant_params) - self.conv2.activation_function = nn.ReLU6(inplace=True) - self.conv3 = QuantModule(inv_res.conv3, weight_quant_params, - act_quant_params) - self.shortcut = nn.Sequential() - if self.stride == 1 and self.inp != self.oup: - self.shortcut = nn.Sequential( - QuantModule(inv_res.shortcut[0], weight_quant_params, - act_quant_params)) - # self.use_res_connect = inv_res.use_res_connect - # self.expand_ratio = inv_res.exp - # if self.expand_ratio == 1: - # self.conv = nn.Sequential( - # QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), - # QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params, disable_act_quant=True), - # ) - # self.conv[0].activation_function = nn.ReLU6() - # else: - # self.conv = nn.Sequential( - # QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), - # QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params), - # QuantModule(inv_res.conv[6], weight_quant_params, act_quant_params, disable_act_quant=True), - # ) - # self.conv[0].activation_function = nn.ReLU6() - # self.conv[1].activation_function = nn.ReLU6() - - def forward(self, x): - out = self.conv1(x) - out = self.conv2(out) - out = self.conv3(out) - out = out + self.shortcut(x) if self.stride == 1 else out - return out - # if self.use_res_connect: - # out = x + self.conv(x) - # else: - # out = self.conv(x) - # out = self.activation_function(out) - # if self.use_act_quant: - # out = self.act_quantizer(out) - # return out - - -# =============================================== -# ***** Quantization Modules for BitSplit ******* -# =============================================== -""" -Supported quantization wrappers for pytorch modules :- - - BasicBlock(nn.Module) -> QBasicBlock(nn.Module) - - Bottleneck(nn.Module) -> QBottleneck(nn.Module) - - InvertedResidual(nn.Module) -> QInvertedResidual(nn.Module) -""" - - -class QBasicBlock(nn.Module): - """ - Parameters - ---------- - basic_block (object): BasicBlock which is to be used - """ - - expansion = 1 - - def __init__(self, basic_block: BasicBlock): - super().__init__() - self.quant1 = ActQuantizer() - self.conv1 = basic_block.conv1 - self.bn1 = basic_block.bn1 - self.active = basic_block.active - self.quant2 = ActQuantizer() - self.conv2 = basic_block.conv2 - self.bn2 = basic_block.bn2 - self.downsample = basic_block.downsample - self.stride = basic_block.stride - - def forward(self, x): - residual = x - x = self.quant1(x) - out = self.active(self.bn1(self.conv1(x))) - out = self.quant2(out) - out = self.bn2(self.conv2(out)) - if self.downsample is not None: - residual = self.downsample(x) - out += residual - out = self.active(out) - return out - - -class QBottleneck(nn.Module): - """ - Parameters - ---------- - bottleneck (object): Bottleneck to be used - """ - - expansion = 4 - - def __init__(self, bottleneck: Bottleneck): - super().__init__() - self.quant1 = ActQuantizer() - self.conv1 = bottleneck.conv1 - self.bn1 = bottleneck.bn1 - self.quant2 = ActQuantizer() - self.conv2 = bottleneck.conv2 - self.bn2 = bottleneck.bn2 - self.quant3 = ActQuantizer() - self.conv3 = bottleneck.conv3 - self.bn3 = bottleneck.bn3 - self.active = bottleneck.active - self.downsample = bottleneck.downsample - self.stride = bottleneck.stride - - def forward(self, x): - residual = x - x = self.quant1(x) - out = self.active(self.bn1(self.conv1(x))) - out = self.quant2(out) - out = self.active(self.bn2(self.conv2(out))) - out = self.quant3(out) - out = self.bn3(self.conv3(out)) - if self.downsample is not None: - residual = self.downsample(x) - out += residual - out = self.active(out) - return out - - -class QInvertedResidual(nn.Module): - """ - Parameters - ---------- - inv_res (object): Inverted Residual block to be used - """ - - def __init__(self, inv_res: InvertedResidual): - super().__init__() - self.stride = inv_res.stride - self.inp = inv_res.inp - self.oup = inv_res.oup - self.exp = inv_res.exp - self.quant1 = ActQuantizer(islinear=1) - self.conv1 = inv_res.conv1 - self.bn1 = inv_res.bn1 - self.quant2 = ActQuantizer(islinear=1) - self.conv2 = inv_res.conv2 - self.bn2 = inv_res.bn2 - self.quant3 = ActQuantizer(islinear=0) - self.conv3 = inv_res.conv3 - self.bn3 = inv_res.bn3 - self.shortcut = inv_res.shortcut - - def forward(self, x): - x = self.quant1(x) - out = F.relu(self.bn1(self.conv1(x))) - out = self.quant2(out) - out = F.relu(self.bn2(self.conv2(out))) - out = self.quant3(out) - out = self.bn3(self.conv3(out)) - out = out + self.shortcut(x) if self.stride == 1 else out - return out - - -# =========================================== -# ***** Quantization Modules for LAPQ ******* -# =========================================== -""" -Supported quantization wrappers for pytorch modules :- - - nn.ReLU, nn.ReLU6 -> ActivationModuleWrapper(nn.Module) - - nn.Conv2d, nn.Linear -> ParameterModuleWrapper(nn.Module) -""" - -quantization_mapping = { - 'max_static': MaxAbsStaticQuantization, - 'lp_norm': LpNormQuantization, -} - - -def is_positive(module): - return isinstance(module, nn.ReLU) or isinstance(module, nn.ReLU6) - - -class ActivationModuleWrapper(nn.Module): - """ - Parameters - ---------- - name (str): Name of the wrapped module - wrapped_module (object): Module to be used - kwargs (object): A yaml safe loaded file with information like bits_out and qtype - """ - - def __init__(self, name, wrapped_module, **kwargs): - super(ActivationModuleWrapper, self).__init__() - self.name = name - self.wrapped_module = wrapped_module - self.bits_out = kwargs['bits_out'] - self.qtype = kwargs['qtype'] - self.post_relu = True - self.enabled = True - self.active = True - if self.bits_out is not None: - self.out_quantization = self.out_quantization_default = None - - def __init_out_quantization__(tensor): - self.out_quantization_default = quantization_mapping[ - self.qtype]( - self, - tensor, - self.bits_out, - symmetric=(not is_positive(wrapped_module)), - uint=True, - kwargs=kwargs, - ) - self.out_quantization = self.out_quantization_default - - self.out_quantization_init_fn = __init_out_quantization__ - - def __enabled__(self): - return self.enabled and self.active and self.bits_out is not None - - def forward(self, *input): - if self.post_relu: - out = self.wrapped_module(*input) - # Quantize output - if self.__enabled__(): - self.verify_initialized(self.out_quantization, out, - self.out_quantization_init_fn) - out = self.out_quantization(out) - else: - # Quantize output - if self.__enabled__(): - self.verify_initialized(self.out_quantization, *input, - self.out_quantization_init_fn) - out = self.out_quantization(*input) - else: - out = self.wrapped_module(*input) - return out - - @staticmethod - def verify_initialized(quantization_handle, tensor, init_fn): - if quantization_handle is None: - init_fn(tensor) - - def get_quantization(self): - return self.out_quantization - - def set_quantization(self, qtype, kwargs): - self.out_quantization = qtype( - self, - self.bits_out, - symmetric=(not is_positive(self.wrapped_module)), - uint=True, - kwargs=kwargs, - ) - - -class ParameterModuleWrapper(nn.Module): - """ - Parameters - ---------- - name (str): Name of the wrapped module - wrapped_module (object): Module to be used - kwargs (object): A yaml safe loaded file with information like bits_out, qtype, forward functor, bit_weight, etc. - """ - - def __init__(self, name, wrapped_module, **kwargs): - super(ParameterModuleWrapper, self).__init__() - self.name = name - self.wrapped_module = wrapped_module - self.forward_functor = kwargs['forward_functor'] - self.bit_weights = kwargs['bit_weights'] - self.bits_out = kwargs['bits_out'] - self.qtype = kwargs['qtype'] - self.bcorr_w = kwargs['bcorr_w'] - self.bn = kwargs['bn'] if 'bn' in kwargs else None - self.enabled = True - self.active = True - self.centroids_hist = {} - self.log_weights_hist = False - self.log_weights_mse = False - self.log_clustering = False - self.dynamic_weight_quantization = True - setattr(self, 'weight', wrapped_module.weight) - delattr(wrapped_module, 'weight') - if hasattr(wrapped_module, 'bias'): - setattr(self, 'bias', wrapped_module.bias) - delattr(wrapped_module, 'bias') - if self.bit_weights is not None: - self.weight_quantization_default = quantization_mapping[ - self.qtype]( - self, - self.weight, - self.bit_weights, - symmetric=True, - uint=True, - kwargs=kwargs, - ) - self.weight_quantization = self.weight_quantization_default - if not self.dynamic_weight_quantization: - self.weight_q = self.weight_quantization(self.weight) - self.weight_mse = torch.mean( - (self.weight_q - self.weight)**2).item() - - def __enabled__(self): - return self.enabled and self.active and self.bit_weights is not None - - def bias_corr(self, x, xq): - bias_q = xq.view(xq.shape[0], -1).mean(-1) - bias_orig = x.view(x.shape[0], -1).mean(-1) - bcorr = bias_q - bias_orig - return (xq - bcorr.view(bcorr.numel(), 1, 1, 1) - if len(x.shape) == 4 else xq - bcorr.view(bcorr.numel(), 1)) - - def forward(self, *input): - w = self.weight - if self.__enabled__(): - # Quantize weights - if self.dynamic_weight_quantization: - w = self.weight_quantization(self.weight) - if self.bcorr_w: - w = self.bias_corr(self.weight, w) - else: - w = self.weight_q - out = self.forward_functor( - *input, - weight=w, - bias=(self.bias if hasattr(self, 'bias') else None)) - return out - - def get_quantization(self): - return self.weight_quantization - - def set_quantization(self, qtype, kwargs): - self.weight_quantization = qtype(self, - self.bit_weights, - symmetric=True, - uint=True, - kwargs=kwargs) diff --git a/trailmet/algorithms/quantize/quantize.py b/trailmet/algorithms/quantize/quantize.py index 7579fc1..876aa55 100644 --- a/trailmet/algorithms/quantize/quantize.py +++ b/trailmet/algorithms/quantize/quantize.py @@ -19,142 +19,311 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. + +import copy import torch import torch.nn as nn -import torch.nn.init as init -from tqdm import tqdm_notebook -from ..algorithms import BaseAlgorithm +import torch.nn.functional as F +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.quantized as nnq +import torch.ao.nn.intrinsic.quantized as nniq +from tqdm import tqdm +from typing import Union, Callable, Dict +from trailmet.models.resnet import BasicBlock, Bottleneck +from trailmet.models.mobilenet import InvertedResidual +from trailmet.algorithms.quantize.utils import StopForwardException, DataSaverHook, \ + GradSaverHook, LinearTempDecay, Node, GraphPlotter, replace_activation_with_identity, \ + get_qscheme, get_dtype, quantized_forward +from trailmet.algorithms.quantize.modules import StraightThrough, QuantModule, \ + BaseQuantBlock, QuantBasicBlock, QuantBottleneck, QuantInvertedResidual +from trailmet.algorithms.quantize._methods import BaseQuantizer, UniformQuantizer +from trailmet.algorithms.algorithms import BaseAlgorithm +from torch.nn.utils.parametrize import type_before_parametrizations as _type +from torch.ao.nn.intrinsic.modules.fused import _FusedModule +from torch.ao.quantization import QConfig, FixedQParamsObserver +from torch.ao.quantization.stubs import QuantStub, DeQuantStub +from torch.ao.quantization.fuse_modules import fuse_modules + __all__ = [ + 'BaseQuantModel', 'BaseQuantization', - 'StraightThrough', - 'RoundSTE', - 'Conv2dFunctor', - 'LinearFunctor', - 'FoldBN', + 'BaseQuantLoss', + 'GetLayerInpOut', + 'GetLayerGrad' ] - -class BaseQuantization(BaseAlgorithm): - """Base class for quantization algorithms.""" - - def __init__(self, **kwargs): - super(BaseQuantization, self).__init__(**kwargs) - pass - - def quantize(self, model, dataloaders, method, **kwargs): - pass - - def round_ste(x: torch.Tensor): - """Implement Straight-Through Estimator for rounding operation.""" - return (x.round() - x).detach() + x - - def get_calib_samples(self, train_loader, num_samples): - """Get calibration-set samples for finetuning weights and clipping - parameters.""" - calib_data = [] - for batch in train_loader: - calib_data.append(batch[0]) - if len(calib_data) * batch[0].size(0) >= num_samples: - break - return torch.cat(calib_data, dim=0)[:num_samples] - - def absorb_bn(self, module, bn_module): - w = module.weight.data - if module.bias is None: - zeros = torch.Tensor(module.out_channels).zero_().type(w.type()) - module.bias = nn.Parameter(zeros) - b = module.bias.data - invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) - w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) - b.add_(-bn_module.running_mean).mul_(invstd) - - if bn_module.affine: - w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) - b.mul_(bn_module.weight.data).add_(bn_module.bias.data) - - bn_module.register_buffer('running_mean', - torch.zeros(module.out_channels).cuda()) - bn_module.register_buffer('running_var', - torch.ones(module.out_channels).cuda()) - bn_module.register_parameter('weight', None) - bn_module.register_parameter('bias', None) - bn_module.affine = False - - def is_bn(self, m): - return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) - - def is_absorbing(self, m): - return (isinstance(m, nn.Conv2d) and m.groups == 1) or isinstance( - m, nn.Linear) - - def search_absorbe_bn(self, model): - prev = None - for m in model.children(): - if self.is_bn(m) and self.is_absorbing(prev): - m.absorbed = True - self.absorb_bn(prev, m) - self.search_absorbe_bn(m) - prev = m - - -class StraightThrough(nn.Module): - """Used to place an identity function in place of a non-differentail - operator for gradient calculation.""" - - def __int__(self): +FAKE_QUANT_MAPPING: Dict[Callable, Callable] = { + nn.Conv2d : QuantModule, + nn.Linear : QuantModule, + nni.ConvReLU2d : QuantModule, + BasicBlock : QuantBasicBlock, + Bottleneck : QuantBottleneck, + InvertedResidual : QuantInvertedResidual +} + +TRUE_QUANT_MAPPING: Dict[Callable, Callable] = { + QuantStub : nnq.Quantize, + DeQuantStub : nnq.DeQuantize, + nn.Conv2d : nnq.Conv2d, + nn.Linear : nnq.Linear, + nni.ConvReLU2d : nniq.ConvReLU2d, + nnq.FloatFunctional : nnq.QFunctional +} + + +class BaseQuantModel(nn.Module): + """base model wrapping class for quantization algorithms""" + def __init__(self, model: nn.Module, weight_quant_params: dict = {}, + act_quant_params: dict = {}, inplace = False, fuse_model=True): super().__init__() - pass - - def forward(self, input): - return input - - -class RoundSTE(torch.autograd.Function): - - @staticmethod - def forward(ctx, input): - output = torch.round(input) - return output - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class Conv2dFunctor: - - def __init__(self, conv2d): - self.conv2d = conv2d - - def __call__(self, *input, weight, bias): - res = torch.nn.functional.conv2d(*input, weight, bias, - self.conv2d.stride, - self.conv2d.padding, - self.conv2d.dilation, - self.conv2d.groups) - return res - - -class LinearFunctor: - - def __init__(self, linear): - self.linear = linear - - def __call__(self, *input, weight, bias): - res = torch.nn.functional.linear(*input, weight, bias) - return res - - -# TODO : To migrate all BN-layer folding function calls to the ones defined inside BaseQuantization class -class FoldBN: - """Used to fold batch norm to prev linear or conv layer which helps reduce - comutational overhead during quantization.""" - - def __init__(self): - pass + if not inplace: + self.model = copy.deepcopy(model) + else: + self.model = model + self.weight_quant_params = weight_quant_params + self.act_quant_params = act_quant_params + self.quant_modules = None + setattr(self.model, 'inp_quant', StraightThrough()) + setattr(self.model, 'out_dequant', StraightThrough()) + self.model.eval() + self.convert_model_to_fake_quantized(fuse_model) + + def forward(self, x): + return self.model.forward(x) + + def convert_model_to_fake_quantized(self, fuse_model=True): + if not fuse_model: + self.search_fold_conv_bn(self.model) # Do Not Use + else: + replace_activation_with_identity(self.model, [nn.ReLU, nn.ReLU6]) + self.add_fused_conv_bn_act(self.model) + + self.input_quantizer: BaseQuantizer = self.act_quant_params.get( + 'method', UniformQuantizer)(self.act_quant_params) + self.model.inp_quant = self.input_quantizer + self._quant_module_refactor(self.model) + self.model.forward = quantized_forward.__get__(self.model, nn.Module) #TODO check functionality + + def add_fused_conv_bn_act(self, model: nn.Module): + # same functionality as torchvision.models.quantization.resnet.QuantizableResNet.fuse_model + setattr(model, "relu1", nn.ReLU(inplace=True)) + model.relu1.eval() + fuse_modules(model, ["conv1", "bn1", "relu1"], inplace=True) + for module in model.modules(): + if type(module) is BasicBlock: + setattr(module, "relu1", nn.ReLU(inplace=True)) + module.relu1.eval() + fuse_modules( + module, + [["conv1", "bn1", "relu1"], ["conv2", "bn2"]], + inplace=True + ) + if module.downsample is not None: + fuse_modules(module.downsample, ["0", "1"], inplace=True) + if type(module) is Bottleneck: + setattr(module, "relu1", nn.ReLU(inplace=True)) + setattr(module, "relu2", nn.ReLU(inplace=True)) + module.relu1.eval() + module.relu2.eval() + fuse_modules( + module, + [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], + inplace=True + ) + if module.downsample is not None: + fuse_modules(module.downsample, ["0", "1"], inplace=True) + + def _quant_module_refactor(self, module: nn.Module): + """ + Recursively replace Conv2d and Linear layers with QuantModule and other + supported network blocks to their respective wrappers, to enable weight + and activations quantization. + """ + for name, child_module in module.named_children(): + if type(child_module) in FAKE_QUANT_MAPPING: + setattr(module, name, FAKE_QUANT_MAPPING[type(child_module)]( + child_module, self.weight_quant_params, self.act_quant_params + )) + elif isinstance(child_module, (StraightThrough, nn.Identity, nn.ReLU)): + continue + else: + self._quant_module_refactor(child_module) + + def convert_model_to_quantized(self, inplace=True, remove_qconfig=True): + model = self.model + if not inplace: + model = copy.deepcopy(model) + + model.to(torch.device('cpu')) + model.inp_quant = QuantStub(qconfig=QConfig(weight = None, + activation = self.input_quantizer.observer)) + model.inp_quant.add_module('activation_post_process', + model.inp_quant.qconfig.activation()) + model.out_dequant = DeQuantStub() + + self._attach_qconfig_to_quantizable(model) + model = self._convert_quantizable(model, TRUE_QUANT_MAPPING, inplace) + if remove_qconfig: + self._remove_qconfig_from_quantizable(model) + return model + + def _convert_quantizable(self, module: nn.Module, mapping: dict, inplace = True): + if not inplace: + module = copy.deepcopy(module) + reassign = dict() + for name, child_module in module.named_children(): + if not isinstance(child_module, _FusedModule): # fused modules are swapped as one unit + self._convert_quantizable(child_module, mapping, True) + if type(child_module) in mapping: + reassign[name] = self._swap_module(child_module, mapping) + for name, quantized_module in reassign.items(): + delattr(module, name) + setattr(module, name, quantized_module) + return module + + def _attach_qconfig_to_quantizable(self, module: nn.Module): + module_attach = dict() + module_reassign = dict() + + for name, child_module in module.named_children(): + if isinstance(child_module, QuantModule): + module_attach[name]['weight'] = child_module.weight_quantizer.observer + module_attach[name]['activation'] = child_module.act_quantizer.observer + module_reassign[name] = child_module.orig_module + if isinstance(child_module, BaseQuantBlock): + child_module._convert_to_quantizable_with_qconfig() + else: + self._attach_qconfig_to_quantizable(child_module) + + for name, orig_module in module_reassign.items(): + delattr(module, name) + setattr(module, name, orig_module) + + for name, observers in module_attach.items(): + submodule = getattr(module, name, None) + assert submodule is not None + if isinstance(submodule, nni.ConvReLU2d): # propagate qconfig + setattr(submodule[0], 'qconfig', QConfig( + weight = observers['weight'], + activation = None + )) + setattr(submodule, 'qconfig', QConfig( + weight = observers['weight'], + activation = observers['activation'] + )) + submodule.add_module('activation_post_process', submodule.qconfig.activation()) + + def _remove_qconfig_from_quantizable(self, module: nn.Module): + for child_module in module.children(): + self._remove_qconfig_from_quantizable(child_module) + if hasattr(module, 'activation_post_process'): + delattr(module, 'activation_post_process') + if hasattr(module, 'qconfig'): + delattr(module, 'qconfig') + + def _swap_module(self, module: nn.Module, mapping: dict): + new_module = module + swapped = False + if hasattr(module, 'qconfig') and module.qconfig is not None: + swapped = False + if _type(module) in mapping: + qmod = mapping[_type(module)] + new_module = qmod.from_float(module) + swapped = True + # print(f">> swapped {type(module)}: {type(new_module)}") + if swapped: + pass #TODO: hook management + return new_module + + def get_weight_quantizers(self): + weight_quantizers = [] + for module in self.model.modules(): + if isinstance(module, (QuantModule, BaseQuantBlock)): + weight_quantizers.append(module.weight_quantizer) + return weight_quantizers + + def get_act_quantizers(self): + act_quantizers = [] + for module in self.model.modules(): + if isinstance(module, (QuantModule, BaseQuantBlock)): + act_quantizers.append(module.act_quantizer) + return act_quantizers + + def set_observation_state(self, weight_obs: bool = True, act_obs: bool = True): + for module in self.model.modules(): + if isinstance(module, (QuantModule, BaseQuantBlock)): + module.set_observation_state(weight_obs, act_obs) + + def set_quantization_state(self, weight_quant: bool = True, act_quant: bool = True): + """ + :param weight_quant: set True to enable weight quantization + :param act_quant: set True to enable activation quantization + """ + for module in self.model.modules(): + if isinstance(module, (QuantModule, BaseQuantBlock)): + module.set_quantization_state(weight_quant, act_quant) + + def quantize_model_till(self, layer, act_quant: bool = False): + """ + :param layer: layer upto which model is to be quantized. + :param act_quant: set True for activation quantization + """ + # TODO + self.set_quant_state(False, False) + for name, module in self.model.named_modules(): + if isinstance(module, (QuantModule, BaseQuantBlock)): + module.set_quantization_state(True, act_quant) + if module == layer: + break + + def set_layer_precision(self, weight_bits: list, act_bit: int): + """ + :param weight_bits: list of bitwidths for layer weights + :param act_bit: bitwidth for activations + """ + # TODO + quant_modules = [m for m in self.model.modules() if isinstance(m, QuantModule)] + assert len(weight_bits)==len(quant_modules) + for idx, module in enumerate(quant_modules): + module.weight_quantizer.reset_bitwidth(weight_bits[idx]) + if module is not self.quant_modules[-1]: + module.act_quantizer.reset_bitwidth(act_bit) + + def search_fold_conv_bn(self, module: nn.Module): + """ + Recursively search for BatchNorm layers, fold them into the previous + Conv2d or Linear layers and set them as a StraightThrough layer. + """ + prev_module = None + for name, child_module in module.named_children(): + if self._is_bn(child_module) and self._is_absorbing(prev_module): + self._fold_bn_into_conv(prev_module, child_module) + setattr(module, name, StraightThrough()) + elif self._is_absorbing(child_module): + prev_module = child_module + else: + prev_module = self.search_fold_conv_bn(child_module) + return prev_module + + def _is_bn(self, module): + return isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)) + + def _is_absorbing(self, module): + return isinstance(module, (nn.Conv2d, nn.Linear)) + + def _fold_bn_into_conv(self, conv_module: nn.Conv2d, bn_module: nn.BatchNorm2d): + # same as torch.nn.utils.fusion.fuse_conv_bn_eval + w, b = self._get_folded_params(conv_module, bn_module) + if conv_module.bias is None: + conv_module.bias = nn.Parameter(b) + else: + conv_module.bias.data = b + conv_module.weight.data = w + bn_module.running_mean = bn_module.bias.data + bn_module.running_var = bn_module.weight.data ** 2 - def _fold_bn(self, conv_module, bn_module): + def _get_folded_params(self, conv_module: nn.Conv2d, bn_module: nn.BatchNorm2d): w = conv_module.weight.data y_mean = bn_module.running_mean y_var = bn_module.running_var @@ -176,35 +345,304 @@ def _fold_bn(self, conv_module, bn_module): bias = beta return weight, bias - def fold_bn_into_conv(self, conv_module, bn_module): - w, b = self._fold_bn(conv_module, bn_module) - if conv_module.bias is None: - conv_module.bias = nn.Parameter(b) - else: - conv_module.bias.data = b - conv_module.weight.data = w - # set bn running stats - bn_module.running_mean = bn_module.bias.data - bn_module.running_var = bn_module.weight.data**2 - - def is_bn(self, m): - return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) +class BaseQuantization(BaseAlgorithm): + """base class for quantization algorithms""" + def __init__(self, dataloaders, **kwargs): + super(BaseQuantization, self).__init__(**kwargs) + self.train_data = dataloaders['train'] + self.test_data = dataloaders['test'] - def is_absorbing(self, m): - return (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear) + def quantize(self, model, method, **kwargs): + pass - def search_fold_and_remove_bn(self, model: nn.Module): - """Method to recursively search for batch norm layers, absorb them into - the previous linear or conv layers, and set it to an identity layer.""" + def get_calib_data(self, num_samples: int, batch_size: int): + """ + Get samples for calibrating quantization parameters + """ + inp, out = [], [] + for batch in self.train_data: + inp.extend(batch[0]) + out.extend(batch[1]) + if len(inp) >= num_samples: + break + batches = [] + for i in range(0, num_samples, batch_size): + batch_inp = inp[i: i+batch_size] + batch_out = out[i: i+batch_size] + batches.append([ + torch.stack(batch_inp, dim=0).to(torch.device('cuda')), + torch.stack(batch_out, dim=0).to(torch.device('cuda')) + ]) + return batches + + def evaluate_loss(self, model, dataloader, device): + criterion = torch.nn.CrossEntropyLoss().to(device) model.eval() - prev = None - for n, m in model.named_children(): - if self.is_bn(m) and self.is_absorbing(prev): - self.fold_bn_into_conv(prev, m) - # set the bn module to straight through - setattr(model, n, StraightThrough()) - elif self.is_absorbing(m): - prev = m - else: - prev = self.search_fold_and_remove_bn(m) - return prev + res = 0 + with torch.no_grad(): + for inputs, targets in dataloader: + outputs = model(inputs) + loss = criterion(outputs, targets) + res += loss.item() + return res/len(dataloader) + + def sensitivity_analysis(self, qmodel: BaseQuantModel, dataloader, test_bits, + budget, save_path, exp_name): + qmodel.set_quant_state(False, False) + inputs = None + fp_outputs = None + with torch.no_grad(): + for batch_idx, (inputs, outputs) in enumerate(dataloader): + inputs = inputs.to(self.device) + fp_outputs = qmodel(inputs) + fp_outputs = F.softmax(fp_outputs, dim=1) + break + sensitivities = [[0 for i in range(len(qmodel.quant_modules))] + for j in range(len(test_bits))] + for i, layer in enumerate(qmodel.quant_modules): + for j, bit in enumerate(test_bits): + layer.set_quant_state(True, True) + layer.weight_quantizer.bitwidth_refactor(bit) + layer.weight_quantizer.inited = False + layer.weight_quantizer.scale_method = 'max' + with torch.no_grad(): + tmp_outputs = qmodel(inputs) + tmp_outputs = F.softmax(tmp_outputs, dim=1) + kld = (F.kl_div(tmp_outputs, fp_outputs, reduction='batchmean') + + F.kl_div(fp_outputs, tmp_outputs, reduction='batchmean')) / 2 + sensitivities[j][i] = kld.item() + layer.set_quant_state(False, False) + layer.weight_quantizer.scale_method = 'mse' + + gp = GraphPlotter(save_path+'/logs/plots') + gp.line_plotter(sensitivities, test_bits, '{} bit', f'{exp_name}_layer_sensitivity', + 'layer', 'sensitivity', 'log') + + weight_numels = [qmodule.weight.numel() for qmodule in qmodel.quant_modules] + node_list = self.dp_most_profit_over_cost(sensitivities, len(qmodel.quant_modules), weight_numels, test_bits) + constraint = sum(weight_numels)*32*budget / (8*1024*1024) + good_nodes = [node for node in node_list if node.cost <= constraint] + bits = [] + node = good_nodes[-1] + while(node is not None): + bits.append(node.bit) + node = node.parent + bits.reverse() + bits = bits[1:] + assert len(bits)==len(qmodel.quant_modules) + gp.line_plotter([bits], ['weight bits'], title=f'{exp_name}_layer_precisions', + xlabel='layer', ylabel='bits') + qmodel_size = 0 + for i, layer in enumerate(qmodel.quant_modules): + qmodel_size += layer.weight.numel()*bits[i]/(8*1024*1024) + return bits, qmodel_size, constraint + + def dp_most_profit_over_cost(self, sensitivities, num_layers, weight_numels, bits, constraint=100): + cost = bits + profits = [] + for line in sensitivities: + profits.append([-i for i in line]) + root = Node(cost=0, profit=0, parent=None) + current_list = [root] + for layer_id in range(num_layers): + next_list = [] + for n in current_list: + n.left = Node(n.cost + cost[0]*weight_numels[layer_id]/(8*1024*1024), + n.profit + profits[0][layer_id], + bit = bits[0], parent=n, position='left') + n.middle = Node(n.cost + cost[1]*weight_numels[layer_id]/(8*1024*1024), + n.profit + profits[1][layer_id], + bit = bits[1], parent=n, position='middle') + n.right = Node(n.cost + cost[2]*weight_numels[layer_id]/(8*1024*1024), + n.profit + profits[2][layer_id], + bit = bits[2], parent=n, position='right') + next_list.extend([n.left, n.middle, n.right]) + next_list.sort(key=lambda x: x.cost, reverse=False) + pruned_list = [] + for node in next_list: + if (len(pruned_list)==0 or pruned_list[-1].profit < node.profit) and node.cost <= constraint: + pruned_list.append(node) + else: + node.parent.__dict__[node.position] = None + current_list = pruned_list + return current_list + + +class BaseQuantLoss: + def __init__(self, module: Union[QuantModule, BaseQuantBlock], + round_loss: str = 'relaxation', weight: float = 1., rec_loss: str = 'mse', + max_count: int = 2000, b_range: tuple = (10, 2), decay_start: float = 0.0, + warmup: float = 0.0, p: float = 2.): + + self.module = module + self.round_loss = round_loss + self.weight = weight + self.rec_loss = rec_loss + self.loss_start = max_count * warmup + self.p = p + self.count = 0 + self.pbar = tqdm(total=max_count) + self.temp_decay = LinearTempDecay(max_count, + rel_start_decay=warmup + (1 - warmup) * decay_start, + start_b=b_range[0], end_b=b_range[1]) + + def __call__(self, pred, tgt, grad=None): + """ + Compute the total loss for adaptive rounding: + rec_loss is the quadratic output reconstruction loss, round_loss is + a regularization term to optimize the rounding policy + :param pred: output from quantized model + :param tgt: output from FP model + :param grad: gradients to compute fisher information + :return: total loss function + """ + self.count += 1 + if self.rec_loss == 'mse': + rec_loss = self.lp_norm(pred, tgt, self.p, reduction='none') + elif self.rec_loss == 'fisher_diag': + rec_loss = self.fisher_diag(pred, tgt, grad) + elif self.rec_loss == 'fisher_full': + rec_loss = self.fisher_full(pred, tgt, grad) + else: + raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) + + b = self.temp_decay(self.count) + if self.count < self.loss_start or self.round_loss == 'none': + b = round_loss = 0 + elif self.round_loss == 'relaxation': + round_loss = 0 + if isinstance(self.module, QuantModule): + round_vals = self.module.weight_quantizer.get_soft_targets() + round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() + if isinstance(self.module, BaseQuantBlock): + for name, submodule in self.module.named_modules(): + if isinstance(submodule, QuantModule): + round_vals = submodule.weight_quantizer.get_soft_targets() + round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() + else: + raise NotImplementedError + + total_loss = rec_loss + round_loss + if self.count % 100 == 0: + self.pbar.set_postfix(loss=float(total_loss), b=b) + self.pbar.update(1) + return total_loss + + @staticmethod + def lp_norm(pred, tgt, p=2.0, reduction = 'mean'): + if reduction == 'mean': + return (pred-tgt).abs().pow(p).mean() + elif reduction == 'none': + return (pred-tgt).abs().pow(p).sum(1).mean() + else: + raise KeyError + + @staticmethod + def fisher_diag(pred, tgt, grad): + return ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() + + @staticmethod + def fisher_full(pred, tgt, grad): + a = (pred - tgt).abs() + grad = grad.abs() + batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) + return (batch_dotprod * a * grad).mean() / 100 + +class GetLayerInpOut: + """ + Get the input and output of a specified layer in a quantized model. + + :param model: quantized model for which the input and output needs to be extracted. + :param layer: the layer for which input and output needs to be extracted. + :param device: the device on which the computation needs to be performed. + :param asym: save quantized input and full precision output. [default=False] + :param act_quant: use activation quantization. [default=False] + """ + def __init__(self, model: BaseQuantModel, layer: Union[QuantModule, BaseQuantBlock], + device: torch.device, asym: bool = False, act_quant: bool = False): + self.model = model + self.layer = layer + self.asym = asym + self.device = device + self.act_quant = act_quant + self.data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=True) + + def __call__(self, model_input): + """ + :param model_input: calibration data samples + :return: tuple of layer input and output + """ + self.model.eval() + self.model.set_quant_state(False, False) + + handle = self.layer.register_forward_hook(self.data_saver) + with torch.no_grad(): + try: + _ = self.model(model_input.to(self.device)) + except StopForwardException: + pass + + if self.asym: + self.data_saver.store_output = False + self.model.set_quant_state(weight_quant=True, act_quant=self.act_quant) + try: + _ = self.model(model_input.to(self.device)) + except StopForwardException: + pass + self.data_saver.store_output = True + + handle.remove() + + self.model.set_quant_state(False, False) + self.layer.set_quant_state(True, self.act_quant) + self.model.train() + + return self.data_saver.input_store[0].detach(), self.data_saver.output_store.detach() + +class GetLayerGrad: + """ + Get the gradient a specified layer in a quantized model. + + :param model: quantized model for which the input and output needs to be extracted. + :param layer: the layer for which input and output needs to be extracted. + :param device: the device on which the computation needs to be performed. + :param asym: if True, save quantized input and full precision output. [default=False] + :param act_quant: use activation quantization. [default=False] + """ + def __init__(self, model: BaseQuantModel, layer: Union[QuantModule, BaseQuantBlock], + device: torch.device, act_quant: bool = False): + self.model = model + self.layer = layer + self.device = device + self.act_quant = act_quant + self.data_saver = GradSaverHook(True) + + def __call__(self, model_input): + """ + Compute the gradients of layer output, note that we compute the + gradient by calculating the KL loss between fp model and quant model + + :param model_input: calibration data samples + :return: gradients for the layer + """ + self.model.eval() + + handle = self.layer.register_backward_hook(self.data_saver) + with torch.enable_grad(): + try: + self.model.zero_grad() + inputs = model_input.to(self.device) + self.model.set_quant_state(False, False) + out_fp = self.model(inputs) + self.model.quantize_model_till(self.layer, self.act_quant) + out_q = self.model(inputs) + loss = F.kl_div(F.log_softmax(out_q, dim=1), F.softmax(out_fp, dim=1), reduction='batchmean') + loss.backward() + except StopForwardException: + pass + + handle.remove() + self.model.set_quant_state(False, False) + self.layer.set_quant_state(True, self.act_quant) + self.model.train() + return self.data_saver.grad_out.data \ No newline at end of file diff --git a/trailmet/algorithms/quantize/reconstruct.py b/trailmet/algorithms/quantize/reconstruct.py deleted file mode 100644 index b5857bd..0000000 --- a/trailmet/algorithms/quantize/reconstruct.py +++ /dev/null @@ -1,834 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Transmute AI Lab -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# source: https://github.com/yhhhli/BRECQ/tree/main/quant - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from typing import Union -from tqdm import tqdm -from trailmet.algorithms.quantize.quantize import ( - StraightThrough, - BaseQuantization as BQ, -) -from trailmet.algorithms.quantize.qmodel import QuantModule, BaseQuantBlock -from trailmet.algorithms.quantize.methods import AdaRoundQuantizer -from trailmet.utils import lp_loss - -__all__ = [ - 'StopForwardException', - 'DataSaverHook', - 'GetLayerInpOut', - 'save_inp_oup_data', - 'GradSaverHook', - 'GetLayerGrad', - 'save_grad_data', - 'LinearTempDecay', - 'LayerLossFunction', - 'layer_reconstruction', - 'BlockLossFunction', - 'block_reconstruction', -] - -optimizer_map = { - 'sgd': torch.optim.SGD, - 'adam': torch.optim.Adam, - 'adagrad': torch.optim.Adagrad, - 'adadelta': torch.optim.Adadelta, -} - - -class StopForwardException(Exception): - """Used to throw and catch an exception to stop traversing the graph.""" - - pass - - -class DataSaverHook: - """Forward hook that stores the input and output of a layer. - - Parameters - ---------- - store_input (bool): If True, input of a layer will be saved, default=False - store_output (bool): If True, output of a layer will be saved, default=False - stop_forward (bool): If True, forward prop will be stopped, default=False. - """ - - def __init__(self, - store_input=False, - store_output=False, - stop_forward=False): - self.store_input = store_input - self.store_output = store_output - self.stop_forward = stop_forward - - self.input_store = None - self.output_store = None - - def __call__(self, module, input_batch, output_batch): - if self.store_input: - self.input_store = input_batch - if self.store_output: - self.output_store = output_batch - if self.stop_forward: - raise StopForwardException - - -class GetLayerInpOut: - """Get the input and output of a specified layer in a quantized model. - - model: quantized model for which the input and output needs to be extracted. - layer: the layer for which input and output needs to be extracted. - device: the device on which the computation needs to be performed. - asym: save quantized input and full precision output. [default=False] - act_quant: use activation quantization. [default=False] - """ - - def __init__( - self, - model, - layer: Union[QuantModule, BaseQuantBlock], - device: torch.device, - asym: bool = False, - act_quant: bool = False, - ): - self.model = model - self.layer = layer - self.asym = asym - self.device = device - self.act_quant = act_quant - self.data_saver = DataSaverHook(store_input=True, - store_output=True, - stop_forward=True) - - def __call__(self, model_input): - """ - Parameters - ---------- - model_input: calibration data samples - return: tuple of layer input and output - """ - self.model.eval() - self.model.set_quant_state(False, False) - - handle = self.layer.register_forward_hook(self.data_saver) - with torch.no_grad(): - try: - _ = self.model(model_input.to(self.device)) - except StopForwardException: - pass - - if self.asym: - self.data_saver.store_output = False - self.model.set_quant_state(weight_quant=True, - act_quant=self.act_quant) - try: - _ = self.model(model_input.to(self.device)) - except StopForwardException: - pass - self.data_saver.store_output = True - - handle.remove() - - self.model.set_quant_state(False, False) - self.layer.set_quant_state(True, self.act_quant) - self.model.train() - - return ( - self.data_saver.input_store[0].detach(), - self.data_saver.output_store.detach(), - ) - - -def save_inp_oup_data( - model, - layer: Union[QuantModule, BaseQuantBlock], - cali_data: torch.Tensor, - asym: bool = False, - act_quant: bool = False, - batch_size: int = 32, - keep_gpu: bool = True, -): - """Function to save input data and output data of a particular layer/block - over calibration dataset. - - Parameters - ---------- - model: quantized model for which the input and output needs to be extracted. - layer: the layer for which input and output needs to be extracted. - cali_data: calibration dataset - asym: save quantized input and full precision output. [default=False] - act_quant: use activation quantization. [default=False] - batch_size: mini-batch size for calibration. [default=32] - keep_gpu: put saved data on GPU for faster optimization. [default=True] - :return: input and output data - """ - device = next(model.parameters()).device - get_inp_out = GetLayerInpOut(model, - layer, - device=device, - asym=asym, - act_quant=act_quant) - cached_batches = [] - torch.cuda.empty_cache() - - for i in range(int(cali_data.size(0) / batch_size)): - cur_inp, cur_out = get_inp_out(cali_data[i * batch_size:(i + 1) * - batch_size]) - cached_batches.append((cur_inp.cpu(), cur_out.cpu())) - - cached_inps = torch.cat([x[0] for x in cached_batches]) - cached_outs = torch.cat([x[1] for x in cached_batches]) - torch.cuda.empty_cache() - if keep_gpu: - cached_inps = cached_inps.to(device) - cached_outs = cached_outs.to(device) - return cached_inps, cached_outs - - -class GradSaverHook: - """Backward hook that stores the gradients of a layer. - - Parameters - ---------- - store_grad (bool): if True, gradient of the layer will be stored - """ - - def __init__(self, store_grad=True): - self.store_grad = store_grad - self.stop_backward = False - self.grad_out = None - - def __call__(self, module, grad_input, grad_output): - if self.store_grad: - self.grad_out = grad_output[0] - if self.stop_backward: - raise StopForwardException - - -class GetLayerGrad: - """Get the gradient a specified layer in a quantized model. - - Parameters - ---------- - model: quantized model for which the input and output needs to be extracted. - layer: the layer for which input and output needs to be extracted. - device: the device on which the computation needs to be performed. - asym: if True, save quantized input and full precision output. [default=False] - act_quant: use activation quantization. [default=False] - """ - - def __init__( - self, - model, - layer: Union[QuantModule, BaseQuantBlock], - device: torch.device, - act_quant: bool = False, - ): - self.model = model - self.layer = layer - self.device = device - self.act_quant = act_quant - self.data_saver = GradSaverHook(True) - - def __call__(self, model_input): - """Compute the gradients of layer output, note that we compute the - gradient by calculating the KL loss between fp model and quant model. - - Parameters - ---------- - model_input: calibration data samples - :return: gradients for the layer - """ - self.model.eval() - - handle = self.layer.register_backward_hook(self.data_saver) - with torch.enable_grad(): - try: - self.model.zero_grad() - inputs = model_input.to(self.device) - self.model.set_quant_state(False, False) - out_fp = self.model(inputs) - self.model.quantize_model_till(self.layer, self.act_quant) - out_q = self.model(inputs) - loss = F.kl_div( - F.log_softmax(out_q, dim=1), - F.softmax(out_fp, dim=1), - reduction='batchmean', - ) - loss.backward() - except StopForwardException: - pass - - handle.remove() - self.model.set_quant_state(False, False) - self.layer.set_quant_state(True, self.act_quant) - self.model.train() - return self.data_saver.grad_out.data - - -def save_grad_data( - model, - layer: Union[QuantModule, BaseQuantBlock], - cali_data: torch.Tensor, - damping: float = 1.0, - act_quant: bool = False, - batch_size: int = 32, - keep_gpu: bool = True, -): - """Function to save gradient data of a particular layer/block over - calibration dataset. - - Parameters - ---------- - model: quantized model for which the input and output needs to be extracted. - layer: the layer for which input and output needs to be extracted. - cali_data: calibration dataset - damping: damping the second-order gradient by adding some constant in the FIM diagonal - act_quant: use activation quantization. [default=False] - batch_size: mini-batch size for calibration. [default=32] - keep_gpu: put saved data on GPU for faster optimization. [default=True] - :return: gradient data - """ - device = next(model.parameters()).device - get_grad = GetLayerGrad(model, layer, device, act_quant=act_quant) - cached_batches = [] - torch.cuda.empty_cache() - - for i in range(int(cali_data.size(0) / batch_size)): - cur_grad = get_grad(cali_data[i * batch_size:(i + 1) * batch_size]) - cached_batches.append(cur_grad.cpu()) - - cached_grads = torch.cat([x for x in cached_batches]) - cached_grads = cached_grads.abs() + 1.0 - # scaling to make sure its mean is 1 - # cached_grads = cached_grads * torch.sqrt(cached_grads.numel() / cached_grads.pow(2).sum()) - torch.cuda.empty_cache() - if keep_gpu: - cached_grads = cached_grads.to(device) - return cached_grads - - -# ================================ -# ****** Reconstruct Layer ******* -# ================================ - - -class LinearTempDecay: - """Class to implement a linear temperature decay scheduler for a given - maximum time step. - - Parameters - ---------- - t_max: maximum number of time steps to decay temperature over. - rel_start_decay: relative point in time to start the decay from the maximum time step. [default=.2] - start_b: initial temperature value. [default=10] - end_b: final temperature value. [default=2] - """ - - def __init__( - self, - t_max: int, - rel_start_decay: float = 0.2, - start_b: int = 10, - end_b: int = 2, - ): - self.t_max = t_max - self.start_decay = rel_start_decay * t_max - self.start_b = start_b - self.end_b = end_b - - def __call__(self, t): - """Cosine annealing scheduler for temperature b. - - Parameters - ---------- - t: the current time step - :return: scheduled temperature - """ - if t < self.start_decay: - return self.start_b - else: - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - return self.end_b + (self.start_b - self.end_b) * max( - 0.0, (1 - rel_t)) - - -class LayerLossFunction: - """ - Parameters - ---------- - layer (object): layer to be quantized - Round_loss (str): type of regularization term used to optimize rounding policy (options: relaxation, none) - Weight (float): weight of rounding loss in total loss - Rec_loss (str): type of output reconstruction loss (options: mse, fisher_diag, fisher_full) - max_count (int): number of iterations - b_range (tuple): range of rounding relaxation factor (b) with linear temp decay scheduler - decay_start (float): starting point for temp decay of b - warmup (float): fraction of iterations used for warmup before applying rounding loss - p (float): power in lp-norm computation of reconstruction loss - """ - - def __init__( - self, - layer: QuantModule, - round_loss: str = 'relaxation', - weight: float = 1.0, - rec_loss: str = 'mse', - max_count: int = 2000, - b_range: tuple = (10, 2), - decay_start: float = 0.0, - warmup: float = 0.0, - p: float = 2.0, - ): - self.layer = layer - self.round_loss = round_loss - self.weight = weight - self.rec_loss = rec_loss - self.loss_start = max_count * warmup - self.p = p - self.count = 0 - # self.pbar = tqdm(total=max_count) - self.pbar = tqdm( - total=max_count, - desc='Reconstructing Layer: Loss (X.X) b (X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - ) - self.temp_decay = LinearTempDecay( - max_count, - rel_start_decay=warmup + (1 - warmup) * decay_start, - start_b=b_range[0], - end_b=b_range[1], - ) - - def __call__(self, pred, tgt, grad=None): - """ - Compute the total loss for adaptive rounding: - rec_loss is the quadratic output reconstruction loss, round_loss is - a regularization term to optimize the rounding policy - - Parameters - ---------- - pred: output from quantized model - tgt: output from FP model - grad: gradients to compute fisher information - :return: total loss function - """ - self.count += 1 - if self.rec_loss == 'mse': - rec_loss = lp_loss(pred, tgt, p=self.p) - elif self.rec_loss == 'fisher_diag': - rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() - elif self.rec_loss == 'fisher_full': - a = (pred - tgt).abs() - grad = grad.abs() - batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) - rec_loss = (batch_dotprod * a * grad).mean() / 100 - else: - raise ValueError( - 'Not supported reconstruction loss function: {}'.format( - self.rec_loss)) - - b = self.temp_decay(self.count) - if self.count < self.loss_start or self.round_loss == 'none': - b = round_loss = 0 - elif self.round_loss == 'relaxation': - round_loss = 0 - round_vals = self.layer.weight_quantizer.get_soft_targets() - round_loss += (self.weight * - (1 - ((round_vals - 0.5).abs() * 2).pow(b)).sum()) - else: - raise NotImplementedError - - total_loss = rec_loss + round_loss - - if self.count % 100 == 0: - self.pbar.set_description( - 'Reconstructing Layer: Loss ({:.3f}) b ({:.1f})'.format( - float(total_loss), b)) - # self.pbar.set_postfix(loss=float(total_loss), b=b) - self.pbar.update(1) - return total_loss - - -def layer_reconstruction( - model, - layer: QuantModule, - cali_data: torch.Tensor, - batch_size: int = 32, - iters: int = 20000, - weight: float = 0.001, - opt_mode: str = 'mse', - asym: bool = False, - include_act_func: bool = True, - b_range: tuple = (20, 2), - warmup: float = 0.0, - act_quant: bool = False, - lr: float = 4e-5, - p: float = 2.0, - multi_gpu: bool = False, - optim='adam', -): - """Block reconstruction to optimize the output from each layer. - - Parameters - ---------- - model: QuantModel - layer: QuantModule that needs to be optimized - cali_data: data for calibration, typically 1024 training images, as described in AdaRound - batch_size: mini-batch size for reconstruction - iters: optimization iterations for reconstruction, - weight: the weight of rounding regularization term - opt_mode: optimization mode - asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output - include_act_func: optimize the output after activation function - b_range: temperature range - warmup: proportion of iterations that no scheduling for temperature - act_quant: use activation quantization or not. - lr: learning rate for act delta learning - p: L_p norm minimization - multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients - """ - - model.set_quant_state(False, False) - layer.set_quant_state(True, act_quant) - round_mode = 'learned_hard_sigmoid' - - if not include_act_func: - org_act_func = layer.activation_function - layer.activation_function = StraightThrough() - - if not act_quant: - # Replace weight quantizer to AdaRoundQuantizer - layer.weight_quantizer = AdaRoundQuantizer( - uaq=layer.weight_quantizer, - round_mode=round_mode, - weight_tensor=layer.org_weight.data, - ) - layer.weight_quantizer.soft_targets = True - - # Set up optimizer - opt_params = [layer.weight_quantizer.alpha] - optimizer = optimizer_map[optim](opt_params) - scheduler = None - else: - # Use UniformAffineQuantizer to learn delta - opt_params = [layer.act_quantizer.delta] - optimizer = optimizer_map[optim](opt_params, lr=lr) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, - T_max=iters, - eta_min=0.0) - - loss_mode = 'none' if act_quant else 'relaxation' - rec_loss = opt_mode - - loss_func = LayerLossFunction( - layer, - round_loss=loss_mode, - weight=weight, - max_count=iters, - rec_loss=rec_loss, - b_range=b_range, - decay_start=0, - warmup=warmup, - p=p, - ) - - # Save data before optimizing the rounding - cached_inps, cached_outs = save_inp_oup_data(model, layer, cali_data, asym, - act_quant, batch_size) - if opt_mode != 'mse': - cached_grads = save_grad_data(model, - layer, - cali_data, - act_quant, - batch_size=batch_size) - else: - cached_grads = None - device = 'cuda' - for i in range(iters): - idx = torch.randperm(cached_inps.size(0))[:batch_size] - cur_inp = cached_inps[idx] - cur_out = cached_outs[idx] - cur_grad = cached_grads[idx] if opt_mode != 'mse' else None - - optimizer.zero_grad() - out_quant = layer(cur_inp) - - err = loss_func(out_quant, cur_out, cur_grad) - err.backward(retain_graph=True) - if multi_gpu: - for p in opt_params: - dist.all_reduce(p.grad) - optimizer.step() - if scheduler: - scheduler.step() - - torch.cuda.empty_cache() - - # Finish optimization, use hard rounding. - layer.weight_quantizer.soft_targets = False - - # Reset original activation function - if not include_act_func: - layer.activation_function = org_act_func - - -# ================================= -# ******* Reconstruct Block ******* -# ================================= - - -class BlockLossFunction: - """ - Parameters - ---------- - Module (object): module or block being quantized - Round_loss (str): type of regularization term used to optimize rounding policy (options: relaxation, none) - Weight (float): weight of rounding loss in total loss - Rec_loss (str): type of output reconstruction loss (options: mse, fisher_diag, fisher_full) - max_count (int): number of iterations - b_range (tuple): range of rounding relaxation factor (b) with linear temp decay scheduler - decay_start (float): starting point for temp decay of b - warmup (float): fraction of iterations used for warmup before applying rounding loss - p (float): power in lp-norm computation of reconstruction loss - """ - - def __init__( - self, - block: BaseQuantBlock, - round_loss: str = 'relaxation', - weight: float = 1.0, - rec_loss: str = 'mse', - max_count: int = 2000, - b_range: tuple = (10, 2), - decay_start: float = 0.0, - warmup: float = 0.0, - p: float = 2.0, - ): - self.block = block - self.round_loss = round_loss - self.weight = weight - self.rec_loss = rec_loss - self.loss_start = max_count * warmup - self.p = p - self.count = 0 - # self.pbar = tqdm(total=max_count) - self.pbar = tqdm( - total=max_count, - desc='Reconstructing Block: Loss (X.X) b (X)', - bar_format='{l_bar}{r_bar}', - dynamic_ncols=True, - ) - self.temp_decay = LinearTempDecay( - max_count, - rel_start_decay=warmup + (1 - warmup) * decay_start, - start_b=b_range[0], - end_b=b_range[1], - ) - - def __call__(self, pred, tgt, grad=None): - """ - Compute the total loss for adaptive rounding: - rec_loss is the quadratic output reconstruction loss, round_loss is - a regularization term to optimize the rounding policy - - Parameters - ---------- - pred: output from quantized model - tgt: output from FP model - grad: gradients to compute fisher information - :return: total loss function - """ - self.count += 1 - if self.rec_loss == 'mse': - rec_loss = lp_loss(pred, tgt, p=self.p) - elif self.rec_loss == 'fisher_diag': - rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() - elif self.rec_loss == 'fisher_full': - a = (pred - tgt).abs() - grad = grad.abs() - batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) - rec_loss = (batch_dotprod * a * grad).mean() / 100 - else: - raise ValueError( - 'Not supported reconstruction loss function: {}'.format( - self.rec_loss)) - - b = self.temp_decay(self.count) - if self.count < self.loss_start or self.round_loss == 'none': - b = round_loss = 0 - elif self.round_loss == 'relaxation': - round_loss = 0 - for name, module in self.block.named_modules(): - if isinstance(module, QuantModule): - round_vals = module.weight_quantizer.get_soft_targets() - round_loss += (self.weight * (1 - ( - (round_vals - 0.5).abs() * 2).pow(b)).sum()) - else: - raise NotImplementedError - - total_loss = rec_loss + round_loss - if self.count % 100 == 0: - self.pbar.set_description( - 'Reconstructing Block: Loss ({:.3f}) b ({:.1f})'.format( - float(total_loss), b)) - # self.pbar.set_postfix(loss=float(total_loss), b=b) - self.pbar.update(1) - return total_loss - - -def block_reconstruction( - model, - block: BaseQuantBlock, - cali_data: torch.Tensor, - batch_size: int = 32, - iters: int = 20000, - weight: float = 0.01, - opt_mode: str = 'mse', - asym: bool = False, - include_act_func: bool = True, - b_range: tuple = (20, 2), - warmup: float = 0.0, - act_quant: bool = False, - lr: float = 4e-5, - p: float = 2.0, - multi_gpu: bool = False, - optim='adam', -): - """Block reconstruction to optimize the output from each block. - - Parameters - ---------- - model: QuantModel - block: BaseQuantBlock that needs to be optimized - cali_data: data for calibration, typically 1024 training images, as described in AdaRound - batch_size: mini-batch size for reconstruction - iters: optimization iterations for reconstruction, - weight: the weight of rounding regularization term - opt_mode: optimization mode - asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output - include_act_func: optimize the output after activation function - b_range: temperature range - warmup: proportion of iterations that no scheduling for temperature - act_quant: use activation quantization or not. - lr: learning rate for act delta learning - p: L_p norm minimization - multi_gpu: use multi-GPU or not, if enabled, we should sync the gradients - """ - model.set_quant_state(False, False) - block.set_quant_state(True, act_quant) - round_mode = 'learned_hard_sigmoid' - - if not include_act_func: - org_act_func = block.activation_function - block.activation_function = StraightThrough() - - if not act_quant: - # Replace weight quantizer to AdaRoundQuantizer - for name, module in block.named_modules(): - if isinstance(module, QuantModule): - module.weight_quantizer = AdaRoundQuantizer( - uaq=module.weight_quantizer, - round_mode=round_mode, - weight_tensor=module.org_weight.data, - ) - module.weight_quantizer.soft_targets = True - - # Set up optimizer - opt_params = [] - for name, module in block.named_modules(): - if isinstance(module, QuantModule): - opt_params += [module.weight_quantizer.alpha] - optimizer = optimizer_map[optim](opt_params) - scheduler = None - else: - # Use UniformAffineQuantizer to learn delta - if hasattr(block.act_quantizer, 'delta'): - opt_params = [block.act_quantizer.delta] - else: - opt_params = [] - for name, module in block.named_modules(): - if isinstance(module, QuantModule): - if module.act_quantizer.delta is not None: - opt_params += [module.act_quantizer.delta] - optimizer = optimizer_map[optim](opt_params, lr=lr) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, - T_max=iters, - eta_min=0.0) - - loss_mode = 'none' if act_quant else 'relaxation' - rec_loss = opt_mode - - loss_func = BlockLossFunction( - block, - round_loss=loss_mode, - weight=weight, - max_count=iters, - rec_loss=rec_loss, - b_range=b_range, - decay_start=0, - warmup=warmup, - p=p, - ) - - # Save data before optimizing the rounding - cached_inps, cached_outs = save_inp_oup_data(model, block, cali_data, asym, - act_quant, batch_size) - if opt_mode != 'mse': - cached_grads = save_grad_data(model, - block, - cali_data, - act_quant, - batch_size=batch_size) - else: - cached_grads = None - device = 'cuda' - for i in range(iters): - idx = torch.randperm(cached_inps.size(0))[:batch_size] - cur_inp = cached_inps[idx].to(device) - cur_out = cached_outs[idx].to(device) - cur_grad = cached_grads[idx].to(device) if opt_mode != 'mse' else None - - optimizer.zero_grad() - out_quant = block(cur_inp) - - err = loss_func(out_quant, cur_out, cur_grad) - err.backward(retain_graph=True) - if multi_gpu: - for p in opt_params: - dist.all_reduce(p.grad) - optimizer.step() - if scheduler: - scheduler.step() - - torch.cuda.empty_cache() - - # Finish optimization, use hard rounding. - for name, module in block.named_modules(): - if isinstance(module, QuantModule): - module.weight_quantizer.soft_targets = False - - # Reset original activation function - if not include_act_func: - block.activation_function = org_act_func diff --git a/trailmet/algorithms/quantize/utils.py b/trailmet/algorithms/quantize/utils.py new file mode 100644 index 0000000..a82e694 --- /dev/null +++ b/trailmet/algorithms/quantize/utils.py @@ -0,0 +1,214 @@ +# MIT License +# +# Copyright (c) 2023 Transmute AI Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import torch +from plotly import graph_objects + +__all__ = [ + 'get_qscheme', + 'get_dtype', + 'replace_activation_with_identity', + 'StopForwardException', + 'DataSaverHook', + 'GradSaverHook', + 'LinearTempDecay' + 'Node', + 'GraphPlotter' +] + +def get_qscheme(per_channel=False, symmetric=False): + if per_channel and symmetric: + return torch.per_channel_symmetric + elif per_channel and not symmetric: + return torch.per_channel_affine + elif not per_channel and symmetric: + return torch.per_tensor_symmetric + else: + return torch.per_tensor_affine + +def get_dtype(quant_min: int, quant_max: int, reduce_range: bool = True): + # byte width for qint and quint is reduced by 1 for 'x86' backend + assert quant_min < quant_max + byte_width = 8 + if reduce_range: + byte_width = 7 + if quant_min >= 0 and quant_max < (2**byte_width): + return torch.quint8 + elif quant_min >= -(2**(byte_width-1)) and quant_max < (2**(byte_width-1)): + return torch.qint8 + else: + return torch.qint32 + +def round_ste(x: torch.Tensor): + return (x.round() - x).detach() + x + +def fake_quantize(x: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + quant_min: int, quant_max: int): + x_int = round_ste(x / scale) + zero_point + x_quant = torch.clamp(x_int, quant_min, quant_max) + x_dequant = (x_quant - zero_point) * scale + return x_dequant + +def reshape_qparams_by_channel(x: torch.Tensor, scale: torch.Tensor, + zero_point: torch.Tensor, ch_axis: int): + new_shape = [1] * len(x.shape) + new_shape[ch_axis] = x.shape[ch_axis] + scale = scale.reshape(new_shape) + zero_point = zero_point.reshape(new_shape) + return scale, zero_point + +def transform_and_flatten_tensor_by_channel(x: torch.Tensor, ch_axis: int): + new_axis_list = list(range(len(x.shape))) + new_axis_list[ch_axis] = 0 + new_axis_list[0] = ch_axis + x_tran = x.permute(new_axis_list) + x_flat = torch.flatten(x_tran, start_dim=1) + return x_flat + +def replace_activation_with_identity(module: torch.nn.Module, activations: list) -> None: + reassign = dict() + for name, child_module in module.named_children(): + replace_activation_with_identity(child_module, activations) + for activation in activations: + if isinstance(child_module, activation): + reassign[name] = torch.nn.Identity() + for key, value in reassign.items(): + module._modules[key] = value + +def quantized_forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.inp_quant(x) + x = self._forward_impl(x) + x = self.out_dequant(x) + return x + +class StopForwardException(Exception): + """ + Used to throw and catch an exception to stop traversing the graph. + """ + pass + +class DataSaverHook: + """ + Forward hook that stores the input and output of a layer. + """ + def __init__(self, store_input=False, store_output=False, stop_forward=False): + self.store_input = store_input + self.store_output = store_output + self.stop_forward = stop_forward + + self.input_store = None + self.output_store = None + + def __call__(self, module, input_batch, output_batch): + if self.store_input: + self.input_store = input_batch + if self.store_output: + self.output_store = output_batch + if self.stop_forward: + raise StopForwardException + +class GradSaverHook: + """ + Backward hook that stores the gradients of a layer. + """ + def __init__(self, store_grad=True): + self.store_grad = store_grad + self.stop_backward = False + self.grad_out = None + + def __call__(self, module, grad_input, grad_output): + if self.store_grad: + self.grad_out = grad_output[0] + if self.stop_backward: + raise StopForwardException + + +class LinearTempDecay: + """ + Class to implement a linear temperature decay scheduler for a given maximum time step. + + :param t_max: maximum number of time steps to decay temperature over. + :param rel_start_decay: relative point in time to start the decay from the maximum time step. [default=.2] + :param start_b: initial temperature value. [default=10] + :param end_b: final temperature value. [default=2] + + """ + def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 10, end_b: int = 2): + self.t_max = t_max + self.start_decay = rel_start_decay * t_max + self.start_b = start_b + self.end_b = end_b + + def __call__(self, t): + """ + Cosine annealing scheduler for temperature b. + :param t: the current time step + :return: scheduled temperature + """ + if t < self.start_decay: + return self.start_b + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_b + (self.start_b - self.end_b) * max(0.0, (1 - rel_t)) + +class Node: + def __init__(self, cost=0, profit=0, bit=None, parent=None, left=None, middle=None, right=None, position='middle'): + self.parent = parent + self.left = left + self.middle = middle + self.right = right + self.position = position + self.cost = cost + self.profit = profit + self.bit = bit + + def __str__(self): + return 'cost: {:.2f} profit: {:.2f}'.format(self.cost, self.profit) + + def __repr__(self): + return self.__str__() + + +class GraphPlotter: + def __init__(self, save_dir: str = './'): + self.save_dir = save_dir + + def line_plotter(self, columns, names, name_fmt: str = '{}', + title: str = '', xlabel: str = '', ylabel: str = '', ytype: str = '-'): + data = [graph_objects.Scatter( + y = columns[i], + mode = 'lines + markers', + name = name_fmt.format(column_name), + ) for i, column_name in enumerate(names)] + layout = graph_objects.Layout( + title = title, + xaxis = dict(title=xlabel), + yaxis = dict(title=ylabel, type=ytype) + ) + fig = graph_objects.Figure(data, layout) + self.save_plot(title, fig) + + def save_plot(self, title, fig: graph_objects.Figure): + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + fig.write_image('{}/{}_plot.png'.format(self.save_dir, title)) \ No newline at end of file diff --git a/trailmet/models/resnet.py b/trailmet/models/resnet.py index 841325a..b40e55d 100644 --- a/trailmet/models/resnet.py +++ b/trailmet/models/resnet.py @@ -316,7 +316,7 @@ def _make_layer(self, block, planes, blocks, stride=1): return nn.Sequential(*layers) - def forward(self, x): + def _forward_impl(self, x): x = self.conv1(x) x = self.bn1(x) x = self.active(x) @@ -336,6 +336,9 @@ def forward(self, x): else: return x + def forward(self, x): + return self._forward_impl(x) + def get_bn_layers(self): bn_layers = [] for l_blocks in [self.layer1, self.layer2, self.layer3, self.layer4]: