Skip to content

Commit

Permalink
Fix tensorboardX visualization errors
Browse files Browse the repository at this point in the history
Fixed a couple of errors caused by an out of date version of
TensorboardX.
  • Loading branch information
tom-roddick committed Aug 7, 2019
1 parent 4b6073a commit 18fb33b
Showing 1 changed file with 5 additions and 35 deletions.
40 changes: 5 additions & 35 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,8 @@ def train(args, dataloader, model, encoder, optimizer, summary, epoch):
summary.add_image('train/image', visualize_image(image), epoch)

# Visualize scores
summary.add_image('train/score',
summary.add_figure('train/score',
visualize_score(pred_encoded, gt_encoded, grid), epoch)

# Visualize uncertainty
# summary.add_image('train/uncertainty',
# visualize_uncertainty(pred_encoded[0], objects, grid), epoch)

# TODO decode and save results

Expand Down Expand Up @@ -124,12 +120,8 @@ def validate(args, dataloader, model, encoder, summary, epoch):
summary.add_image('val/image', visualize_image(image), epoch)

# Visualize scores
summary.add_image('val/score',
summary.add_figure('val/score',
visualize_score(pred_encoded, gt_encoded, grid), epoch)

# Visualize uncertainty
# summary.add_image('val/uncertainty',
# visualize_uncertainty(pred_encoded[0], objects, grid), epoch)

# TODO decode and save results

Expand Down Expand Up @@ -172,44 +164,22 @@ def compute_loss(pred_encoded, gt_encoded, loss_weights=[1., 1., 1., 1.]):


def visualize_image(image):
fig = plt.figure('image', figsize=(8, 4))
plt.imshow(image[0].cpu().detach().permute(1, 2, 0).numpy())
plt.axis('off')
return oft.convert_figure(fig)
return image[0].cpu().detach()

def visualize_score(preds, targets, grid):

# Expand tuples
score, pos_offsets, dim_offsets, ang_offsets = preds
labels, sqr_dists, gt_pos_offsets, gt_dim_offsets, gt_ang_offsets = targets

# TODO Compute
# score = oft.model.loss.compute_uncertainty(score, sqr_dists, 0.25)

# Visualize score
fig_score = plt.figure(num='score', figsize=(8, 6))
fig_score.clear()
# oft.vis_score(score[0, 0], grid[0], ax=plt.subplot(121))
oft.vis_score(score[0, 0].sigmoid(), grid[0], ax=plt.subplot(121))
oft.vis_score(labels[0, 0].float(), grid[0], ax=plt.subplot(122))

return oft.convert_figure(fig_score)


def visualize_uncertainty(logvar, objects, grid):

# Visualize uncertainty
fig_uncert = plt.figure(num='uncertainty', figsize=(6, 6))
fig_uncert.clear()
ax = fig_uncert.gca()
ax = oft.vis_uncertainty(logvar[0, 0], objects[0], grid[0], ax=ax)

# Setup colorbar
def cbar_fmt(x, pos):
return '{:.2f}'.format(math.exp(x))
cbar = plt.colorbar(ax.collections[0], format=FuncFormatter(cbar_fmt))

return oft.convert_figure(fig_uncert)
return fig_score


def parse_args():
Expand Down Expand Up @@ -262,7 +232,7 @@ def parse_args():
# Training options
parser.add_argument('-e', '--epochs', type=int, default=600,
help='number of epochs to train for')
parser.add_argument('-b', '--batch_size', type=int, default=8,
parser.add_argument('-b', '--batch-size', type=int, default=1,
help='mini-batch size for training')

# Experiment options
Expand Down

0 comments on commit 18fb33b

Please sign in to comment.