From fd947a9bb65d138e06255f6a139f7073a021ce64 Mon Sep 17 00:00:00 2001 From: Tong He Date: Mon, 8 Oct 2018 18:22:59 -0700 Subject: [PATCH] Fix image classification training scripts and readme (#309) * fix script, remove outdated readme text * fix metric for smoothed label * enable resuming from previous params and states * fix * Trigger CI * fix scripts --- gluoncv/model_zoo/resnetv1b.py | 6 +- scripts/classification/cifar/README.md | 152 +----------------- scripts/classification/imagenet/README.md | 29 +--- .../classification/imagenet/train_imagenet.py | 70 +++++--- .../imagenet/verify_pretrained.py | 5 +- 5 files changed, 56 insertions(+), 206 deletions(-) diff --git a/gluoncv/model_zoo/resnetv1b.py b/gluoncv/model_zoo/resnetv1b.py index c9e9ee2266..01ea66914d 100644 --- a/gluoncv/model_zoo/resnetv1b.py +++ b/gluoncv/model_zoo/resnetv1b.py @@ -197,9 +197,11 @@ def _make_layer(self, stage_index, block, planes, blocks, strides=1, dilation=1, with downsample.name_scope(): if avg_down: if dilation == 1: - downsample.add(nn.AvgPool2D(pool_size=strides, strides=strides)) + downsample.add(nn.AvgPool2D(pool_size=strides, strides=strides, + ceil_mode=True, count_include_pad=False)) else: - downsample.add(nn.AvgPool2D(pool_size=1, strides=1)) + downsample.add(nn.AvgPool2D(pool_size=1, strides=1, + ceil_mode=True, count_include_pad=False)) downsample.add(nn.Conv2D(channels=planes * block.expansion, kernel_size=1, strides=1, use_bias=False)) downsample.add(norm_layer(**self.norm_kwargs)) diff --git a/scripts/classification/cifar/README.md b/scripts/classification/cifar/README.md index 5ca8df71b1..50d01dda0c 100644 --- a/scripts/classification/cifar/README.md +++ b/scripts/classification/cifar/README.md @@ -1,150 +1,4 @@ -# CIFAR10 - -Here we present examples of training resnet/wide-resnet on CIFAR10 dataset. - -The main training script is `train.py`. The script takes various parameters, thus we offer suggested parameters, and corresponding results. - -We also experiment the [Mix-Up augmentation method](https://arxiv.org/abs/1710.09412), and compare results for each model. - -## Models - -We offer models in `ResNetV1`, `ResNetV2` and `WideResNet`, with various parameters. Following is a list of available pretrained models for certain parameters, and their accuracy on CIFAR10: - -| Model | Accuracy | -|------------------|----------| -| ResNet20_v1 | 0.9160 | -| ResNet56_v1 | 0.9387 | -| ResNet110_v1 | 0.9471 | -| ResNet20_v2 | 0.9158 | -| ResNet56_v2 | 0.9413 | -| ResNet110_v2 | 0.9484 | -| WideResNet16_10 | 0.9614 | -| WideResNet28_10 | 0.9667 | -| WideResNet40_8 | 0.9673 | - -## Demo - -Before training your own model, you may want to take a look at how it will look like. - -Here we provide you a script `demo.py` to load a pre-trained model and predict on an image. - -**Execution** - -``` -python demo --model cifar_resnet110_v2 --input-pic ~/Pictures/demo.jpg -``` - -**Parameters Explained** - -- `--model`: The model to use. -- `--saved-params`: the path to a locally saved model. -- `--input-pic`: the path to the input picture file. - -## Training - -Training can be done by either `train.py` or `train_mixup.py`. - -Training a model on ResNet110_v2 can be done with - -``` -python train.py --num-epochs 240 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\ - --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 80,160 --model cifar_resnet110_v2 -``` - -With mixup, the command is - -``` -python train_mixup.py --num-epochs 350 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\ - --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 150,250 --model cifar_resnet110_v2 -``` - -To get results from a different ResNet, modify `--model`. - -Results: - -| Model | Accuracy | Mix-Up | -|--------------|----------|--------| -| ResNet20_v1 | 0.9115 | 0.9161 | -| ResNet20_v2 | 0.9117 | 0.9119 | -| ResNet56_v2 | 0.9307 | 0.9414 | -| ResNet110_v2 | 0.9414 | 0.9447 | - -Pretrained Model: - -| Model | Accuracy | -|--------------|----------| -| ResNet20_v1 | 0.9160 | -| ResNet56_v1 | 0.9387 | -| ResNet110_v1 | 0.9471 | -| ResNet20_v2 | 0.9130 | -| ResNet56_v2 | 0.9413 | -| ResNet110_v2 | 0.9464 | - -by script: - -``` -python train_mixup.py --num-epochs 450 --mode hybrid --num-gpus 2 -j 32 --batch-size 64 --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 150,250 --model cifar_resnet20_v1 -``` - -## Wide ResNet - -Training a model on WRN-28-10 can be done with - -``` -python train.py --num-epochs 200 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\ - --wd 0.0005 --lr 0.1 --lr-decay 0.2 --lr-decay-epoch 60,120,160\ - --model cifar_wideresnet28 --width-factor 10 -``` - -With mixup, the command is - -``` -python train_mixup.py --num-epochs 350 --mode hybrid --num-gpus 2 -j 32 --batch-size 64\ - --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 80,160,240\ - --model cifar_wideresnet28 --width-factor 10 -``` - -To get results from a different WRN, modify `--model` and `--width-factor`. - -Results: - -| Model | Accuracy | Mix-Up | -|--------------|----------|--------| -| WRN-16-10 | 0.9527 | 0.9602 | -| WRN-28-10 | 0.9584 | 0.9667 | -| WRN-40-8 | 0.9559 | 0.9620 | - -Pretrained Model: - -| Model | Accuracy | -|------------------|----------| -| WideResNet20_v1 | 0.9614 | -| WideResNet56_v1 | 0.9667 | -| WideResNet110_v1 | 0.9673 | - -by scripts: - -``` -python train_mixup.py --num-epochs 500 --mode hybrid --num-gpus 2 -j 32 --batch-size 64 --wd 0.0001 --lr 0.1 --lr-decay 0.1 --lr-decay-epoch 100,200,300 --model cifar_wideresnet16_10 -``` - -**Parameters Explained** - -- `--batch-size`: per-device batch size for the training. -- `--num-gpus`: the number of GPUs to use for computation, default is `0` and it means only using CPU. -- `--model`: The model to train. For `CIFAR10` we offer [`ResNet`](https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/cifarresnet.py) and [`WideResNet`](https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/cifarwideresnet.py) as options. -- `--num-data-workers`/`-j`: the number of data processing workers. -- `--num-epochs`: the number of training epochs. -- `--lr`: the initial learning rate in training. -- `--momentum`: the momentum parameter. -- `--wd`: the weight decay parameter. -- `--lr-decay`: the learning rate decay factor. -- `--lr-decay-period`: the learning rate decay period, i.e. for every `--lr-decay-period` epochs, the learning rate will decay by a factor of `--lr-decay`. -- `--lr-decay-epoch`: epochs at which the learning rate decay by a factor of `--lr-decay`. -- `--width-factor`: parameters for `WideResNet` model. -- `--drop-rate`: parameters for `WideResNet` model. -- `--mode`: whether to use `hybrid` mode to speed up the training process. -- `--save-period`: for every `--save-period`, the model will be saved to disk. -- `--save-dir`: the directory to save the models. -- `--logging-dir`: the directory to save the training logs. +# Image Classification on CIFAR10 +Please refer to [GluonCV Model Zoo](http://gluon-cv.mxnet.io/model_zoo/index.html#image-classification) +for available pretrained models, training hyper-parameters, etc. diff --git a/scripts/classification/imagenet/README.md b/scripts/classification/imagenet/README.md index e6a877227c..92d2b85902 100644 --- a/scripts/classification/imagenet/README.md +++ b/scripts/classification/imagenet/README.md @@ -1,27 +1,4 @@ -# Image Classification - -Here we present an examples to train gluon on image classification tasks. - -## ImageNet - -Here we present examples of training resnet on ImageNet dataset. - -The main training script is `train_imagenet.py`. The script takes various parameters, thus we offer suggested parameters, and corresponding results. - -### ResNet50_v2 - -Training a ResNet50_v2 can be done with: - -``` -python train_imagenet.py --batch-size 64 --num-gpus 4 -j 32 --mode hybrid\ - --num-epochs 120 --lr 0.1 --momentum 0.9 --wd 0.0001\ - --lr-decay 0.1 --lr-decay-epoch 30,60,90 --model resnet50_v2 -``` - -Results: - -| Model | Top-1 Error | Top-5 Error | -|--------------|-------------|-------------| -| ResNet50_v2 | 0.2428 | 0.0738 | - +# Image Classification on ImageNet +Please refer to [GluonCV Model Zoo](http://gluon-cv.mxnet.io/model_zoo/index.html#image-classification) +for available pretrained models, training hyper-parameters, etc. diff --git a/scripts/classification/imagenet/train_imagenet.py b/scripts/classification/imagenet/train_imagenet.py index d6ea7211a8..ea18d0dd14 100644 --- a/scripts/classification/imagenet/train_imagenet.py +++ b/scripts/classification/imagenet/train_imagenet.py @@ -1,4 +1,4 @@ -import argparse, time, logging, os +import argparse, time, logging, os, math import numpy as np import mxnet as mx @@ -61,6 +61,8 @@ help='type of model to use. see vision_model for options.') parser.add_argument('--input-size', type=int, default=224, help='size of the input image size. default is 224') +parser.add_argument('--crop-ratio', type=float, default=0.875, + help='Crop ratio during validation. default is 0.875') parser.add_argument('--use-pretrained', action='store_true', help='enable using pretrained model from gluon.') parser.add_argument('--use_se', action='store_true', @@ -77,12 +79,18 @@ help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') parser.add_argument('--batch-norm', action='store_true', help='enable batch normalization or not in vgg. default is false.') -parser.add_argument('--log-interval', type=int, default=50, - help='Number of batches to wait before logging.') parser.add_argument('--save-frequency', type=int, default=10, help='frequency of model saving.') parser.add_argument('--save-dir', type=str, default='params', help='directory of saved models') +parser.add_argument('--resume-epoch', type=int, default=0, + help='epoch to resume training from.') +parser.add_argument('--resume-params', type=str, default='', + help='path of parameters to load from.') +parser.add_argument('--resume-states', type=str, default='', + help='path of trainer state to load from.') +parser.add_argument('--log-interval', type=int, default=50, + help='Number of batches to wait before logging.') parser.add_argument('--logging-file', type=str, default='train_imagenet.log', help='name of training log file') opt = parser.parse_args() @@ -137,6 +145,8 @@ net = get_model(model_name, **kwargs) net.cast(opt.dtype) +if opt.resume_params is not '': + net.load_parameters(opt.resume_params, ctx = context) # Two functions for reading data from record file or raw images def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num_workers): @@ -147,6 +157,8 @@ def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, num jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size + crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 + resize = int(math.ceil(input_size / crop_ratio)) mean_rgb = [123.68, 116.779, 103.939] std_rgb = [58.393, 57.12, 57.375] @@ -187,7 +199,7 @@ def batch_fn(batch, ctx): shuffle = False, batch_size = batch_size, - resize = 256, + resize = resize, data_shape = (3, input_size, input_size), mean_r = mean_rgb[0], mean_g = mean_rgb[1], @@ -203,6 +215,8 @@ def get_data_loader(data_dir, batch_size, num_workers): jitter_param = 0.4 lighting_param = 0.1 input_size = opt.input_size + crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 + resize = int(math.ceil(input_size / crop_ratio)) def batch_fn(batch, ctx): data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) @@ -219,7 +233,7 @@ def batch_fn(batch, ctx): normalize ]) transform_test = transforms.Compose([ - transforms.Resize(256, keep_ratio=True), + transforms.Resize(resize, keep_ratio=True), transforms.CenterCrop(input_size), transforms.ToTensor(), normalize @@ -256,11 +270,14 @@ def batch_fn(batch, ctx): save_dir = '' save_frequency = 0 -def label_transform(label, classes, eta=0.0): - ind = label.astype('int') - res = nd.zeros((ind.shape[0], classes), ctx = label.context) - res += eta/classes - res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1 - eta + eta/classes +def mixup_transform(label, classes, lam=1, eta=0.0): + if isinstance(label, nd.NDArray): + label = [label] + res = [] + for l in label: + y1 = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) + y2 = l[::-1].one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) + res.append(lam*y1 + (1-lam)*y2) return res def smooth(label, classes, eta=0.1): @@ -268,10 +285,7 @@ def smooth(label, classes, eta=0.1): label = [label] smoothed = [] for l in label: - ind = l.astype('int') - res = nd.zeros((ind.shape[0], classes), ctx = l.context) - res += eta/classes - res[nd.arange(ind.shape[0], ctx = l.context), ind] = 1 - eta + eta/classes + res = l.one_hot(classes, on_value = 1 - eta + eta/classes, off_value = eta/classes) smoothed.append(res) return smoothed @@ -293,13 +307,16 @@ def test(ctx, val_data): def train(ctx): if isinstance(ctx, mx.Context): ctx = [ctx] - net.initialize(mx.init.MSRAPrelu(), ctx=ctx) + if opt.resume_params is '': + net.initialize(mx.init.MSRAPrelu(), ctx=ctx) if opt.no_wd: for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) + if opt.resume_states is not '': + trainer.load_states(opt.resume_states) if opt.label_smoothing or opt.mixup: L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False) @@ -308,7 +325,7 @@ def train(ctx): best_val_score = 1 - for epoch in range(opt.num_epochs): + for epoch in range(opt.resume_epoch, opt.num_epochs): tic = time.time() if opt.use_rec: train_data.reset() @@ -322,21 +339,16 @@ def train(ctx): lam = np.random.beta(opt.mixup_alpha, opt.mixup_alpha) if epoch >= opt.num_epochs - opt.mixup_off_epoch: lam = 1 - data_mixup = [lam*X + (1-lam)*X[::-1] for X in data] + data = [lam*X + (1-lam)*X[::-1] for X in data] - label_mixup = [] if opt.label_smoothing: eta = 0.1 else: eta = 0.0 - for Y in label: - y1 = label_transform(Y, classes, eta) - y2 = label_transform(Y[::-1], classes, eta) - label_mixup.append(lam*y1 + (1-lam)*y2) + label = mixup_transform(label, classes, lam, eta) - data = data_mixup - label = label_mixup elif opt.label_smoothing: + hard_label = label label = smooth(label, classes) with ag.record(): @@ -352,7 +364,10 @@ def train(ctx): for out in outputs] train_metric.update(label, output_softmax) else: - train_metric.update(label, outputs) + if opt.label_smoothing: + train_metric.update(hard_label, outputs) + else: + train_metric.update(label, outputs) if opt.log_interval and not (i+1)%opt.log_interval: train_metric_name, train_metric_score = train_metric.get() @@ -370,15 +385,18 @@ def train(ctx): logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic)) logger.info('[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val)) - if err_top1_val < best_val_score and epoch > 50: + if err_top1_val < best_val_score: best_val_score = err_top1_val net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) + trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch)) if save_frequency and save_dir and (epoch + 1) % save_frequency == 0: net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch)) + trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch)) if save_frequency and save_dir: net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1)) + trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1)) def main(): if opt.mode == 'hybrid': diff --git a/scripts/classification/imagenet/verify_pretrained.py b/scripts/classification/imagenet/verify_pretrained.py index 678a0d89e7..66f80cacca 100644 --- a/scripts/classification/imagenet/verify_pretrained.py +++ b/scripts/classification/imagenet/verify_pretrained.py @@ -1,7 +1,6 @@ -import argparse, os +import argparse, os, math import mxnet as mx -import math from mxnet import gluon, nd, image from mxnet.gluon.nn import Block, HybridBlock from mxnet.gluon.data.vision import transforms @@ -68,7 +67,7 @@ ratio set as 0.875; Set the crop as ceil(input-size/ratio) """ crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875 -resize = math.ceil(input_size/crop_ratio) +resize = int(math.ceil(input_size/crop_ratio)) transform_test = transforms.Compose([ transforms.Resize(resize, keep_ratio=True),