diff --git a/scripts/instance/mask_rcnn/train_mask_rcnn.py b/scripts/instance/mask_rcnn/train_mask_rcnn.py index 4910be5afd..09bff75f04 100644 --- a/scripts/instance/mask_rcnn/train_mask_rcnn.py +++ b/scripts/instance/mask_rcnn/train_mask_rcnn.py @@ -113,6 +113,11 @@ def parse_args(): parser.add_argument('--kv-store', type=str, default='nccl', help='KV store options. local, device, nccl, dist_sync, dist_device_sync, ' 'dist_async are available.') + parser.add_argument('--use-pinned', action='store_true', + help='Whether to use pinned memory buffers to stage the input data.') + parser.add_argument('--pinned-buffer-size', type=int, default=0, + help='Size of the staged memory buffers for input data. If the value is <= 0, ' + 'the buffers will be dynamically reshaped. Default is 0.') args = parser.parse_args() if args.horovod: @@ -124,6 +129,8 @@ def parse_args(): args.lr = float(args.lr) if args.lr else 0.01 args.lr_warmup = args.lr_warmup if args.lr_warmup else 1000 args.wd = float(args.wd) if args.wd else 1e-4 + global use_pinned, pinned_buffer_size + use_pinned, pinned_buffer_size = args.use_pinned, args.pinned_buffer_size return args @@ -175,6 +182,7 @@ def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix def _stage_data(i, data, ctx_list, pinned_data_stage): + global pinned_buffer_size def _get_chunk(data, storage): s = storage.reshape(shape=(storage.size,)) s = s[:data.size] @@ -192,7 +200,12 @@ def _get_chunk(data, storage): for j in range(len(storage)): if data[j].size > storage[j].size: - storage[j] = data[j].as_in_context(mx.cpu_pinned()) + if data[j].size > pinned_buffer_size: + storage[j] = data[j].as_in_context(mx.cpu_pinned()) + else: + storage[j] = mx.nd.zeros(shape=(pinned_buffer_size), + dtype=data[j].dtype, + ctx=mx.cpu_pinned()) return [_get_chunk(d, s) for d, s in zip(data, storage)] @@ -204,14 +217,19 @@ def split_and_load(batch, ctx_list): """Split data to 1 batch each device.""" new_batch = [] for i, data in enumerate(batch): - if isinstance(data, (list, tuple)): - new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)] + if not isinstance(data, (list, tuple)): + data = [data] + global use_pinned + if use_pinned: + staged_data = _stage_data(i, data, ctx_list, pinned_data_stage) + new_data = [x.as_in_context(ctx) for x, ctx in zip(staged_data, ctx_list)] else: - new_data = [data.as_in_context(ctx_list[0])] + new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)] new_batch.append(new_data) return new_batch + def validate(net, val_data, ctx, eval_metric, args): """Test on validation dataset.""" clipper = gcv.nn.bbox.BBoxClipToImage()