diff --git a/oft/data/encoder.py b/oft/data/encoder.py index 7805d9e..25582b1 100644 --- a/oft/data/encoder.py +++ b/oft/data/encoder.py @@ -167,12 +167,24 @@ def decode(self, heatmaps, pos_offsets, dim_offsets, ang_offsets, grid): self.classnames[cid], pos, dim, ang, score)) return objects + + + def decode_batch(self, heatmaps, pos_offsets, dim_offsets, ang_offsets, + grids): + + boxes = list() + for hmap, pos_off, dim_off, ang_off, grid in zip(heatmaps, pos_offsets, + dim_offsets, + ang_offsets, grids): + boxes.append(self.decode(hmap, pos_off, dim_off, ang_off, grid)) + + return boxes def _decode_heatmaps(self, heatmaps): peaks = non_maximum_suppression(heatmaps, self.sigma) scores = heatmaps[peaks] classids = torch.nonzero(peaks)[:, 0] - return peaks, scores, classids + return peaks, scores.cpu(), classids.cpu() def _decode_positions(self, pos_offsets, peaks, grid): @@ -181,18 +193,20 @@ def _decode_positions(self, pos_offsets, peaks, grid): centers = (grid[1:, 1:] + grid[:-1, :-1]) / 2. # Un-normalize grid offsets - positions = pos_offsets.permute(0, 2, 3, 1) * self.pos_std + centers - return positions[peaks] + positions = pos_offsets.permute(0, 2, 3, 1) * self.pos_std.to(grid) \ + + centers + return positions[peaks].cpu() def _decode_dimensions(self, dim_offsets, peaks): dim_offsets = dim_offsets.permute(0, 2, 3, 1) dimensions = torch.exp( - dim_offsets * self.log_dim_std + self.log_dim_mean) - return dimensions[peaks] + dim_offsets * self.log_dim_std.to(dim_offsets) \ + + self.log_dim_mean.to(dim_offsets)) + return dimensions[peaks].cpu() def _decode_angles(self, angle_offsets, peaks): cos, sin = torch.unbind(angle_offsets, 1) - return torch.atan2(sin, cos)[peaks] + return torch.atan2(sin, cos)[peaks].cpu() diff --git a/oft/visualization/bbox.py b/oft/visualization/bbox.py index eef47fa..8640849 100644 --- a/oft/visualization/bbox.py +++ b/oft/visualization/bbox.py @@ -36,7 +36,7 @@ def draw_bbox3d(obj, calib, ax, color='b'): corners = utils.bbox_corners(obj) # Project into image coordinates - img_corners = utils.perspective(calib, corners).numpy() + img_corners = utils.perspective(calib.cpu(), corners).numpy() # Draw polygons # Front face @@ -57,7 +57,7 @@ def visualize_objects(image, calib, objects, cmap='tab20', ax=None): ax.clear() # Visualize image - ax.imshow(image.permute(1, 2, 0).numpy()) + ax.imshow(image.permute(1, 2, 0).cpu().numpy()) extents = ax.axis() # Visualize objects @@ -67,8 +67,8 @@ def visualize_objects(image, calib, objects, cmap='tab20', ax=None): # Format axis ax.axis(extents) - ax.axis('off') - ax.grid('off') + ax.axis(False) + ax.grid(False) return ax diff --git a/train.py b/train.py index a4663df..b70359e 100644 --- a/train.py +++ b/train.py @@ -76,6 +76,13 @@ def train(args, dataloader, model, encoder, optimizer, summary, epoch): # Visualize scores summary.add_figure('train/score', visualize_score(pred_encoded[0], gt_encoded[0], grid), epoch) + + # Decode predictions + preds = encoder.decode_batch(*pred_encoded, grid) + + # Visualise bounding boxes + summary.add_figure('train/bboxes', + visualise_bboxes(image, calib, objects, preds), epoch) # TODO decode and save results @@ -113,6 +120,9 @@ def validate(args, dataloader, model, encoder, summary, epoch): pred_encoded, gt_encoded, args.loss_weights) epoch_loss += loss_dict + # Decode predictions + preds = encoder.decode_batch(*pred_encoded, grid) + # Visualize predictions if i % args.vis_iter == 0: @@ -123,7 +133,11 @@ def validate(args, dataloader, model, encoder, summary, epoch): summary.add_figure('val/score', visualize_score(pred_encoded[0], gt_encoded[0], grid), epoch) - # TODO decode and save results + # Visualise bounding boxes + summary.add_figure('val/bboxes', + visualise_bboxes(image, calib, objects, preds), epoch) + + # TODO evaluate @@ -175,6 +189,22 @@ def visualize_score(scores, heatmaps, grid): return fig_score +def visualise_bboxes(image, calib, objects, preds): + + fig = plt.figure(num='bbox', figsize=(8, 6)) + fig.clear() + ax1 = plt.subplot(211) + ax2 = plt.subplot(212) + + oft.visualize_objects(image[0], calib[0], preds[0], ax=ax1) + ax1.set_title('Predictions') + + oft.visualize_objects(image[0], calib[0], objects[0], ax=ax2) + ax2.set_title('Ground truth') + + return fig + + def parse_args(): parser = ArgumentParser()