diff --git a/scripts/segmentation/test.py b/scripts/segmentation/test.py index a1066d725c..df0f736490 100644 --- a/scripts/segmentation/test.py +++ b/scripts/segmentation/test.py @@ -130,8 +130,8 @@ def test(model, args, input_transform): im_paths = dsts predicts = evaluator.parallel_forward(data) for predict, impath in zip(predicts, im_paths): - predict = mx.nd.squeeze(mx.nd.argmax(predict[0], 1)).asnumpy() + \ - testset.pred_offset + predict = mx.nd.squeeze(mx.nd.argmax(predict, 1), axis=0).\ + asnumpy() + testset.pred_offset mask = get_color_pallete(predict, args.dataset) outname = os.path.splitext(impath)[0] + '.png' mask.save(os.path.join(outdir, outname)) @@ -194,6 +194,7 @@ def benchmarking(model, args): if __name__ == "__main__": args = parse_args() + args.test_batch_size = max(1, args.ngpus) logging.basicConfig() logger = logging.getLogger('logger') logger.setLevel(logging.INFO) diff --git a/scripts/segmentation/train.py b/scripts/segmentation/train.py index aada216d17..a51a00f031 100644 --- a/scripts/segmentation/train.py +++ b/scripts/segmentation/train.py @@ -88,6 +88,8 @@ def parse_args(): # the parser args = parser.parse_args() # handle contexts + if args.ngpus == 0: + args.no_cuda = True if args.no_cuda: print('Using CPU') args.kvstore = 'local'