diff --git a/demo.py b/demo.py index ededbdb..d7ccf87 100644 --- a/demo.py +++ b/demo.py @@ -5,15 +5,10 @@ import importlib import sys from detect.detector import Detector +from symbol.symbol_factory import get_symbol -CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', - 'bottle', 'bus', 'car', 'cat', 'chair', - 'cow', 'diningtable', 'dog', 'horse', - 'motorbike', 'person', 'pottedplant', - 'sheep', 'sofa', 'train', 'tvmonitor') - -def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, - nms_thresh=0.5, force_nms=True): +def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, num_class, + nms_thresh=0.5, force_nms=True, nms_topk=400): """ wrapper for initialize a detector @@ -31,23 +26,25 @@ def get_detector(net, prefix, epoch, data_shape, mean_pixels, ctx, mean pixel values (R, G, B) ctx : mx.ctx running context, mx.cpu() or mx.gpu(?) + num_class : int + number of classes + nms_thresh : float + non-maximum suppression threshold force_nms : bool force suppress different categories """ - sys.path.append(os.path.join(os.getcwd(), 'symbol')) if net is not None: - net = importlib.import_module("symbol_" + net) \ - .get_symbol(len(CLASSES), nms_thresh, force_nms) - detector = Detector(net, prefix + "_" + str(data_shape), epoch, \ - data_shape, mean_pixels, ctx=ctx) + net = get_symbol(net, data_shape, num_classes=num_class, nms_thresh=nms_thresh, + force_nms=force_nms, nms_topk=nms_topk) + detector = Detector(net, prefix, epoch, data_shape, mean_pixels, ctx=ctx) return detector def parse_args(): parser = argparse.ArgumentParser(description='Single-shot detection network demo') - parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', - choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', + help='which network to use') parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg', - help='run demo with images, use comma(without space) to seperate multiple images') + help='run demo with images, use comma to seperate multiple images') parser.add_argument('--dir', dest='dir', nargs='?', help='demo image directory, optional', type=str) parser.add_argument('--ext', dest='extension', help='image extension, optional', @@ -55,7 +52,8 @@ def parse_args(): parser.add_argument('--epoch', dest='epoch', help='epoch of trained model', default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', - default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str) + default=os.path.join(os.getcwd(), 'model', 'ssd_vgg16_reduced_300'), + type=str) parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect', action='store_true', default=False) parser.add_argument('--gpu', dest='gpu_id', type=int, default=0, @@ -78,9 +76,29 @@ def parse_args(): help='show detection time') parser.add_argument('--deploy', dest='deploy_net', action='store_true', default=False, help='Load network from json file, rather than from symbol') + parser.add_argument('--class-names', dest='class_names', type=str, + default='aeroplane, bicycle, bird, boat, bottle, bus, \ + car, cat, chair, cow, diningtable, dog, horse, motorbike, \ + person, pottedplant, sheep, sofa, train, tvmonitor', + help='string of comma separated names, or text filename') args = parser.parse_args() return args +def parse_class_names(class_names): + """ parse # classes and class_names if applicable """ + if len(class_names) > 0: + if os.path.isfile(class_names): + # try to open it to read class names + with open(class_names, 'r') as f: + class_names = [l.strip() for l in f.readlines()] + else: + class_names = [c.strip() for c in class_names.split(',')] + for name in class_names: + assert len(name) > 0 + else: + raise RuntimeError("No valid class_name provided...") + return class_names + if __name__ == '__main__': args = parse_args() if args.cpu: @@ -93,10 +111,11 @@ def parse_args(): assert len(image_list) > 0, "No valid image specified to detect" network = None if args.deploy_net else args.network + class_names = parse_class_names(args.class_names) detector = get_detector(network, args.prefix, args.epoch, args.data_shape, (args.mean_r, args.mean_g, args.mean_b), - ctx, args.nms_thresh, args.force_nms) + ctx, len(class_names), args.nms_thresh, args.force_nms) # run detection detector.detect_and_visualize(image_list, args.dir, args.extension, - CLASSES, args.thresh, args.show_timer) + class_names, args.thresh, args.show_timer) diff --git a/deploy.py b/deploy.py index 264314a..5f6b8b0 100644 --- a/deploy.py +++ b/deploy.py @@ -5,6 +5,7 @@ import os import importlib import sys +from symbol.symbol_factory import get_symbol def parse_args(): parser = argparse.ArgumentParser(description='Convert a trained model to deploy model') @@ -14,20 +15,24 @@ def parse_args(): default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', default=os.path.join(os.getcwd(), 'model', 'ssd_300'), type=str) + parser.add_argument('--data-shape', dest='data_shape', type=int, default=300, + help='data shape') parser.add_argument('--num-class', dest='num_classes', help='number of classes', default=20, type=int) parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.5, help='non-maximum suppression threshold, default 0.5') parser.add_argument('--force', dest='force_nms', type=bool, default=True, help='force non-maximum suppression on different class') + parser.add_argument('--topk', dest='nms_topk', type=int, default=400, + help='apply nms only to top k detections based on scores.') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() - sys.path.append(os.path.join(os.getcwd(), 'symbol')) - net = importlib.import_module("symbol_" + args.network) \ - .get_symbol(args.num_classes, args.nms_thresh, args.force_nms) + net = get_symbol(args.network).get_symbol(args.network, args.data_shape, + num_classes=args.num_classes, nms_thresh=args.nms_thresh, + force_suppress=args.force_nms, nms_topk=args.nms_topk) _, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch) # new name tmp = args.prefix.rsplit('/', 1) diff --git a/evaluate.py b/evaluate.py index a38a7f6..3c9ff44 100644 --- a/evaluate.py +++ b/evaluate.py @@ -5,12 +5,6 @@ import sys from evaluate.evaluate_net import evaluate_net -CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', - 'bottle', 'bus', 'car', 'cat', 'chair', - 'cow', 'diningtable', 'dog', 'horse', - 'motorbike', 'person', 'pottedplant', - 'sheep', 'sofa', 'train', 'tvmonitor') - def parse_args(): parser = argparse.ArgumentParser(description='Evaluate a network') parser.add_argument('--rec-path', dest='rec_path', help='which record file to use', @@ -23,7 +17,10 @@ def parse_args(): help='evaluation batch size') parser.add_argument('--num-class', dest='num_class', type=int, default=20, help='number of classes') - parser.add_argument('--class-names', dest='class_names', type=str, default=",".join(CLASSES), + parser.add_argument('--class-names', dest='class_names', type=str, + default='aeroplane, bicycle, bird, boat, bottle, bus, \ + car, cat, chair, cow, diningtable, dog, horse, motorbike, \ + person, pottedplant, sheep, sofa, train, tvmonitor', help='string of comma separated names, or text filename') parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model', default=0, type=int) diff --git a/symbol/symbol_factory.py b/symbol/symbol_factory.py index 462fe86..37030c8 100644 --- a/symbol/symbol_factory.py +++ b/symbol/symbol_factory.py @@ -113,7 +113,7 @@ def get_symbol(network, data_shape, **kwargs): kwargs : dict see symbol_builder.get_symbol for more details """ - if network.stargswith('legacy'): + if network.startswith('legacy'): return symbol_builder.import_module(network).get_symbol(**kwargs) config = get_config(network, data_shape, **kwargs).copy() config.update(kwargs) diff --git a/train.py b/train.py index fcd5fb9..62ad2c4 100644 --- a/train.py +++ b/train.py @@ -15,8 +15,8 @@ def parse_args(): default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str) parser.add_argument('--val-list', dest='val_list', help='validation list to use', default="", type=str) - parser.add_argument('--network', dest='network', type=str, default='vgg16_ssd_300', - choices=['vgg16_ssd_300', 'vgg16_ssd_512'], help='which network to use') + parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced', + help='which network to use') parser.add_argument('--batch-size', dest='batch_size', type=int, default=32, help='training batch size') parser.add_argument('--resume', dest='resume', type=int, default=-1, @@ -41,7 +41,7 @@ def parse_args(): help='set image shape') parser.add_argument('--label-width', dest='label_width', type=int, default=350, help='force padding label width to sync across train and validation') - parser.add_argument('--lr', dest='learning_rate', type=float, default=0.004, + parser.add_argument('--lr', dest='learning_rate', type=float, default=0.002, help='learning rate') parser.add_argument('--momentum', dest='momentum', type=float, default=0.9, help='momentum') @@ -53,7 +53,7 @@ def parse_args(): help='green mean value') parser.add_argument('--mean-b', dest='mean_b', type=float, default=104, help='blue mean value') - parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='150, 200', + parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='80, 160', help='refactor learning rate at specified epochs') parser.add_argument('--lr-factor', dest='lr_refactor_ratio', type=str, default=0.1, help='ratio to refactor learning rate') @@ -92,9 +92,9 @@ def parse_class_names(args): num_class = args.num_class if len(args.class_names) > 0: if os.path.isfile(args.class_names): - # try to open it to read class names - with open(args.class_names, 'r') as f: - class_names = [l.strip() for l in f.readlines()] + # try to open it to read class names + with open(args.class_names, 'r') as f: + class_names = [l.strip() for l in f.readlines()] else: class_names = [c.strip() for c in args.class_names.split(',')] assert len(class_names) == num_class, str(len(class_names))