From 18fb33bbf9b98f92e35cefe2fdcb364fadd82693 Mon Sep 17 00:00:00 2001 From: Thomas Roddick Date: Wed, 7 Aug 2019 17:16:42 +0100 Subject: [PATCH] Fix tensorboardX visualization errors Fixed a couple of errors caused by an out of date version of TensorboardX. --- train.py | 40 +++++----------------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/train.py b/train.py index 3dbb2a8..3cedf5e 100644 --- a/train.py +++ b/train.py @@ -74,12 +74,8 @@ def train(args, dataloader, model, encoder, optimizer, summary, epoch): summary.add_image('train/image', visualize_image(image), epoch) # Visualize scores - summary.add_image('train/score', + summary.add_figure('train/score', visualize_score(pred_encoded, gt_encoded, grid), epoch) - - # Visualize uncertainty - # summary.add_image('train/uncertainty', - # visualize_uncertainty(pred_encoded[0], objects, grid), epoch) # TODO decode and save results @@ -124,12 +120,8 @@ def validate(args, dataloader, model, encoder, summary, epoch): summary.add_image('val/image', visualize_image(image), epoch) # Visualize scores - summary.add_image('val/score', + summary.add_figure('val/score', visualize_score(pred_encoded, gt_encoded, grid), epoch) - - # Visualize uncertainty - # summary.add_image('val/uncertainty', - # visualize_uncertainty(pred_encoded[0], objects, grid), epoch) # TODO decode and save results @@ -172,10 +164,7 @@ def compute_loss(pred_encoded, gt_encoded, loss_weights=[1., 1., 1., 1.]): def visualize_image(image): - fig = plt.figure('image', figsize=(8, 4)) - plt.imshow(image[0].cpu().detach().permute(1, 2, 0).numpy()) - plt.axis('off') - return oft.convert_figure(fig) + return image[0].cpu().detach() def visualize_score(preds, targets, grid): @@ -183,9 +172,6 @@ def visualize_score(preds, targets, grid): score, pos_offsets, dim_offsets, ang_offsets = preds labels, sqr_dists, gt_pos_offsets, gt_dim_offsets, gt_ang_offsets = targets - # TODO Compute - # score = oft.model.loss.compute_uncertainty(score, sqr_dists, 0.25) - # Visualize score fig_score = plt.figure(num='score', figsize=(8, 6)) fig_score.clear() @@ -193,23 +179,7 @@ def visualize_score(preds, targets, grid): oft.vis_score(score[0, 0].sigmoid(), grid[0], ax=plt.subplot(121)) oft.vis_score(labels[0, 0].float(), grid[0], ax=plt.subplot(122)) - return oft.convert_figure(fig_score) - - -def visualize_uncertainty(logvar, objects, grid): - - # Visualize uncertainty - fig_uncert = plt.figure(num='uncertainty', figsize=(6, 6)) - fig_uncert.clear() - ax = fig_uncert.gca() - ax = oft.vis_uncertainty(logvar[0, 0], objects[0], grid[0], ax=ax) - - # Setup colorbar - def cbar_fmt(x, pos): - return '{:.2f}'.format(math.exp(x)) - cbar = plt.colorbar(ax.collections[0], format=FuncFormatter(cbar_fmt)) - - return oft.convert_figure(fig_uncert) + return fig_score def parse_args(): @@ -262,7 +232,7 @@ def parse_args(): # Training options parser.add_argument('-e', '--epochs', type=int, default=600, help='number of epochs to train for') - parser.add_argument('-b', '--batch_size', type=int, default=8, + parser.add_argument('-b', '--batch-size', type=int, default=1, help='mini-batch size for training') # Experiment options