Skip to content

Commit

Permalink
hang by the wierd result
Browse files Browse the repository at this point in the history
  • Loading branch information
ricky40403 committed Apr 20, 2020
1 parent a814c8c commit 2a54fbf
Show file tree
Hide file tree
Showing 9 changed files with 808 additions and 343 deletions.
55 changes: 48 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,63 @@ and zeroQ will prefer to generate data that will not pay attention to the data d

But I can not reproduce the beautiful generated data in the paper. :sweat_smile::sweat_smile::sweat_smile:

## Experiment



<table>
<tr><th> model </th> <th> QuanType </th> <th> W/A bit </th> <th> top1 </th> <th> top5 </th></tr>
<tr><th rowspan="10"> resnet18 </th>
<tr><th colspan="2"> fp </th><th> 69.758 </th> <th> 89.078 </th></tr>
<tr><th rowspan="4"> zeroQ </th>
<th> 8/8 </th> <th> 69.230 </th> <th> 88.840 </th></tr>
<th> 4/8 </th> <th> 57.582 </th> <th> 81.182 </th></tr>
<th> 8/4 </th> <th> 1.130 </th> <th> 3.056 </th></tr>
<th> 4/4 </th> <th> 0.708 </th><th> 2.396 </th></tr></tr>
<tr><th rowspan="4"> GDFQ </th>
<th> 8/8 </th> <th> </th> <th> </th></tr>
<th> 4/8 </th> <th> </th> <th> </th></tr>
<th> 8/4 </th> <th> </th> <th> </th></tr>
<th> 4/4 </th> <th> </th><th> </th></tr></tr>
</table>

```
I also try to clone the [origin zeroQ repository](https://github.com/amirgholami/ZeroQ/blob/ba37f793dbcb9f966b58f6b8d1e9de3c34a11b8c/classification/utils/quantize_model.py#L36) and just set the all weight_bit to 4, the acc is about 10.
And get about 24.16% by using pytorchcv. But 2.16 by using torchvision's model.
```


## Training

* The floating model using torchvision, so the architecture must fit the torchvisoin model name.
You may reference https://pytorch.org/docs/stable/torchvision/models.html
* Batch size set the default batch size as 256, and it will follow the related rules of the learning rate, iteration, batch size.
Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour: <https://arxiv.org/abs/1706.02677v1>
* Batch size set the default batch size as 32

Default is 4 bit
Ex: training with resnet
```
python train.py [imagenet path]
optional arguments:
-a , --arch model architecture
-m , --method zeroQ, GDFQ
--n_epochs GDFQ's trainig epochs
--n_iter training iteration per trainig epochs
--batch_size batch size
--q_lr learning rate of GDFQ's quantization model
--g_lr learning rate of GDFQ's generator model
-qa quantization activation bit
-qw quantization weight bit
-qb quantization bias bit
```

Ex: Training with resnet
```
python train.py -a resnet18
python train.py -a resnet18 --batch_size 64
```

Training with vgg16_bn with 8 bit activation, 8bit weight, 8 bit bias
Ex: Training with vgg16_bn with 8 bit activation, 8bit weight, 8 bit bias
```
python train.py -a vgg16_bn -qa 8 -qw 8 -qb 8
```
Expand All @@ -43,10 +84,10 @@ python train.py -a vgg16_bn -qa 8 -qw 8 -qb 8
2. The toy experiment can not generate the beautiful output, maybe something wrong. (Any advice or PR is welcome)

### Todo
- [ ] add zeroQ traning.
- [x] add zeroQ traning.
- [ ] Check the effect of the BNS and KL.

### Note
This had not tested the performance yet.
The performace did reach the number in the paper.
So it may have some bug for now.
All the results are base on fake quantization, not the true low bit inference.
140 changes: 84 additions & 56 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,99 @@
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from utils.quantize_model import *
from train_script import train_GDFQ, train_zeroQ
from utils.val import validation
from train_script import train_GDFQ

from utils.quantize_model import *
from pytorchcv.model_provider import get_model as ptcv_get_model


parser = argparse.ArgumentParser()
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument("-a", "--arch", type=str, default="resnet18", help="number of epochs of training")
parser.add_argument("--n_epochs", type=int, default=400, help="number of epochs of training")
parser.add_argument("--n_iter", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
parser.add_argument("--q_lr", type=float, default=1e-6, help="adam: learning rate")
parser.add_argument("--g_lr", type=float, default=1e-3, help="adam: learning rate")
parser.add_argument("-qa", "--quan_a_bit", type=int, default=4, help=" quan activation bit")
parser.add_argument("-qw", "--quan_w_bit", type=int, default=4, help=" quan weight bit")
parser.add_argument("-qb", "--quan_b_bit", type=int, default=4, help=" quan bias bit")


parser.add_argument("-a", "--arch", type=str,
default="resnet18", help="number of epochs of training")
parser.add_argument("-m", "--method", type=str, default="GDFQ",
help="method of training")
parser.add_argument("--n_epochs", type=int, default=400,
help="number of epochs of training")
parser.add_argument("--n_iter", type=int, default=200,
help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=32,
help="size of the batches")
parser.add_argument("--q_lr", type=float, default=1e-6,
help="adam: learning rate")
parser.add_argument("--g_lr", type=float, default=1e-3,
help="adam: learning rate")
parser.add_argument("-qa", "--quan_a_bit", type=int,
default=4, help=" quan activation bit")
parser.add_argument("-qw", "--quan_w_bit", type=int,
default=4, help=" quan weight bit")
parser.add_argument("-qb", "--quan_b_bit", type=int,
default=4, help=" quan bias bit")


def main():

args = parser.parse_args()
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)

