Skip to content

Commit

Permalink
Cleaned up imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Jul 9, 2016
1 parent 38d7d9b commit 97a284e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
5 changes: 3 additions & 2 deletions kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import matplotlib.gridspec as gridspec

from keras import backend as K
from kers.engine.training import Model
from kers.engine import Model
from keras.layers import Input, Dense, Flatten

from prednet import PredNet
from process_kitti import SequenceGenerator
import kitti_settings
from kitti_settings import *


n_plot = 20
batch_size = 10
Expand Down
17 changes: 9 additions & 8 deletions kitti_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@
from six.moves import cPickle

from keras import backend as K
from keras.engine.training import Model
from keras.engine import Model
from keras.layers import Input, Dense, Flatten
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed
from keras.layers import LSTM
from keras.layers import TimeDistributed
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import Adam

from prednet import PredNet
from process_kitti import SequenceGenerator
import kitti_settings
from kitti_settings import *


save_model = True # if weights will be saved
weights_file = os.path.join(weights_dir, 'prednet_kitti_weights.hdf5') # where weights will be saved
Expand All @@ -29,9 +30,9 @@
val_sources = os.path.join(data_dir, 'sources_val.hkl')

# Training parameters
nb_epoch = 150
nb_epoch = 2 #150
batch_size = 5
samples_per_epoch = 500
samples_per_epoch = 10 #500
N_seq_val = 100 # number of sequences to use for validation

# Model parameters
Expand Down Expand Up @@ -81,8 +82,8 @@
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.savefig('error_curve.png')
3 changes: 1 addition & 2 deletions process_kitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import numpy as np
from scipy.misc import imread, imresize
import hickle as hkl

from keras import backend as K
from keras.preprocessing.image import Iterator
import kitti_settings
from kitti_settings import *


desired_im_sz = (128, 160)
Expand Down

0 comments on commit 97a284e

Please sign in to comment.