Skip to content

Commit

Permalink
inference debugs
Browse files Browse the repository at this point in the history
  • Loading branch information
IdeaKing authored and IdeaKing committed Mar 21, 2022
1 parent 6a25d43 commit 2afa0f1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import argparse
import tensorflow as tf
import matplotlib.pyplot as plt

from src.utils.postprocess import FilterDetections
from src.utils.visualize import draw_boxes
Expand Down Expand Up @@ -50,11 +49,11 @@ def test(image_path, image_dir, save_dir, model,
prog="i-Sight")
parser.add_argument("--testing-image-dir",
type=str,
default="data/dataset/VOC2012/TestImages",
default="datasets/data/VOC2012/TestImages",
help="Path to testing images directory.")
parser.add_argument("--save-image-dir",
type=str,
default="data/dataset/Tests",
default="datasets/data/Tests",
help="Path to testing images directory.")
parser.add_argument("--model-dir",
type=str,
Expand All @@ -66,7 +65,7 @@ def test(image_path, image_dir, save_dir, model,
help="Size of the input image.")
parser.add_argument("--labels-file",
type=str,
default="data/dataset/VOC2012/labels.txt",
default="datasets/data/VOC2012/labels.txt",
help="Path to labels file.")
parser.add_argument("--score-threshold",
type=float,
Expand All @@ -79,14 +78,20 @@ def test(image_path, image_dir, save_dir, model,
args=parser.parse_args()

label_dict = parse_label_file(
path_to_label_file=args.labels_files)
path_to_label_file=args.labels_file)

model = tf.keras.models.load_model(args.model_dir)

if os.path.exists(args.save_image_dir) == False:
os.mkdir(args.save_image_dir)

for image_path in os.listdir(args.testing_image_dir):
# Test the model on the image
test(image_path=image_path,
image_dir=args.testing_image_dir,
save_dir=args.save_image_dir,
model=model,
image_dims=args.image_dims,
label_dict=label_dict,
score_threshold=args.score_threshold)
score_threshold=args.score_threshold,
iou_threshold=args.iou_threshold)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ tensorflow==2.7.0
opencv-python
albumentations
matplotlib
Pillow

# Web development backends
1 change: 0 additions & 1 deletion src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def parse_process_voc(self, file_name):
# Reads a voc annotation and returns
# a list of tuples containing the ground
# truth boxes and its respective label

root = ET.parse(file_name).getroot()
image_size = (int(root.findtext("size/width")),
int(root.findtext("size/height")))
Expand Down

0 comments on commit 2afa0f1

Please sign in to comment.