From a2a29a41a7246c01d59c9e023eee2cc6e9852765 Mon Sep 17 00:00:00 2001 From: Bill Lotter Date: Sat, 9 Jul 2016 18:16:13 -0400 Subject: [PATCH] Moved paths to kitti_settings --- kitti_evaluate.py | 22 ++++++++++++---------- kitti_settings.py | 10 ++++++++++ kitti_train.py | 45 ++++++++++++++++++++++----------------------- process_kitti.py | 29 ++++++++++++----------------- 4 files changed, 56 insertions(+), 50 deletions(-) create mode 100644 kitti_settings.py diff --git a/kitti_evaluate.py b/kitti_evaluate.py index b7389fc..72c1d5a 100644 --- a/kitti_evaluate.py +++ b/kitti_evaluate.py @@ -3,8 +3,11 @@ Calculates mean-squared error and plots predictions. ''' +import os import numpy as np from six.moves import cPickle +import matplotlib +matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec @@ -14,18 +17,17 @@ from prednet import PredNet from process_kitti import SequenceGenerator +import kitti_settings -plot_save_dir = './kitti_plots/' n_plot = 20 - -weights_file = 'prednet_weights.hdf5' -config_file = 'prednet_config.pkl' -test_file = './kitti_data/X_test.hkl' -test_sources = './kitti_data/sources_test.hkl' - -batch_size = 5 +batch_size = 10 nt = 10 +weights_file = os.path.join(weights_dir, 'prednet_kitti_weights.hdf5') +config_file = os.path.join(weights_dir, 'prednet_kitti_config.pkl') +test_file = os.path.join(data_dir, 'X_test.hkl') +test_sources = os.path.join(data_dir, 'sources_test.hkl') + # Load trained model config = cPickle.load(open(config_file)) train_model = Model.from_config(config, custom_objects = {'PredNet': PredNet}) @@ -54,7 +56,7 @@ plt.figure(figsize = (nt, 2*aspect_ratio)) gs = gridspec.GridSpec(2, nt) gs.update(wspace=0.025, hspace=0.05) -if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir) +if not os.path.exists(eval_save_dir): os.mkdir(eval_save_dir) for i in range(n_plot): for t in range(nt): plt.subplot(gs[t]) @@ -67,5 +69,5 @@ plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') if t==0: plt.ylabel('Predicted') - plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png') + plt.savefig(os.path.join(eval_save_dir, 'plot_' + str(i) + '.png')) plt.clf() diff --git a/kitti_settings.py b/kitti_settings.py new file mode 100644 index 0000000..83b4f6f --- /dev/null +++ b/kitti_settings.py @@ -0,0 +1,10 @@ +# Where KITTI data will be saved if you run process_kitti.py +# If you directly download the processed data, change to the path of the data. +data_dir = './kitti_data/' + +# Where model weights and config will be saved if you run kitti_train.py +# If you directly download the trained weights, change to appropriate path. +weights_dir = './' + +# Where evaluation results (prediction plots) will be saved. +eval_save_dir = './' diff --git a/kitti_train.py b/kitti_train.py index a068e41..3064bfc 100644 --- a/kitti_train.py +++ b/kitti_train.py @@ -2,6 +2,7 @@ Train PredNet on KITTI sequences. (Geiger et al. 2013, http://www.cvlibs.net/datasets/kitti/) ''' +import os import numpy as np from six.moves import cPickle @@ -11,25 +12,27 @@ from keras.layers.recurrent import LSTM from keras.layers.wrappers import TimeDistributed from keras.callbacks import EarlyStopping, ModelCheckpoint +from keras.optimizers import Adam from prednet import PredNet from process_kitti import SequenceGenerator +import kitti_settings + +save_model = True # if weights will be saved +weights_file = os.path.join(weights_dir, 'prednet_kitti_weights.hdf5') # where weights will be saved +config_file = os.path.join(weights_dir, 'prednet_kitti_config.pkl') + +# Data files +train_file = os.path.join(data_dir, 'X_train.hkl') +train_sources = os.path.join(data_dir, 'sources_train.hkl') +val_file = os.path.join(data_dir, 'X_val.hkl') +val_sources = os.path.join(data_dir, 'sources_val.hkl') # Training parameters -nb_epoch = 2 +nb_epoch = 100 batch_size = 5 -samples_per_epoch = 100 #500 +samples_per_epoch = 500 N_seq_val = 100 # number of sequences to use for validation -use_early_stopping = True -patience = 5 -save_model = True -save_name = 'prednet_' - -# Data files -train_file = './kitti_data/X_train.hkl' -train_sources = './kitti_data/sources_train.hkl' -val_file = './kitti_data/X_val.hkl' -val_sources = './kitti_data/sources_val.hkl' # Model parameters nt = 10 @@ -55,23 +58,19 @@ errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt) final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time model = Model(input=inputs, output=final_errors) -model.compile(loss='mean_absolute_error', optimizer='adam') +optimizer = Adam(lr=0.0005) +model.compile(loss='mean_absolute_error', optimizer=optimizer) train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size) val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val) callbacks = [] -if use_early_stopping: - callbacks.append(EarlyStopping(monitor='val_loss', patience=patience)) if save_model: - callbacks.append(ModelCheckpoint(filepath=save_name + 'weights.hdf5', monitor='val_loss', save_best_only=True)) + if not os.path.exists(weights_dir): os.mkdir(weights_dir) + callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True)) -history = model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks, - validation_data=val_generator, nb_val_samples=val_generator.N_sequences/batch_size) +model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks, + validation_data=val_generator, nb_val_samples=N_seq_val) if save_model: config = model.get_config() - cPickle.dump(config, open(save_name + '_config.pkl', 'w')) - -#TODO: remove -from helpers import plot_training_error -plot_training_error(train_err = history.history['loss'], val_err = history.history['val_loss'], run_name = 'kitti', out_file = 'error_plot.jpg') + cPickle.dump(config, open(config_file, 'w')) diff --git a/process_kitti.py b/process_kitti.py index 14cb68a..3ef9028 100644 --- a/process_kitti.py +++ b/process_kitti.py @@ -12,22 +12,20 @@ from keras import backend as K from keras.preprocessing.image import Iterator +import kitti_settings -np.random.seed(123) -data_dir = './kitti_data/' desired_im_sz = (128, 160) n_val_by_cat = {'city': 1} # number of recordings to use for validation out of each category n_test_by_cat = {'city': 1, 'residential': 1, 'road': 1} # number of recordings for testing - categories = ['city', 'residential', 'road'] -if data_dir[-1] != '/': data_dir += '/' +np.random.seed(123) if not os.path.exists(data_dir): os.mkdir(data_dir) # Download raw zip files by scraping KITTI website def download_data(): - base_dir = data_dir + 'raw/' + base_dir = os.path.join(data_dir, 'raw/') if not os.path.exists(base_dir): os.mkdir(base_dir) for c in categories: url = "http://www.cvlibs.net/datasets/kitti/raw_data.php?type=" + c @@ -47,7 +45,7 @@ def download_data(): # unzip images def extract_data(): for c in categories: - c_dir = data_dir + 'raw/' + c + '/' + c_dir = os.path.join(data_dir, 'raw/', c + '/') _, _, zip_files = os.walk(c_dir).next() for f in zip_files: print 'unpacking: ' + f @@ -61,7 +59,7 @@ def extract_data(): def process_data(): splits = {s: [] for s in ['train', 'test', 'val']} for c in categories: # Randomly assign recordings to training and testing. Cross-validation done across entire recordings. - c_dir = data_dir + 'raw/' + c + '/' + c_dir = os.path.join(data_dir, 'raw', c + '/') _, folders, _ = os.walk(c_dir).next() folders = np.random.permutation(folders) n_val = 0 if c not in n_val_by_cat else n_val_by_cat[c] @@ -74,7 +72,7 @@ def process_data(): im_list = [] source_list = [] # corresponds to recording that image came from for category, folder in splits[split]: - im_dir = data_dir + 'raw/' + category + '/' + folder + '/' + folder[:10] + '/' + folder + '/image_03/data/' + im_dir = os.path.join('raw/', category, folder, folder[:10], folder, '/image_03/data/') _, _, files = os.walk(im_dir).next() im_list += [im_dir + f for f in sorted(files)] source_list += [category + '-' + folder] * len(files) @@ -85,8 +83,8 @@ def process_data(): im = imread(im_file) X[i] = process_im(im, desired_im_sz) - hkl.dump(X, data_dir + 'X_' + split + '.hkl') - hkl.dump(source_list, data_dir + 'sources_' + split + '.hkl') + hkl.dump(X, os.path.join(data_dir, 'X_' + split + '.hkl')) + hkl.dump(source_list, os.path.join(data_dir, 'sources_' + split + '.hkl')) # resize and crop image @@ -144,9 +142,9 @@ def next(self): for i, idx in enumerate(index_array): idx = self.possible_starts[idx] batch_x[i] = self.preprocess(self.X[idx:idx+self.nt]) - if output_mode == 'error': # model outputs errors, so y should be zeros + if self.output_mode == 'error': # model outputs errors, so y should be zeros batch_y = np.zeros(current_batch_size, np.float32) - elif output_mode == 'prediction': # output actual pixels + elif self.output_mode == 'prediction': # output actual pixels batch_y = batch_x return batch_x, batch_y @@ -161,9 +159,6 @@ def create_all(self): if __name__ == '__main__': - import time - t0 = time.time() download_data() - print 'time to download: ' + str( (time.time() - t0)/60 ) - #extract_data() - #process_data() + extract_data() + process_data()