Skip to content

Commit

Permalink
Made compatible with Keras 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Aug 15, 2017
1 parent fbcdc18 commit cc76248
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 59 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ The PredNet is a deep recurrent convolutional neural network that is inspired by
**Check out example prediction videos [here](https://coxlab.github.io/prednet/).**

The architecture is implemented as a custom layer<sup>1</sup> in [Keras](http://keras.io/).
It is compatible with both [theano](http://deeplearning.net/software/theano/) and [tensorflow](https://www.tensorflow.org/) backends.
Tested on Keras 1.2.1 with Theano 0.9.0, Tensorflow 0.12.1, and Python 2.7.
See http://keras.io/ for instructions on installing Keras and its list of dependencies.
For Torch implementation, see [torch-prednet](https://github.com/e-lab/torch-prednet).
Code and model data is now compatible with Keras 2.0.
Specifically, it has been tested on Keras 2.0.6 with Theano 0.9.0, Tensorflow 1.2.1, and Python 2.7.
The provided weights were trained with the Theano backend.
For previous versions of the code compatible with Keras 1.2.1, use fbcdc18.
To convert old PredNet model files and weights for Keras 2.0 compatibility, see ```convert_model_to_keras2``` in `keras_utils.py`.
<br>

## KITTI Demo
Expand Down Expand Up @@ -48,6 +49,9 @@ The model download will include the original weights trained for t+1 prediction,
### Feature Extraction
Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate ```output_mode``` argument. For example, to extract the hidden state of the LSTM (the "Representation" units) in the lowest layer, use ```output_mode = 'R0'```. More details can be found in the PredNet docstring.

### Multi-Step Prediction
The PredNet argument ```extrap_start_time``` can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with ```extrap_start_time = 10```, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights, and the resulting fine-tuned weights are included in `download_models.sh`.

<br>

<sup>1</sup> Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a 'super' layer where all layers in the PredNet are in one PredNet 'layer'. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation.
6 changes: 3 additions & 3 deletions data_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ class SequenceGenerator(Iterator):
def __init__(self, data_file, source_file, nt,
batch_size=8, shuffle=False, seed=None,
output_mode='error', sequence_start_mode='all', N_seq=None,
dim_ordering=K.image_dim_ordering()):
data_format=K.image_data_format()):
self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels)
self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video
self.nt = nt
self.batch_size = batch_size
self.dim_ordering = dim_ordering
self.data_format = data_format
assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}'
self.sequence_start_mode = sequence_start_mode
assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}'
self.output_mode = output_mode

if self.dim_ordering == 'th':
if self.data_format == 'channels_first':
self.X = np.transpose(self.X, (0, 3, 1, 2))
self.im_shape = self.X[0].shape

Expand Down
10 changes: 4 additions & 6 deletions download_models.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
savedir="model_data"
savedir="model_data_keras2"
mkdir -p -- "$savedir"
wget https://www.dropbox.com/s/n6hllbbaeh3fpj9/prednet_kitti_model.zip?dl=0 -O $savedir/prednet_kitti_model.zip
unzip $savedir/prednet_kitti_model.zip -d $savedir
wget https://www.dropbox.com/s/zhcp20ixvufnma8/prednet_kitti_model-extrapfinetuned.zip?dl=0 -O $savedir/prednet_kitti_model-extrapfinetuned.zip
unzip $savedir/prednet_kitti_model-extrapfinetuned.zip -d $savedir
wget https://www.dropbox.com/s/e9048813j2fhdcw/prednet_kitti_weights-Lall.hdf5?dl=0 -O $savedir/prednet_kitti_weights-Lall.hdf5
wget https://www.dropbox.com/s/z7ittwfxa5css7a/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip
unzip -j $savedir/model_data_keras2.zip -d $savedir
rm $savedir/model_data_keras2.zip
58 changes: 58 additions & 0 deletions keras_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import numpy as np

from keras import backend as K
from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor
from keras.models import model_from_json

