Skip to content

Commit

Permalink
Visualise boxes during training
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-roddick committed Mar 31, 2021
1 parent 02ad217 commit cecedcf
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
26 changes: 20 additions & 6 deletions oft/data/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,24 @@ def decode(self, heatmaps, pos_offsets, dim_offsets, ang_offsets, grid):
self.classnames[cid], pos, dim, ang, score))

return objects


def decode_batch(self, heatmaps, pos_offsets, dim_offsets, ang_offsets,
grids):

boxes = list()
for hmap, pos_off, dim_off, ang_off, grid in zip(heatmaps, pos_offsets,
dim_offsets,
ang_offsets, grids):
boxes.append(self.decode(hmap, pos_off, dim_off, ang_off, grid))

return boxes

def _decode_heatmaps(self, heatmaps):
peaks = non_maximum_suppression(heatmaps, self.sigma)
scores = heatmaps[peaks]
classids = torch.nonzero(peaks)[:, 0]
return peaks, scores, classids
return peaks, scores.cpu(), classids.cpu()


def _decode_positions(self, pos_offsets, peaks, grid):
Expand All @@ -181,18 +193,20 @@ def _decode_positions(self, pos_offsets, peaks, grid):
centers = (grid[1:, 1:] + grid[:-1, :-1]) / 2.

# Un-normalize grid offsets
positions = pos_offsets.permute(0, 2, 3, 1) * self.pos_std + centers
return positions[peaks]
positions = pos_offsets.permute(0, 2, 3, 1) * self.pos_std.to(grid) \
+ centers
return positions[peaks].cpu()

def _decode_dimensions(self, dim_offsets, peaks):
dim_offsets = dim_offsets.permute(0, 2, 3, 1)
dimensions = torch.exp(
dim_offsets * self.log_dim_std + self.log_dim_mean)
return dimensions[peaks]
dim_offsets * self.log_dim_std.to(dim_offsets) \
+ self.log_dim_mean.to(dim_offsets))
return dimensions[peaks].cpu()

def _decode_angles(self, angle_offsets, peaks):
cos, sin = torch.unbind(angle_offsets, 1)
return torch.atan2(sin, cos)[peaks]
return torch.atan2(sin, cos)[peaks].cpu()



Expand Down
8 changes: 4 additions & 4 deletions oft/visualization/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def draw_bbox3d(obj, calib, ax, color='b'):
corners = utils.bbox_corners(obj)

# Project into image coordinates
img_corners = utils.perspective(calib, corners).numpy()
img_corners = utils.perspective(calib.cpu(), corners).numpy()

# Draw polygons
# Front face
Expand All @@ -57,7 +57,7 @@ def visualize_objects(image, calib, objects, cmap='tab20', ax=None):
ax.clear()

# Visualize image
ax.imshow(image.permute(1, 2, 0).numpy())
ax.imshow(image.permute(1, 2, 0).cpu().numpy())
extents = ax.axis()

# Visualize objects
Expand All @@ -67,8 +67,8 @@ def visualize_objects(image, calib, objects, cmap='tab20', ax=None):

# Format axis
ax.axis(extents)
ax.axis('off')
ax.grid('off')
ax.axis(False)
ax.grid(False)
return ax


Expand Down
32 changes: 31 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def train(args, dataloader, model, encoder, optimizer, summary, epoch):
# Visualize scores
summary.add_figure('train/score',
visualize_score(pred_encoded[0], gt_encoded[0], grid), epoch)

# Decode predictions
preds = encoder.decode_batch(*pred_encoded, grid)

# Visualise bounding boxes
summary.add_figure('train/bboxes',
visualise_bboxes(image, calib, objects, preds), epoch)

# TODO decode and save results

Expand Down Expand Up @@ -113,6 +120,9 @@ def validate(args, dataloader, model, encoder, summary, epoch):
pred_encoded, gt_encoded, args.loss_weights)
epoch_loss += loss_dict

# Decode predictions
preds = encoder.decode_batch(*pred_encoded, grid)

# Visualize predictions
if i % args.vis_iter == 0:

Expand All @@ -123,7 +133,11 @@ def validate(args, dataloader, model, encoder, summary, epoch):
summary.add_figure('val/score',
visualize_score(pred_encoded[0], gt_encoded[0], grid), epoch)

# TODO decode and save results
# Visualise bounding boxes
summary.add_figure('val/bboxes',
visualise_bboxes(image, calib, objects, preds), epoch)



# TODO evaluate

Expand Down Expand Up @@ -175,6 +189,22 @@ def visualize_score(scores, heatmaps, grid):

return fig_score

def visualise_bboxes(image, calib, objects, preds):

fig = plt.figure(num='bbox', figsize=(8, 6))
fig.clear()
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)

oft.visualize_objects(image[0], calib[0], preds[0], ax=ax1)
ax1.set_title('Predictions')

oft.visualize_objects(image[0], calib[0], objects[0], ax=ax2)
ax2.set_title('Ground truth')

return fig



def parse_args():
parser = ArgumentParser()
Expand Down

0 comments on commit cecedcf

Please sign in to comment.