Skip to content

Commit

Permalink
Moved paths to kitti_settings
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Jul 9, 2016
1 parent 8f8106f commit a2a29a4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 50 deletions.
22 changes: 12 additions & 10 deletions kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
Calculates mean-squared error and plots predictions.
'''

import os
import numpy as np
from six.moves import cPickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

Expand All @@ -14,18 +17,17 @@

from prednet import PredNet
from process_kitti import SequenceGenerator
import kitti_settings

plot_save_dir = './kitti_plots/'
n_plot = 20

weights_file = 'prednet_weights.hdf5'
config_file = 'prednet_config.pkl'
test_file = './kitti_data/X_test.hkl'
test_sources = './kitti_data/sources_test.hkl'

batch_size = 5
batch_size = 10
nt = 10

weights_file = os.path.join(weights_dir, 'prednet_kitti_weights.hdf5')
config_file = os.path.join(weights_dir, 'prednet_kitti_config.pkl')
test_file = os.path.join(data_dir, 'X_test.hkl')
test_sources = os.path.join(data_dir, 'sources_test.hkl')

# Load trained model
config = cPickle.load(open(config_file))
train_model = Model.from_config(config, custom_objects = {'PredNet': PredNet})
Expand Down Expand Up @@ -54,7 +56,7 @@
plt.figure(figsize = (nt, 2*aspect_ratio))
gs = gridspec.GridSpec(2, nt)
gs.update(wspace=0.025, hspace=0.05)
if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)
if not os.path.exists(eval_save_dir): os.mkdir(eval_save_dir)
for i in range(n_plot):
for t in range(nt):
plt.subplot(gs[t])
Expand All @@ -67,5 +69,5 @@
plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
if t==0: plt.ylabel('Predicted')

plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png')
plt.savefig(os.path.join(eval_save_dir, 'plot_' + str(i) + '.png'))
plt.clf()
10 changes: 10 additions & 0 deletions kitti_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Where KITTI data will be saved if you run process_kitti.py
# If you directly download the processed data, change to the path of the data.
data_dir = './kitti_data/'

# Where model weights and config will be saved if you run kitti_train.py
# If you directly download the trained weights, change to appropriate path.
weights_dir = './'

# Where evaluation results (prediction plots) will be saved.
eval_save_dir = './'
45 changes: 22 additions & 23 deletions kitti_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Train PredNet on KITTI sequences. (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/)
'''

import os
import numpy as np
from six.moves import cPickle

Expand All @@ -11,25 +12,27 @@
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam

from prednet import PredNet
from process_kitti import SequenceGenerator
import kitti_settings

save_model = True # if weights will be saved
weights_file = os.path.join(weights_dir, 'prednet_kitti_weights.hdf5') # where weights will be saved
config_file = os.path.join(weights_dir, 'prednet_kitti_config.pkl')

# Data files
train_file = os.path.join(data_dir, 'X_train.hkl')
train_sources = os.path.join(data_dir, 'sources_train.hkl')
val_file = os.path.join(data_dir, 'X_val.hkl')
val_sources = os.path.join(data_dir, 'sources_val.hkl')

# Training parameters
nb_epoch = 2
nb_epoch = 100
batch_size = 5
samples_per_epoch = 100 #500
samples_per_epoch = 500
N_seq_val = 100 # number of sequences to use for validation
use_early_stopping = True
patience = 5
save_model = True
save_name = 'prednet_'

# Data files
train_file = './kitti_data/X_train.hkl'
train_sources = './kitti_data/sources_train.hkl'
val_file = './kitti_data/X_val.hkl'
val_sources = './kitti_data/sources_val.hkl'

# Model parameters
nt = 10
Expand All @@ -55,23 +58,19 @@
errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
model = Model(input=inputs, output=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')
optimizer = Adam(lr=0.0005)
model.compile(loss='mean_absolute_error', optimizer=optimizer)

train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size)
val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)
callbacks = []
if use_early_stopping:
callbacks.append(EarlyStopping(monitor='val_loss', patience=patience))
if save_model:
callbacks.append(ModelCheckpoint(filepath=save_name + 'weights.hdf5', monitor='val_loss', save_best_only=True))
if not os.path.exists(weights_dir): os.mkdir(weights_dir)
callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))

