From 2afa0f13464b904321b054808c1bc294eba016ae Mon Sep 17 00:00:00 2001 From: IdeaKing Date: Mon, 21 Mar 2022 16:33:29 -0400 Subject: [PATCH] inference debugs --- inference.py | 17 +++++++++++------ requirements.txt | 1 + src/dataset.py | 1 - 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/inference.py b/inference.py index ddba438..70affdc 100644 --- a/inference.py +++ b/inference.py @@ -1,7 +1,6 @@ import os import argparse import tensorflow as tf -import matplotlib.pyplot as plt from src.utils.postprocess import FilterDetections from src.utils.visualize import draw_boxes @@ -50,11 +49,11 @@ def test(image_path, image_dir, save_dir, model, prog="i-Sight") parser.add_argument("--testing-image-dir", type=str, - default="data/dataset/VOC2012/TestImages", + default="datasets/data/VOC2012/TestImages", help="Path to testing images directory.") parser.add_argument("--save-image-dir", type=str, - default="data/dataset/Tests", + default="datasets/data/Tests", help="Path to testing images directory.") parser.add_argument("--model-dir", type=str, @@ -66,7 +65,7 @@ def test(image_path, image_dir, save_dir, model, help="Size of the input image.") parser.add_argument("--labels-file", type=str, - default="data/dataset/VOC2012/labels.txt", + default="datasets/data/VOC2012/labels.txt", help="Path to labels file.") parser.add_argument("--score-threshold", type=float, @@ -79,14 +78,20 @@ def test(image_path, image_dir, save_dir, model, args=parser.parse_args() label_dict = parse_label_file( - path_to_label_file=args.labels_files) + path_to_label_file=args.labels_file) model = tf.keras.models.load_model(args.model_dir) + if os.path.exists(args.save_image_dir) == False: + os.mkdir(args.save_image_dir) + for image_path in os.listdir(args.testing_image_dir): # Test the model on the image test(image_path=image_path, + image_dir=args.testing_image_dir, + save_dir=args.save_image_dir, model=model, image_dims=args.image_dims, label_dict=label_dict, - score_threshold=args.score_threshold) + score_threshold=args.score_threshold, + iou_threshold=args.iou_threshold) diff --git a/requirements.txt b/requirements.txt index e2008fa..93db403 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ tensorflow==2.7.0 opencv-python albumentations matplotlib +Pillow # Web development backends \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py index 8d0e756..383b273 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -126,7 +126,6 @@ def parse_process_voc(self, file_name): # Reads a voc annotation and returns # a list of tuples containing the ground # truth boxes and its respective label - root = ET.parse(file_name).getroot() image_size = (int(root.findtext("size/width")), int(root.findtext("size/height")))