Skip to content

Commit

Permalink
Add auto layout to classification, detection and segmentation scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Mar 11, 2020
1 parent c2390f6 commit 21c9d60
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 6 deletions.
6 changes: 5 additions & 1 deletion scripts/classification/imagenet/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def parse_args():
help='whether to use group norm.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
opt = parser.parse_args()
return opt

Expand All @@ -125,8 +127,10 @@ def main():

logger.info(opt)

assert not opt.auto_layout or opt.amp, "--auto-layout needs to be used with --amp"

if opt.amp:
amp.init()
amp.init(opt.auto_layout)

batch_size = opt.batch_size
classes = 1000
Expand Down
6 changes: 5 additions & 1 deletion scripts/detection/faster_rcnn/train_faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def parse_args():
help='Whether to use static memory allocation. Memory usage will increase.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
parser.add_argument('--horovod', action='store_true',
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
'--gpus is ignored when using --horovod.')
Expand Down Expand Up @@ -622,8 +624,10 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args):
# fix seed for mxnet, numpy and python builtin random generator.
gutils.random.seed(args.seed)

assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp"

if args.amp:
amp.init()
amp.init(args.auto_layout)

# training contexts
if args.horovod:
Expand Down
6 changes: 5 additions & 1 deletion scripts/detection/ssd/train_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def parse_args():
'Currently supports only COCO.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
parser.add_argument('--horovod', action='store_true',
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
'--gpus is ignored when using --horovod.')
Expand Down Expand Up @@ -360,8 +362,10 @@ def train(net, train_data, val_data, eval_metric, ctx, args):
if __name__ == '__main__':
args = parse_args()

assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp"

if args.amp:
amp.init()
amp.init(args.auto_layout)

if args.horovod:
hvd.init()
Expand Down
6 changes: 5 additions & 1 deletion scripts/detection/yolo/train_yolo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def parse_args():
parser.add_argument('--label-smooth', action='store_true', help='Use label smoothing.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
parser.add_argument('--horovod', action='store_true',
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
'--gpus is ignored when using --horovod.')
Expand Down Expand Up @@ -325,8 +327,10 @@ def train(net, train_data, val_data, eval_metric, ctx, args):
if __name__ == '__main__':
args = parse_args()

assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp"

if args.amp:
amp.init()
amp.init(args.auto_layout)

if args.horovod:
if hvd is None:
Expand Down
6 changes: 5 additions & 1 deletion scripts/instance/mask_rcnn/train_mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def parse_args():
help='Whether to use static memory allocation. Memory usage will increase.')
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
parser.add_argument('--horovod', action='store_true',
help='Use MXNet Horovod for distributed training. Must be run with OpenMPI. '
'--gpus is ignored when using --horovod.')
Expand Down Expand Up @@ -700,8 +702,10 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args)
# fix seed for mxnet, numpy and python builtin random generator.
gutils.random.seed(args.seed)

assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp"

if args.amp:
amp.init()
amp.init(args.auto_layout)

# training contexts
if args.horovod:
Expand Down
6 changes: 5 additions & 1 deletion scripts/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def parse_args():
# performance related
parser.add_argument('--amp', action='store_true',
help='Use MXNet AMP for mixed precision training.')
parser.add_argument('--auto-layout', action='store_true',
help='Add layout optimization to AMP. Must be used in addition of `--amp`.')
# handle contexts
if args.no_cuda:
print('Using CPU')
Expand Down Expand Up @@ -263,8 +265,10 @@ def save_checkpoint(net, args, epoch, mIoU, is_best=False):

if __name__ == "__main__":
args = parse_args()
assert not args.auto_layout or args.amp, "--auto-layout needs to be used with --amp"

if args.amp:
amp.init()
amp.init(args.auto_layout)
# build logger
filehandler = logging.FileHandler(os.path.join(args.save_dir, args.logging_file))
streamhandler = logging.StreamHandler()
Expand Down

0 comments on commit 21c9d60

Please sign in to comment.