Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix image classification training scripts and readme #309

Merged
merged 7 commits into from
Oct 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gluoncv/model_zoo/resnetv1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,11 @@ def _make_layer(self, stage_index, block, planes, blocks, strides=1, dilation=1,
with downsample.name_scope():
if avg_down:
if dilation == 1:
downsample.add(nn.AvgPool2D(pool_size=strides, strides=strides))
downsample.add(nn.AvgPool2D(pool_size=strides, strides=strides,
ceil_mode=True, count_include_pad=False))
else:
downsample.add(nn.AvgPool2D(pool_size=1, strides=1))
downsample.add(nn.AvgPool2D(pool_size=1, strides=1,
ceil_mode=True, count_include_pad=False))
downsample.add(nn.Conv2D(channels=planes * block.expansion, kernel_size=1,
strides=1, use_bias=False))
downsample.add(norm_layer(**self.norm_kwargs))
Expand Down
152 changes: 3 additions & 149 deletions scripts/classification/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,150 +1,4 @@
# CIFAR10

Here we present examples of training resnet/wide-resnet on CIFAR10 dataset.

The main training script is `train.py`. The script takes various parameters, thus we offer suggested parameters, and corresponding results.

We also experiment the [Mix-Up augmentation method](https://arxiv.org/abs/1710.09412), and compare results for each model.

## Models

We offer models in `ResNetV1`, `ResNetV2` and `WideResNet`, with various parameters. Following is a list of available pretrained models for certain parameters, and their accuracy on CIFAR10:

| Model | Accuracy |
|------------------|----------|
| ResNet20_v1 | 0.9160 |
| ResNet56_v1 | 0.9387 |
| ResNet110_v1 | 0.9471 |
| ResNet20_v2 | 0.9158 |
| ResNet56_v2 | 0.9413 |
| ResNet110_v2 | 0.9484 |
| WideResNet16_10 | 0.9614 |
| WideResNet28_10 | 0.9667 |
| WideResNet40_8 | 0.9673 |

## Demo

Before training your own model, you may want to take a look at how it will look like.

Here we provide you a script `demo.py` to load a pre-trained model and predict on an image.

**Execution**

```
python demo --model cifar_resnet110_v2 --input-pic ~/Pictures/demo.jpg
```

**Parameters Explained**

- `--model`: The model to use.
- `--saved-params`: the path to a locally saved model.
- `--input-pic`: the path to the input picture file.

## Training

Training can be done by either `train.py` or `train_mixup.py`.

Training a model on ResNet110_v2 can be done with

```
python train.py --num-epochs 240 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\
--wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 80,160 --model cifar_resnet110_v2
```

With mixup, the command is

```
python train_mixup.py --num-epochs 350 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\
--wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 150,250 --model cifar_resnet110_v2
```

To get results from a different ResNet, modify `--model`.

Results:

| Model | Accuracy | Mix-Up |
|--------------|----------|--------|
| ResNet20_v1 | 0.9115 | 0.9161 |
| ResNet20_v2 | 0.9117 | 0.9119 |
| ResNet56_v2 | 0.9307 | 0.9414 |
| ResNet110_v2 | 0.9414 | 0.9447 |

Pretrained Model:

| Model | Accuracy |
|--------------|----------|
| ResNet20_v1 | 0.9160 |
| ResNet56_v1 | 0.9387 |
| ResNet110_v1 | 0.9471 |
| ResNet20_v2 | 0.9130 |
| ResNet56_v2 | 0.9413 |
| ResNet110_v2 | 0.9464 |

by script:

```
python train_mixup.py --num-epochs 450 --mode hybrid --num-gpus 2 -j 32 --batch-size 64 --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 150,250 --model cifar_resnet20_v1
```

## Wide ResNet

Training a model on WRN-28-10 can be done with

```
python train.py --num-epochs 200 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\
--wd 0.0005 --lr 0.1 --lr-decay 0.2 --lr-decay-epoch 60,120,160\
--model cifar_wideresnet28 --width-factor 10
```

With mixup, the command is

```
python train_mixup.py --num-epochs 350 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\
--wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 80,160,240\
--model cifar_wideresnet28 --width-factor 10
```

To get results from a different WRN, modify `--model` and `--width-factor`.

Results:

| Model | Accuracy | Mix-Up |
|--------------|----------|--------|
| WRN-16-10 | 0.9527 | 0.9602 |
| WRN-28-10 | 0.9584 | 0.9667 |
| WRN-40-8 | 0.9559 | 0.9620 |

Pretrained Model:

| Model | Accuracy |
|------------------|----------|
| WideResNet20_v1 | 0.9614 |
| WideResNet56_v1 | 0.9667 |
| WideResNet110_v1 | 0.9673 |

by scripts:

```
python train_mixup.py --num-epochs 500 --mode hybrid --num-gpus 2 -j 32 --batch-size 64 --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 100,200,300 --model cifar_wideresnet16_10
```

**Parameters Explained**

- `--batch-size`: per-device batch size for the training.
- `--num-gpus`: the number of GPUs to use for computation, default is `0` and it means only using CPU.
- `--model`: The model to train. For `CIFAR10` we offer [`ResNet`](https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/cifarresnet.py) and [`WideResNet`](https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/cifarwideresnet.py) as options.
- `--num-data-workers`/`-j`: the number of data processing workers.
- `--num-epochs`: the number of training epochs.
- `--lr`: the initial learning rate in training.
- `--momentum`: the momentum parameter.
- `--wd`: the weight decay parameter.
- `--lr-decay`: the learning rate decay factor.
- `--lr-decay-period`: the learning rate decay period, i.e. for every `--lr-decay-period` epochs, the learning rate will decay by a factor of `--lr-decay`.
- `--lr-decay-epoch`: epochs at which the learning rate decay by a factor of `--lr-decay`.
- `--width-factor`: parameters for `WideResNet` model.
- `--drop-rate`: parameters for `WideResNet` model.
- `--mode`: whether to use `hybrid` mode to speed up the training process.
- `--save-period`: for every `--save-period`, the model will be saved to disk.
- `--save-dir`: the directory to save the models.
- `--logging-dir`: the directory to save the training logs.
# Image Classification on CIFAR10

Please refer to [GluonCV Model Zoo](http://gluon-cv.mxnet.io/model_zoo/index.html#image-classification)
for available pretrained models, training hyper-parameters, etc.
29 changes: 3 additions & 26 deletions scripts/classification/imagenet/README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@
# Image Classification

Here we present an examples to train gluon on image classification tasks.

## ImageNet

Here we present examples of training resnet on ImageNet dataset.

The main training script is `train_imagenet.py`. The script takes various parameters, thus we offer suggested parameters, and corresponding results.

### ResNet50_v2

Training a ResNet50_v2 can be done with:

```
python train_imagenet.py --batch-size 64 --num-gpus 4 -j 32 --mode hybrid\
--num-epochs 120 --lr 0.1 --momentum 0.9 --wd 0.0001\
--lr-decay 0.1 --lr-decay-epoch 30,60,90 --model resnet50_v2
```

Results:

| Model | Top-1 Error | Top-5 Error |
|--------------|-------------|-------------|
| ResNet50_v2 | 0.2428 | 0.0738 |

# Image Classification on ImageNet

Please refer to [GluonCV Model Zoo](http://gluon-cv.mxnet.io/model_zoo/index.html#image-classification)
for available pretrained models, training hyper-parameters, etc.
70 changes: 44 additions & 26 deletions scripts/classification/imagenet/train_imagenet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse, time, logging, os
import argparse, time, logging, os, math

import numpy as np
import mxnet as mx
Expand Down Expand Up @@ -61,6 +61,8 @@
help='type of model to use. see vision_model for options.')
parser.add_argument('--input-size', type=int, default=224,
help='size of the input image size. default is 224')
parser.add_argument('--crop-ratio', type=float, default=0.875,
help='Crop ratio during validation. default is 0.875')
parser.add_argument('--use-pretrained', action='store_true',
help='enable using pretrained model from gluon.')
parser.add_argument('--use_se', action='store_true',
Expand All @@ -77,12 +79,18 @@
help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.')
parser.add_argument('--batch-norm', action='store_true',
help='enable batch normalization or not in vgg. default is false.')
parser.add_argument('--log-interval', type=int, default=50,
help='Number of batches to wait before logging.')
parser.add_argument('--save-frequency', type=int, default=10,
help='frequency of model saving.')
parser.add_argument('--save-dir', type=str, default='params',
help='directory of saved models')
parser.add_argument('--resume-epoch', type=int, default=0,
help='epoch to resume training from.')
parser.add_argument('--resume-params', type=str, default='',
help='path of parameters to load from.')
parser.add_argument('--resume-states', type=str, default='',
help='path of trainer state to load from.')
parser.add_argument('--log-interval', type=int, default=50,
help='Number of batches to wait before logging.')
parser.add_argument('--logging-file', type=str, default='train_imagenet.log',
help='name of training log file')
opt = parser.parse_args()
Expand Down Expand Up @@ -137,6 +145,8 @@

net = get_model(model_name, **kwargs)
net.cast(opt.dtype)
if opt.resume_params is not '':
net.load_parameters(opt.resume_params, ctx = context)

# Two functions for reading data from record file or raw images
def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers):
Expand All @@ -147,6 +157,8 @@ def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num
jitter_param = 0.4
lighting_param = 0.1
input_size = opt.input_size
crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
resize = int(math.ceil(input_size / crop_ratio))
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]

Expand Down Expand Up @@ -187,7 +199,7 @@ def batch_fn(batch, ctx):
shuffle = False,
batch_size = batch_size,

resize = 256,
resize = resize,
data_shape = (3, input_size, input_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
Expand All @@ -203,6 +215,8 @@ def get_data_loader(data_dir, batch_size, num_workers):
jitter_param = 0.4
lighting_param = 0.1
input_size = opt.input_size
crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
resize = int(math.ceil(input_size / crop_ratio))

def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
Expand All @@ -219,7 +233,7 @@ def batch_fn(batch, ctx):
normalize
])
transform_test = transforms.Compose([
transforms.Resize(256, keep_ratio=True),
transforms.Resize(resize, keep_ratio=True),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize
Expand Down Expand Up @@ -256,22 +270,22 @@ def batch_fn(batch, ctx):
save_dir = ''
save_frequency = 0

def label_transform(label, classes, eta=0.0):
ind = label.astype('int')
res = nd.zeros((ind.shape[0], classes), ctx = label.context)
res += eta/classes
res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1 - eta + eta/classes
def mixup_transform(label, classes, lam=1, eta=0.0):
if isinstance(label, nd.NDArray):
label = [label]
res = []
for l in label:
y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
res.append(lam*y1 + (1-lam)*y2)
return res

def smooth(label, classes, eta=0.1):
if isinstance(label, nd.NDArray):
label = [label]
smoothed = []
for l in label:
ind = l.astype('int')
res = nd.zeros((ind.shape[0], classes), ctx = l.context)
res += eta/classes
res[nd.arange(ind.shape[0], ctx = l.context), ind] = 1 - eta + eta/classes
res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes)
smoothed.append(res)
return smoothed

Expand All @@ -293,13 +307,16 @@ def test(ctx, val_data):
def train(ctx):
if isinstance(ctx, mx.Context):
ctx = [ctx]
net.initialize(mx.init.MSRAPrelu(), ctx=ctx)
if opt.resume_params is '':
net.initialize(mx.init.MSRAPrelu(), ctx=ctx)

if opt.no_wd:
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
v.wd_mult = 0.0

trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
if opt.resume_states is not '':
trainer.load_states(opt.resume_states)

if opt.label_smoothing or opt.mixup:
L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
Expand All @@ -308,7 +325,7 @@ def train(ctx):

best_val_score = 1

for epoch in range(opt.num_epochs):
for epoch in range(opt.resume_epoch, opt.num_epochs):
tic = time.time()
if opt.use_rec:
train_data.reset()
Expand All @@ -322,21 +339,16 @@ def train(ctx):
lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha)
if epoch >= opt.num_epochs - opt.mixup_off_epoch:
lam = 1
data_mixup = [lam*X + (1-lam)*X[::-1] for X in data]
data = [lam*X + (1-lam)*X[::-1] for X in data]

label_mixup = []
if opt.label_smoothing:
eta = 0.1
else:
eta = 0.0
for Y in label:
y1 = label_transform(Y, classes, eta)
y2 = label_transform(Y[::-1], classes, eta)
label_mixup.append(lam*y1 + (1-lam)*y2)
label = mixup_transform(label, classes, lam, eta)

data = data_mixup
label = label_mixup
elif opt.label_smoothing:
hard_label = label
label = smooth(label, classes)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this label is not used now? After adding line 362 the train_metric seems to be updated with hard_label, irrelevant to the smoothed label.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label is used in loss calculation.


with ag.record():
Expand All @@ -352,7 +364,10 @@ def train(ctx):
for out in outputs]
train_metric.update(label, output_softmax)
else:
train_metric.update(label, outputs)
if opt.label_smoothing:
train_metric.update(hard_label, outputs)
else:
train_metric.update(label, outputs)

if opt.log_interval and not (i+1)%opt.log_interval:
train_metric_name, train_metric_score = train_metric.get()
Expand All @@ -370,15 +385,18 @@ def train(ctx):
logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score and epoch > 50:
if err_top1_val < best_val_score:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

def main():
if opt.mode == 'hybrid':
Expand Down
Loading