-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AMP to ImageNet classification and segmentation scripts + auto layout #1201
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from mxnet import gluon, nd | ||
from mxnet import autograd as ag | ||
from mxnet.gluon.data.vision import transforms | ||
from mxnet.contrib import amp | ||
|
||
import gluoncv as gcv | ||
gcv.utils.check_version('0.6.0') | ||
|
@@ -104,6 +105,10 @@ def parse_args(): | |
help='name of training log file') | ||
parser.add_argument('--use-gn', action='store_true', | ||
help='whether to use group norm.') | ||
parser.add_argument('--amp', action='store_true', | ||
help='Use MXNet AMP for mixed precision training.') | ||
parser.add_argument('--auto-layout', action='store_true', | ||
help='Add layout optimization to AMP. Must be used in addition of `--amp`.') | ||
opt = parser.parse_args() | ||
return opt | ||
|
||
|
@@ -121,6 +126,11 @@ def main(): | |
|
||
logger.info(opt) | ||
|
||
assert not opt.auto_layout or opt.amp, "--auto-layout needs to be used with --amp" | ||
|
||
if opt.amp: | ||
amp.init(layout_optimization=opt.auto_layout) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Referring to definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's an internal feature, it will be added soon There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for clarification. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curiously, when setting both |
||
|
||
batch_size = opt.batch_size | ||
classes = 1000 | ||
num_training_samples = 1281167 | ||
|
@@ -347,10 +357,13 @@ def train(ctx): | |
for k, v in net.collect_params('.*beta|.*gamma|.*bias').items(): | ||
v.wd_mult = 0.0 | ||
|
||
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params) | ||
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=(False if opt.amp else None)) | ||
if opt.resume_states != '': | ||
trainer.load_states(opt.resume_states) | ||
|
||
if opt.amp: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here may need change to |
||
amp.init_trainer(trainer) | ||
|
||
if opt.label_smoothing or opt.mixup: | ||
sparse_label_loss = False | ||
else: | ||
|
@@ -402,8 +415,13 @@ def train(ctx): | |
p.astype('float32', copy=False)) for yhat, y, p in zip(outputs, label, teacher_prob)] | ||
else: | ||
loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)] | ||
for l in loss: | ||
l.backward() | ||
if opt.amp: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here may need change to |
||
with amp.scale_loss(loss, trainer) as scaled_loss: | ||
ag.backward(scaled_loss) | ||
else: | ||
for l in loss: | ||
l.backward() | ||
|
||
trainer.step(batch_size) | ||
|
||
if opt.mixup: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import mxnet as mx | ||
from mxnet import gluon, autograd | ||
from mxnet.gluon.data.vision import transforms | ||
from mxnet.contrib import amp | ||
|
||
import gluoncv | ||
gluoncv.utils.check_version('0.6.0') | ||
|
@@ -99,6 +100,11 @@ def parse_args(): | |
# synchronized Batch Normalization | ||
parser.add_argument('--syncbn', action='store_true', default=False, | ||
help='using Synchronized Cross-GPU BatchNorm') | ||
# performance related | ||
parser.add_argument('--amp', action='store_true', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We usually add |
||
help='Use MXNet AMP for mixed precision training.') | ||
parser.add_argument('--auto-layout', action='store_true', | ||
help='Add layout optimization to AMP. Must be used in addition of `--amp`.') | ||
# the parser | ||
args = parser.parse_args() | ||
|
||
|
@@ -229,7 +235,12 @@ def __init__(self, args, logger): | |
v.wd_mult = 0.0 | ||
|
||
self.optimizer = gluon.Trainer(self.net.module.collect_params(), args.optimizer, | ||
optimizer_params, kvstore=kv) | ||
optimizer_params, update_on_kvstore=(False if args.amp else None)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I know why |
||
|
||
|
||
if args.amp: | ||
amp.init_trainer(self.optimizer) | ||
|
||
# evaluation metrics | ||
self.metric = gluoncv.utils.metrics.SegmentationMetric(trainset.num_class) | ||
|
||
|
@@ -241,7 +252,11 @@ def training(self, epoch): | |
outputs = self.net(data.astype(args.dtype, copy=False)) | ||
losses = self.criterion(outputs, target) | ||
mx.nd.waitall() | ||
autograd.backward(losses) | ||
if args.amp: | ||
with amp.scale_loss(losses, self.optimizer) as scaled_losses: | ||
autograd.backward(scaled_losses) | ||
else: | ||
autograd.backward(losses) | ||
self.optimizer.step(self.args.batch_size) | ||
for loss in losses: | ||
train_loss += np.mean(loss.asnumpy()) / len(losses) | ||
|
@@ -281,7 +296,10 @@ def save_checkpoint(net, args, epoch, mIoU, is_best=False): | |
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp" | ||
|
||
if args.amp: | ||
amp.init(layout_optimization=args.auto_layout) | ||
# build logger | ||
filehandler = logging.FileHandler(os.path.join(args.save_dir, args.logging_file)) | ||
streamhandler = logging.StreamHandler() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add an option like
--target-dtype
since now we not only havefloat16
for amp, butbfloat16
. Then, we can passtarget-dtype
toamp.init()
to enable float16/bfloat16 training for GPU and CPU respectively. Thanks.