From c7d3ced22a30216c87e28a39ef785be820258ddd Mon Sep 17 00:00:00 2001 From: Thomas Roddick Date: Thu, 29 Aug 2019 12:52:01 +0100 Subject: [PATCH] Added inference script --- infer.py | 96 +++++++++++++++++++++++++++++++++++ oft/visualization/__init__.py | 1 + oft/visualization/bbox.py | 51 ++++++++++++++++++- train.py | 2 +- 4 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 infer.py diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..ce29670 --- /dev/null +++ b/infer.py @@ -0,0 +1,96 @@ +import time +import torch +from torchvision.transforms.functional import to_tensor +from argparse import ArgumentParser +import matplotlib.pyplot as plt + +from oft import KittiObjectDataset, OftNet, ObjectEncoder, visualize_objects + +def parse_args(): + parser = ArgumentParser() + + parser.add_argument('model-path', type=str, + help='path to checkpoint file containing trained model') + parser.add_argument('-g', '--gpu', type=int, default=0, + help='gpu to use for inference (-1 for cpu)') + + # Data options + parser.add_argument('--root', type=str, default='data/kitti', + help='root directory of the KITTI dataset') + parser.add_argument('--grid-size', type=float, nargs=2, default=(80., 80.), + help='width and depth of validation grid, in meters') + parser.add_argument('--yoffset', type=float, default=1.74, + help='vertical offset of the grid from the camera axis') + parser.add_argument('--nms-thresh', type=float, default=0.2, + help='minimum score for a positive detection') + + # Model options + parser.add_argument('--grid-height', type=float, default=4., + help='size of grid cells, in meters') + parser.add_argument('-r', '--grid-res', type=float, default=0.5, + help='size of grid cells, in meters') + parser.add_argument('--frontend', type=str, default='resnet18', + choices=['resnet18', 'resnet34'], + help='name of frontend ResNet architecture') + parser.add_argument('--topdown', type=int, default=8, + help='number of residual blocks in topdown network') + + return parser.parse_args() + + +def main(): + + # Parse command line arguments + args = parse_args() + + # Load validation dataset to visualise + dataset = KittiObjectDataset( + args.root, 'val', args.grid_size, args.grid_res, args.yoffset) + + # Build model + model = OftNet(num_classes=1, frontend=args.frontend, + topdown_layers=args.topdown, grid_res=args.grid_res, + grid_height=args.grid_height) + if args.gpu >= 0: + torch.cuda.set_device(args.gpu) + model.cuda() + + # Load checkpoint + ckpt = torch.load(args.model_path) + model.load_state_dict(ckpt['model']) + + # Create encoder + encoder = ObjectEncoder(nms_thresh=args.nms_thresh) + + # Set up plots + _, (ax1, ax2) = plt.subplots(nrows=2) + plt.ion() + + # Iterate over validation images + for _, image, calib, objects, grid in dataset: + + # Move tensors to gpu + image = to_tensor(image) + if args.gpu >= 0: + image, calib, grid = image.cuda(), calib.cuda(), grid.cuda() + + # Run model forwards + pred_encoded = model(image[None], calib[None], grid[None]) + + # Decode predictions + pred_encoded = [t[0].cpu() for t in pred_encoded] + detections = encoder.decode(*pred_encoded, grid.cpu()) + + # Visualize predictions + visualize_objects(image, calib, detections, ax=ax1) + ax1.set_title('Detections') + visualize_objects(image, calib, objects, ax=ax2) + ax2.set_title('Ground truth') + + plt.draw() + plt.pause(0.01) + time.sleep(0.5) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/oft/visualization/__init__.py b/oft/visualization/__init__.py index d3f1186..d494079 100644 --- a/oft/visualization/__init__.py +++ b/oft/visualization/__init__.py @@ -5,6 +5,7 @@ from matplotlib.cm import ScalarMappable from .encoded import vis_score, vis_uncertainty +from .bbox import visualize_objects # def vis_score(scores, labels, ax=None): # scores = scores.sigmoid().cpu().detach().numpy() diff --git a/oft/visualization/bbox.py b/oft/visualization/bbox.py index 75f2c72..eef47fa 100644 --- a/oft/visualization/bbox.py +++ b/oft/visualization/bbox.py @@ -1,7 +1,10 @@ import math import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle, Circle +from matplotlib.patches import Rectangle, Circle, Polygon +from matplotlib.lines import Line2D +from matplotlib import cm from matplotlib import transforms +from .. import utils def draw_bbox2d(objects, color='k', ax=None): @@ -25,3 +28,49 @@ def draw_bbox2d(objects, color='k', ax=None): ax.axis(limits) return ax + + +def draw_bbox3d(obj, calib, ax, color='b'): + + # Get corners of 3D bounding box + corners = utils.bbox_corners(obj) + + # Project into image coordinates + img_corners = utils.perspective(calib, corners).numpy() + + # Draw polygons + # Front face + ax.add_patch(Polygon(img_corners[[1, 3, 7, 5]], ec=color, fill=False)) + # Back face + ax.add_patch(Polygon(img_corners[[0, 2, 6, 4]], ec=color, fill=False)) + ax.add_line(Line2D(*img_corners[[0, 1]].T, c=color)) # Lower left + ax.add_line(Line2D(*img_corners[[2, 3]].T, c=color)) # Lower right + ax.add_line(Line2D(*img_corners[[4, 5]].T, c=color)) # Upper left + ax.add_line(Line2D(*img_corners[[6, 7]].T, c=color)) # Upper right + + +def visualize_objects(image, calib, objects, cmap='tab20', ax=None): + + # Create a figure if it doesn't already exist + if ax is None: + fig, ax = plt.subplots() + ax.clear() + + # Visualize image + ax.imshow(image.permute(1, 2, 0).numpy()) + extents = ax.axis() + + # Visualize objects + cmap = cm.get_cmap(cmap, len(objects)) + for i, obj in enumerate(objects): + draw_bbox3d(obj, calib, ax, cmap(i)) + + # Format axis + ax.axis(extents) + ax.axis('off') + ax.grid('off') + return ax + + + + diff --git a/train.py b/train.py index c23630d..8e2c362 100644 --- a/train.py +++ b/train.py @@ -188,7 +188,7 @@ def parse_args(): help='width and depth of validation grid, in meters') parser.add_argument('--train-grid-size', type=int, nargs=2, default=(120, 120), - help='width and depth of training grid, in meters') + help='width and depth of training grid, in pixels') parser.add_argument('--grid-jitter', type=float, nargs=3, default=[.25, .5, .25], help='magn. of random noise applied to grid coords')