criterion = nn.CrossEntropyLoss().cuda()
FP_model = getattr(torch_model, args.arch)(pretrained=True)
fp_1, fp_5 = validation(val_loader, FP_model, criterion)

Q_model = quantize_model(FP_model, args.quan_a_bit, args.quan_w_bit, args.quan_b_bit)
Q_model = freeze_bn(Q_model)
Q_model = freeze_act(Q_model)

q_init_1, q_init5 = validation(val_loader, Q_model, criterion)


Q_model = un_freeze_act(Q_model)
Q_model = train_GDFQ.train_GDFQ(FP_model, Q_model, val_loader, criterion,
batch_size = args.batch_size,
total_epoch = args.n_epochs, iter_per_epoch=args.n_iter,
q_lr = args.q_lr, g_lr = args.g_lr)
Q_model = freeze_act(Q_model)
q_final_1, q_final_5 = validation(val_loader, Q_model, criterion)

print("FP Model ==> Top1: {}, Top5: {}".format(fp_1, fp_5))
print("Q Model Initial ==> Top1: {}, Top5: {}".format(q_init_1, q_init5))
print("Q Model Final ==> Top1: {}, Top5: {}".format(q_final_1, q_final_5))


args = parser.parse_args()

# restrice method input
assert args.method in ["zeroQ", "GDFQ"]


torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# prepare validation data
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)

for_inception = args.arch.startswith('inception')

# FP_model = getattr(torch_model, args.arch)(pretrained=True)
FP_model = ptcv_get_model(args.arch, pretrained=True)
# fp_1, fp_5 = validation(val_loader, FP_model)

Q_model = quantize_model(FP_model, args.quan_a_bit,
args.quan_w_bit, args.quan_b_bit)


# _, _ = validation(val_loader, Q_model, criterion)

# exit()

if "GDFQ" == args.method:
Q_model = train_GDFQ.train_GDFQ(FP_model, Q_model, val_loader,
batch_size=args.batch_size,
total_epoch=args.n_epochs, iter_per_epoch=args.n_iter,
q_lr=args.q_lr, g_lr=args.g_lr,
for_incep=for_inception)

elif "zeroQ" == args.method:

Q_model = train_zeroQ.train_zeroQ(FP_model, Q_model,
val_loader,
batch_size=args.batch_size,
for_incep=for_inception)
exit()
Q_model = freeze_act(Q_model)
q_final_1, q_final_5 = validation(val_loader, Q_model)

# print("FP Model ==> Top1: {}, Top5: {}".format(fp_1, fp_5))
print("Q Model Final ==> Top1: {}, Top5: {}".format(q_final_1, q_final_5))


if __name__ == "__main__":
main()
main()
113 changes: 113 additions & 0 deletions train_script/get_zeroQ_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# reference from official zeroQ
# https://github.com/amirgholami/ZeroQ/blob/ba37f793dbcb9f966b58f6b8d1e9de3c34a11b8c/classification/utils/data_utils.py
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository. If not, see <http://www.gnu.org/licenses/>.
#*

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torch


class UniformDataset(Dataset):
"""
get random uniform samples with mean 0 and variance 1
"""
def __init__(self, length, size, transform):
self.length = length
self.transform = transform
self.size = size

def __len__(self):
return self.length

def __getitem__(self, idx):
# var[U(-128, 127)] = (127 - (-128))**2 / 12 = 5418.75
sample = (torch.randint(high=255, size=self.size).float() -
127.5) / 5418.75
return sample


def getRandomData(dataset='cifar10', batch_size=512, for_inception=False):
"""
get random sample dataloader
dataset: name of the dataset
batch_size: the batch size of random data
for_inception: whether the data is for Inception because inception has input size 299 rather than 224
"""
if dataset == 'cifar10':
size = (3, 32, 32)
num_data = 10000
elif dataset == 'imagenet':
num_data = 10000
if not for_inception:
size = (3, 224, 224)
else:
size = (3, 299, 299)
else:
raise NotImplementedError
dataset = UniformDataset(length=10000, size=size, transform=None)
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
num_workers=32)
return data_loader


def getTestData(dataset='imagenet',
batch_size=1024,
path='data/imagenet',
for_inception=False):
"""
Get dataloader of testset
dataset: name of the dataset
batch_size: the batch size of random data
path: the path to the data
for_inception: whether the data is for Inception because inception has input size 299 rather than 224
"""
if dataset == 'imagenet':
input_size = 299 if for_inception else 224
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_dataset = datasets.ImageFolder(
path + 'val',
transforms.Compose([
transforms.Resize(int(input_size / 0.875)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize,
]))
test_loader = DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=32)
return test_loader
elif dataset == 'cifar10':
data_dir = '/rscratch/yaohuic/data/'
normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010))
transform_test = transforms.Compose([transforms.ToTensor(), normalize])

test_dataset = datasets.CIFAR10(root=data_dir,
train=False,
transform=transform_test)
test_loader = DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=32)
return test_loader
Loading

0 comments on commit 2a54fbf

Please sign in to comment.