Skip to content

Commit

Permalink
Write results to file
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Jul 11, 2016
1 parent 4db031d commit 9bba40c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,20 @@
X_test = np.transpose(X_test, (1, 2, 0))
X_hat = np.transpose(X_hat, (1, 2, 0))

# Compare MSE of PredNet predictions vs. using last frame.
# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first
mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 )
print "Model MSE: %f" % mse_model
print "Previous Frame MSE: %f" % mse_prev
if not os.path.exists(results_save_dir): os.mkdir(results_save_dir)
f = open(results_save_dir + 'prediction_scores.txt', 'w')
f.write("Model MSE: %f\n" % mse_model)
f.write("Previous Frame MSE: %f" % mse_prev)
f.close()

# Plot some predictions
aspect_ratio = float(X_hat.shape[3]) / X_hat.shape[4]
plt.figure(figsize = (nt, 2*aspect_ratio))
gs = gridspec.GridSpec(2, nt)
gs.update(wspace=0.025, hspace=0.05)
if not os.path.exists(results_save_dir): os.mkdir(results_save_dir)
plot_save_dir = os.path.join(results_save_dir, 'prediction_plots/')
if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)
for i in range(n_plot):
Expand Down

0 comments on commit 9bba40c

Please sign in to comment.