Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
bairdzhang committed Dec 6, 2018
1 parent 3ae1ccf commit 0eda5a6
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,4 @@ submit.sh
*.eps
*.jpg
*.png
!demo/demo.jpg
4 changes: 4 additions & 0 deletions configs/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,7 @@ SCALES = [100, 300, 600, 1000, 1400]
SCORE_THRESH = 0.002
GPU_ID = [0,1,2,3]
IOU_THRESH = 0.5

[TEST.DEMO]
ENABLE = false
IMAGE = "demo/demo.jpg"
Binary file added demo/demo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions lib/test.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,36 @@ def inference_worker(rank,
return dets


def demo(target_test, thresh):
# Loading the network
cfg.GPU_ID = cfg.TEST.GPU_ID[0]
caffe.set_mode_gpu()
caffe.set_device(cfg.GPU_ID)
net = caffe.Net(str(target_test), str(cfg.TEST.MODEL), caffe.TEST)
pyramid = True if len(cfg.TEST.SCALES) > 1 else False
im_path = cfg.TEST.DEMO.IMAGE
dets, detect_time = detect(
net, im_path, thresh, timers=None, pyramid=pyramid)
im = cv2.imread(cfg.TEST.DEMO.IMAGE)
for i in range(dets[0].shape[0]):
if dets[0][i, -1] < thresh:
continue
cv2.rectangle(im, (int(dets[0][i, 0]), int(dets[0][i, 1])),
(int(dets[0][i, 2]), int(dets[0][i, 3])), (0, 255, 0), 2)
cv2.imwrite('/tmp/demo_res.jpg', im)
return None


def test_net(imdb,
output_dir,
target_test,
thresh=0.05,
no_cache=False,
step=0):
# Run demo
if imdb is None:
assert cfg.TEST.DEMO.ENABLE, "check your config and stderr!"
return demo(target_test, thresh)
# Initializing the timers
logger.info('Evaluating {} on {}'.format(cfg.NAME, imdb.name))

Expand Down
8 changes: 6 additions & 2 deletions train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ def parser():
if isinstance(cfg.TEST.GPU_ID, int):
cfg.TEST.GPU_ID = [cfg.TEST.GPU_ID]

imdb = get_imdb(cfg.TEST.DB)
output_dir = get_output_dir(imdb.name, cfg.NAME + '_' + cfg.LOG.TIME)
if not cfg.TEST.DEMO.ENABLE:
imdb = get_imdb(cfg.TEST.DB)
output_dir = get_output_dir(imdb.name, cfg.NAME + '_' + cfg.LOG.TIME)
else:
imdb = None
output_dir = get_output_dir("demo", cfg.NAME + '_' + cfg.LOG.TIME)

f = open(osp.join(output_dir, 'stderr.log'), 'w', 0)
os.dup2(f.fileno(), sys.stderr.fileno())
Expand Down

0 comments on commit 0eda5a6

Please sign in to comment.