Skip to content

Commit

Permalink
Added inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-roddick committed Aug 29, 2019
1 parent 7692dd0 commit c7d3ced
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 2 deletions.
96 changes: 96 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import time
import torch
from torchvision.transforms.functional import to_tensor
from argparse import ArgumentParser
import matplotlib.pyplot as plt

from oft import KittiObjectDataset, OftNet, ObjectEncoder, visualize_objects

def parse_args():
parser = ArgumentParser()

parser.add_argument('model-path', type=str,
help='path to checkpoint file containing trained model')
parser.add_argument('-g', '--gpu', type=int, default=0,
help='gpu to use for inference (-1 for cpu)')

# Data options
parser.add_argument('--root', type=str, default='data/kitti',
help='root directory of the KITTI dataset')
parser.add_argument('--grid-size', type=float, nargs=2, default=(80., 80.),
help='width and depth of validation grid, in meters')
parser.add_argument('--yoffset', type=float, default=1.74,
help='vertical offset of the grid from the camera axis')
parser.add_argument('--nms-thresh', type=float, default=0.2,
help='minimum score for a positive detection')

# Model options
parser.add_argument('--grid-height', type=float, default=4.,
help='size of grid cells, in meters')
parser.add_argument('-r', '--grid-res', type=float, default=0.5,
help='size of grid cells, in meters')
parser.add_argument('--frontend', type=str, default='resnet18',
choices=['resnet18', 'resnet34'],
help='name of frontend ResNet architecture')
parser.add_argument('--topdown', type=int, default=8,
help='number of residual blocks in topdown network')

return parser.parse_args()


def main():

# Parse command line arguments
args = parse_args()

# Load validation dataset to visualise
dataset = KittiObjectDataset(
args.root, 'val', args.grid_size, args.grid_res, args.yoffset)

# Build model
model = OftNet(num_classes=1, frontend=args.frontend,
topdown_layers=args.topdown, grid_res=args.grid_res,
grid_height=args.grid_height)
if args.gpu >= 0:
torch.cuda.set_device(args.gpu)
model.cuda()

# Load checkpoint
ckpt = torch.load(args.model_path)
model.load_state_dict(ckpt['model'])

# Create encoder
encoder = ObjectEncoder(nms_thresh=args.nms_thresh)

# Set up plots
_, (ax1, ax2) = plt.subplots(nrows=2)
plt.ion()

# Iterate over validation images
for _, image, calib, objects, grid in dataset:

# Move tensors to gpu
image = to_tensor(image)
if args.gpu >= 0:
image, calib, grid = image.cuda(), calib.cuda(), grid.cuda()

# Run model forwards
pred_encoded = model(image[None], calib[None], grid[None])

# Decode predictions
pred_encoded = [t[0].cpu() for t in pred_encoded]
detections = encoder.decode(*pred_encoded, grid.cpu())

# Visualize predictions
visualize_objects(image, calib, detections, ax=ax1)
ax1.set_title('Detections')
visualize_objects(image, calib, objects, ax=ax2)
ax2.set_title('Ground truth')

plt.draw()
plt.pause(0.01)
time.sleep(0.5)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions oft/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from matplotlib.cm import ScalarMappable

from .encoded import vis_score, vis_uncertainty
from .bbox import visualize_objects

# def vis_score(scores, labels, ax=None):
# scores = scores.sigmoid().cpu().detach().numpy()
Expand Down
51 changes: 50 additions & 1 deletion oft/visualization/bbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
from matplotlib.patches import Rectangle, Circle, Polygon
from matplotlib.lines import Line2D
from matplotlib import cm
from matplotlib import transforms
from .. import utils

def draw_bbox2d(objects, color='k', ax=None):

Expand All @@ -25,3 +28,49 @@ def draw_bbox2d(objects, color='k', ax=None):

ax.axis(limits)
return ax


def draw_bbox3d(obj, calib, ax, color='b'):

# Get corners of 3D bounding box
corners = utils.bbox_corners(obj)

# Project into image coordinates
img_corners = utils.perspective(calib, corners).numpy()

# Draw polygons
# Front face
ax.add_patch(Polygon(img_corners[[1, 3, 7, 5]], ec=color, fill=False))
# Back face
ax.add_patch(Polygon(img_corners[[0, 2, 6, 4]], ec=color, fill=False))
ax.add_line(Line2D(*img_corners[[0, 1]].T, c=color)) # Lower left
ax.add_line(Line2D(*img_corners[[2, 3]].T, c=color)) # Lower right
ax.add_line(Line2D(*img_corners[[4, 5]].T, c=color)) # Upper left
ax.add_line(Line2D(*img_corners[[6, 7]].T, c=color)) # Upper right


def visualize_objects(image, calib, objects, cmap='tab20', ax=None):

# Create a figure if it doesn't already exist
if ax is None:
fig, ax = plt.subplots()
ax.clear()

# Visualize image
ax.imshow(image.permute(1, 2, 0).numpy())
extents = ax.axis()

# Visualize objects
cmap = cm.get_cmap(cmap, len(objects))
for i, obj in enumerate(objects):
draw_bbox3d(obj, calib, ax, cmap(i))

# Format axis
ax.axis(extents)
ax.axis('off')
ax.grid('off')
return ax




2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def parse_args():
help='width and depth of validation grid, in meters')
parser.add_argument('--train-grid-size', type=int, nargs=2,
default=(120, 120),
help='width and depth of training grid, in meters')
help='width and depth of training grid, in pixels')
parser.add_argument('--grid-jitter', type=float, nargs=3,
default=[.25, .5, .25],
help='magn. of random noise applied to grid coords')
Expand Down

0 comments on commit c7d3ced

Please sign in to comment.