history = model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks,
validation_data=val_generator, nb_val_samples=val_generator.N_sequences/batch_size)
model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks,
validation_data=val_generator, nb_val_samples=N_seq_val)

if save_model:
config = model.get_config()
cPickle.dump(config, open(save_name + '_config.pkl', 'w'))

#TODO: remove
from helpers import plot_training_error
plot_training_error(train_err = history.history['loss'], val_err = history.history['val_loss'], run_name = 'kitti', out_file = 'error_plot.jpg')
cPickle.dump(config, open(config_file, 'w'))
29 changes: 12 additions & 17 deletions process_kitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@

from keras import backend as K
from keras.preprocessing.image import Iterator
import kitti_settings

np.random.seed(123)

data_dir = './kitti_data/'
desired_im_sz = (128, 160)
n_val_by_cat = {'city': 1} # number of recordings to use for validation out of each category
n_test_by_cat = {'city': 1, 'residential': 1, 'road': 1} # number of recordings for testing

categories = ['city', 'residential', 'road']

if data_dir[-1] != '/': data_dir += '/'
np.random.seed(123)
if not os.path.exists(data_dir): os.mkdir(data_dir)

# Download raw zip files by scraping KITTI website
def download_data():
base_dir = data_dir + 'raw/'
base_dir = os.path.join(data_dir, 'raw/')
if not os.path.exists(base_dir): os.mkdir(base_dir)
for c in categories:
url = "http://www.cvlibs.net/datasets/kitti/raw_data.php?type=" + c
Expand All @@ -47,7 +45,7 @@ def download_data():
# unzip images
def extract_data():
for c in categories:
c_dir = data_dir + 'raw/' + c + '/'
c_dir = os.path.join(data_dir, 'raw/', c + '/')
_, _, zip_files = os.walk(c_dir).next()
for f in zip_files:
print 'unpacking: ' + f
Expand All @@ -61,7 +59,7 @@ def extract_data():
def process_data():
splits = {s: [] for s in ['train', 'test', 'val']}
for c in categories: # Randomly assign recordings to training and testing. Cross-validation done across entire recordings.
c_dir = data_dir + 'raw/' + c + '/'
c_dir = os.path.join(data_dir, 'raw', c + '/')
_, folders, _ = os.walk(c_dir).next()
folders = np.random.permutation(folders)
n_val = 0 if c not in n_val_by_cat else n_val_by_cat[c]
Expand All @@ -74,7 +72,7 @@ def process_data():
im_list = []
source_list = [] # corresponds to recording that image came from
for category, folder in splits[split]:
im_dir = data_dir + 'raw/' + category + '/' + folder + '/' + folder[:10] + '/' + folder + '/image_03/data/'
im_dir = os.path.join('raw/', category, folder, folder[:10], folder, '/image_03/data/')
_, _, files = os.walk(im_dir).next()
im_list += [im_dir + f for f in sorted(files)]
source_list += [category + '-' + folder] * len(files)
Expand All @@ -85,8 +83,8 @@ def process_data():
im = imread(im_file)
X[i] = process_im(im, desired_im_sz)

hkl.dump(X, data_dir + 'X_' + split + '.hkl')
hkl.dump(source_list, data_dir + 'sources_' + split + '.hkl')
hkl.dump(X, os.path.join(data_dir, 'X_' + split + '.hkl'))
hkl.dump(source_list, os.path.join(data_dir, 'sources_' + split + '.hkl'))


# resize and crop image
Expand Down Expand Up @@ -144,9 +142,9 @@ def next(self):
for i, idx in enumerate(index_array):
idx = self.possible_starts[idx]
batch_x[i] = self.preprocess(self.X[idx:idx+self.nt])
if output_mode == 'error': # model outputs errors, so y should be zeros
if self.output_mode == 'error': # model outputs errors, so y should be zeros
batch_y = np.zeros(current_batch_size, np.float32)
elif output_mode == 'prediction': # output actual pixels
elif self.output_mode == 'prediction': # output actual pixels
batch_y = batch_x
return batch_x, batch_y

Expand All @@ -161,9 +159,6 @@ def create_all(self):


if __name__ == '__main__':
import time
t0 = time.time()
download_data()
print 'time to download: ' + str( (time.time() - t0)/60 )
#extract_data()
#process_data()
extract_data()
process_data()

0 comments on commit a2a29a4

Please sign in to comment.