Skip to content

Commit

Permalink
Fixed typos
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Jul 11, 2016
1 parent 6c8fea0 commit 70f124c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def preprocess(self, X):
return X.astype(np.float32) / 255

def create_all(self):
X_all = np.zeros((self.N_sequences, nt) + self.im_shape, np.float32)
X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32)
for i, idx in enumerate(self.possible_starts):
X_all[i] = self.preprocess(self.X[idx:idx+self.nt])
return X_all
6 changes: 3 additions & 3 deletions kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
X_test = test_generator.create_all()
X_hat = test_model.predict(X_test, batch_size)
if K.image_dim_ordering() == 'th':
X_test = np.transpose(X_test, (1, 2, 0))
X_hat = np.transpose(X_hat, (1, 2, 0))
X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))

# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first
Expand All @@ -73,7 +73,7 @@
if t==0: plt.ylabel('Actual')

plt.subplot(gs[t + nt])
plt.imshow(X_hat[i,t], (1, 2, 0)), interpolation='none')
plt.imshow(X_hat[i,t], interpolation='none')
plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
if t==0: plt.ylabel('Predicted')

Expand Down

0 comments on commit 70f124c

Please sign in to comment.