legacy_prednet_support = generate_legacy_interface(
allowed_positional_args=['stack_sizes', 'R_stack_sizes',
'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'],
conversions=[('dim_ordering', 'data_format'),
('consume_less', 'implementation')],
value_conversions={'dim_ordering': {'tf': 'channels_last',
'th': 'channels_first',
'default': None},
'consume_less': {'cpu': 0,
'mem': 1,
'gpu': 2}},
preprocessor=recurrent_args_preprocessor)

# Convert old Keras (1.2) json models and weights to Keras 2.0
def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file):
from prednet import PredNet
# If using tensorflow, it doesn't allow you to load the old weights.
if K.backend() != 'theano':
os.environ['KERAS_BACKEND'] = backend
reload(K)

f = open(old_json_file, 'r')
json_string = f.read()
f.close()
model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
model.load_weights(old_weights_file)

weights = model.layers[1].get_weights()
if weights[0].shape[0] == model.layers[1].stack_sizes[1]:
for i, w in enumerate(weights):
if w.ndim == 4:
weights[i] = np.transpose(w, (2, 3, 1, 0))
model.set_weights(weights)

model.save_weights(new_weights_file)
json_string = model.to_json()
with open(new_json_file, "w") as f:
f.write(json_string)


if __name__ == '__main__':
old_dir = './model_data/'
new_dir = './model_data_keras2/'
if not os.path.exists(new_dir):
os.mkdir(new_dir)
for w_tag in ['', '-Lall', '-extrapfinetuned']:
m_tag = '' if w_tag == '-Lall' else w_tag
convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json',
old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5',
new_dir + 'prednet_kitti_model' + m_tag + '.json',
new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5')
8 changes: 4 additions & 4 deletions kitti_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@
# Create testing model (to output predictions)
layer_config = train_model.layers[1].get_config()
layer_config['output_mode'] = 'prediction'
dim_ordering = layer_config['dim_ordering']
data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
input_shape = list(train_model.layers[0].batch_input_shape[1:])
input_shape[0] = nt
inputs = Input(shape=tuple(input_shape))
predictions = test_prednet(inputs)
test_model = Model(input=inputs, output=predictions)
test_model = Model(inputs=inputs, outputs=predictions)

test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', dim_ordering=dim_ordering)
test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
X_test = test_generator.create_all()
X_hat = test_model.predict(X_test, batch_size)
if dim_ordering == 'th':
if data_format == 'channels_first':
X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))

Expand Down
2 changes: 1 addition & 1 deletion kitti_settings.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Where model weights and config will be saved if you run kitti_train.py
# If you directly download the trained weights, change to appropriate path.
WEIGHTS_DIR = './model_data/'
WEIGHTS_DIR = './model_data_keras2/'

# Where results (prediction plots and evaluation file) will be saved.
RESULTS_SAVE_DIR = './kitti_results/'
14 changes: 7 additions & 7 deletions kitti_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
N_seq_val = 100 # number of sequences to use for validation

# Model parameters
nt = 10
n_channels, im_height, im_width = (3, 128, 160)
input_shape = (n_channels, im_height, im_width) if K.image_dim_ordering() == 'th' else (im_height, im_width, n_channels)
input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
stack_sizes = (n_channels, 48, 96, 192)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3)
layer_loss_weights = np.array([1., 0., 0., 0.])
layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1))
nt = 10 # number of timesteps used for sequences in training
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first
time_loss_weights[0] = 0


Expand All @@ -60,7 +60,7 @@
errors_by_time = TimeDistributed(Dense(1, weights=[layer_loss_weights, np.zeros(1)], trainable=False), trainable=False)(errors) # calculate weighted error by layer
errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
model = Model(input=inputs, output=final_errors)
model = Model(inputs=inputs, outputs=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')

train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
Expand All @@ -72,8 +72,8 @@
if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))

history = model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks,
validation_data=val_generator, nb_val_samples=N_seq_val)
history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
validation_data=val_generator, validation_steps=N_seq_val / batch_size)

if save_model:
json_string = model.to_json()
Expand Down
Loading

0 comments on commit cc76248

Please sign in to comment.