diff --git a/README.md b/README.md
index 8cbce27..ece99aa 100755
--- a/README.md
+++ b/README.md
@@ -6,10 +6,11 @@ The PredNet is a deep recurrent convolutional neural network that is inspired by
**Check out example prediction videos [here](https://coxlab.github.io/prednet/).**
The architecture is implemented as a custom layer1 in [Keras](http://keras.io/).
-It is compatible with both [theano](http://deeplearning.net/software/theano/) and [tensorflow](https://www.tensorflow.org/) backends.
-Tested on Keras 1.2.1 with Theano 0.9.0, Tensorflow 0.12.1, and Python 2.7.
-See http://keras.io/ for instructions on installing Keras and its list of dependencies.
-For Torch implementation, see [torch-prednet](https://github.com/e-lab/torch-prednet).
+Code and model data is now compatible with Keras 2.0.
+Specifically, it has been tested on Keras 2.0.6 with Theano 0.9.0, Tensorflow 1.2.1, and Python 2.7.
+The provided weights were trained with the Theano backend.
+For previous versions of the code compatible with Keras 1.2.1, use fbcdc18.
+To convert old PredNet model files and weights for Keras 2.0 compatibility, see ```convert_model_to_keras2``` in `keras_utils.py`.
## KITTI Demo
@@ -48,6 +49,9 @@ The model download will include the original weights trained for t+1 prediction,
### Feature Extraction
Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate ```output_mode``` argument. For example, to extract the hidden state of the LSTM (the "Representation" units) in the lowest layer, use ```output_mode = 'R0'```. More details can be found in the PredNet docstring.
+### Multi-Step Prediction
+The PredNet argument ```extrap_start_time``` can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with ```extrap_start_time = 10```, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights, and the resulting fine-tuned weights are included in `download_models.sh`.
+
1 Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a 'super' layer where all layers in the PredNet are in one PredNet 'layer'. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation.
diff --git a/data_utils.py b/data_utils.py
old mode 100644
new mode 100755
index f70603f..863ddbd
--- a/data_utils.py
+++ b/data_utils.py
@@ -8,18 +8,18 @@ class SequenceGenerator(Iterator):
def __init__(self, data_file, source_file, nt,
batch_size=8, shuffle=False, seed=None,
output_mode='error', sequence_start_mode='all', N_seq=None,
- dim_ordering=K.image_dim_ordering()):
+ data_format=K.image_data_format()):
self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels)
self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video
self.nt = nt
self.batch_size = batch_size
- self.dim_ordering = dim_ordering
+ self.data_format = data_format
assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}'
self.sequence_start_mode = sequence_start_mode
assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}'
self.output_mode = output_mode
- if self.dim_ordering == 'th':
+ if self.data_format == 'channels_first':
self.X = np.transpose(self.X, (0, 3, 1, 2))
self.im_shape = self.X[0].shape
diff --git a/download_models.sh b/download_models.sh
index d1bb597..e438daf 100755
--- a/download_models.sh
+++ b/download_models.sh
@@ -1,7 +1,5 @@
-savedir="model_data"
+savedir="model_data_keras2"
mkdir -p -- "$savedir"
-wget https://www.dropbox.com/s/n6hllbbaeh3fpj9/prednet_kitti_model.zip?dl=0 -O $savedir/prednet_kitti_model.zip
-unzip $savedir/prednet_kitti_model.zip -d $savedir
-wget https://www.dropbox.com/s/zhcp20ixvufnma8/prednet_kitti_model-extrapfinetuned.zip?dl=0 -O $savedir/prednet_kitti_model-extrapfinetuned.zip
-unzip $savedir/prednet_kitti_model-extrapfinetuned.zip -d $savedir
-wget https://www.dropbox.com/s/e9048813j2fhdcw/prednet_kitti_weights-Lall.hdf5?dl=0 -O $savedir/prednet_kitti_weights-Lall.hdf5
+wget https://www.dropbox.com/s/z7ittwfxa5css7a/model_data_keras2.zip?dl=0 -O $savedir/model_data_keras2.zip
+unzip -j $savedir/model_data_keras2.zip -d $savedir
+rm $savedir/model_data_keras2.zip
diff --git a/keras_utils.py b/keras_utils.py
new file mode 100755
index 0000000..ededcc7
--- /dev/null
+++ b/keras_utils.py
@@ -0,0 +1,58 @@
+import os
+import numpy as np
+
+from keras import backend as K
+from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor
+from keras.models import model_from_json
+
+legacy_prednet_support = generate_legacy_interface(
+ allowed_positional_args=['stack_sizes', 'R_stack_sizes',
+ 'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'],
+ conversions=[('dim_ordering', 'data_format'),
+ ('consume_less', 'implementation')],
+ value_conversions={'dim_ordering': {'tf': 'channels_last',
+ 'th': 'channels_first',
+ 'default': None},
+ 'consume_less': {'cpu': 0,
+ 'mem': 1,
+ 'gpu': 2}},
+ preprocessor=recurrent_args_preprocessor)
+
+# Convert old Keras (1.2) json models and weights to Keras 2.0
+def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file):
+ from prednet import PredNet
+ # If using tensorflow, it doesn't allow you to load the old weights.
+ if K.backend() != 'theano':
+ os.environ['KERAS_BACKEND'] = backend
+ reload(K)
+
+ f = open(old_json_file, 'r')
+ json_string = f.read()
+ f.close()
+ model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
+ model.load_weights(old_weights_file)
+
+ weights = model.layers[1].get_weights()
+ if weights[0].shape[0] == model.layers[1].stack_sizes[1]:
+ for i, w in enumerate(weights):
+ if w.ndim == 4:
+ weights[i] = np.transpose(w, (2, 3, 1, 0))
+ model.set_weights(weights)
+
+ model.save_weights(new_weights_file)
+ json_string = model.to_json()
+ with open(new_json_file, "w") as f:
+ f.write(json_string)
+
+
+if __name__ == '__main__':
+ old_dir = './model_data/'
+ new_dir = './model_data_keras2/'
+ if not os.path.exists(new_dir):
+ os.mkdir(new_dir)
+ for w_tag in ['', '-Lall', '-extrapfinetuned']:
+ m_tag = '' if w_tag == '-Lall' else w_tag
+ convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json',
+ old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5',
+ new_dir + 'prednet_kitti_model' + m_tag + '.json',
+ new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5')
diff --git a/kitti_evaluate.py b/kitti_evaluate.py
index b3399bc..9dc5582 100755
--- a/kitti_evaluate.py
+++ b/kitti_evaluate.py
@@ -39,18 +39,18 @@
# Create testing model (to output predictions)
layer_config = train_model.layers[1].get_config()
layer_config['output_mode'] = 'prediction'
-dim_ordering = layer_config['dim_ordering']
+data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
input_shape = list(train_model.layers[0].batch_input_shape[1:])
input_shape[0] = nt
inputs = Input(shape=tuple(input_shape))
predictions = test_prednet(inputs)
-test_model = Model(input=inputs, output=predictions)
+test_model = Model(inputs=inputs, outputs=predictions)
-test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', dim_ordering=dim_ordering)
+test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
X_test = test_generator.create_all()
X_hat = test_model.predict(X_test, batch_size)
-if dim_ordering == 'th':
+if data_format == 'channels_first':
X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))
diff --git a/kitti_settings.py b/kitti_settings.py
old mode 100644
new mode 100755
index b43796f..b019346
--- a/kitti_settings.py
+++ b/kitti_settings.py
@@ -4,7 +4,7 @@
# Where model weights and config will be saved if you run kitti_train.py
# If you directly download the trained weights, change to appropriate path.
-WEIGHTS_DIR = './model_data/'
+WEIGHTS_DIR = './model_data_keras2/'
# Where results (prediction plots and evaluation file) will be saved.
RESULTS_SAVE_DIR = './kitti_results/'
diff --git a/kitti_train.py b/kitti_train.py
index ea6b930..c0c97e9 100755
--- a/kitti_train.py
+++ b/kitti_train.py
@@ -37,17 +37,17 @@
N_seq_val = 100 # number of sequences to use for validation
# Model parameters
-nt = 10
n_channels, im_height, im_width = (3, 128, 160)
-input_shape = (n_channels, im_height, im_width) if K.image_dim_ordering() == 'th' else (im_height, im_width, n_channels)
+input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
stack_sizes = (n_channels, 48, 96, 192)
R_stack_sizes = stack_sizes
A_filt_sizes = (3, 3, 3)
Ahat_filt_sizes = (3, 3, 3, 3)
R_filt_sizes = (3, 3, 3, 3)
-layer_loss_weights = np.array([1., 0., 0., 0.])
+layer_loss_weights = np.array([1., 0., 0., 0.]) # weighting for each layer in final loss; "L_0" model: [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
-time_loss_weights = 1./ (nt - 1) * np.ones((nt,1))
+nt = 10 # number of timesteps used for sequences in training
+time_loss_weights = 1./ (nt - 1) * np.ones((nt,1)) # equally weight all timesteps except the first
time_loss_weights[0] = 0
@@ -60,7 +60,7 @@
errors_by_time = TimeDistributed(Dense(1, weights=[layer_loss_weights, np.zeros(1)], trainable=False), trainable=False)(errors) # calculate weighted error by layer
errors_by_time = Flatten()(errors_by_time) # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time) # weight errors by time
-model = Model(input=inputs, output=final_errors)
+model = Model(inputs=inputs, outputs=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')
train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=True)
@@ -72,8 +72,8 @@
if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))
-history = model.fit_generator(train_generator, samples_per_epoch, nb_epoch, callbacks=callbacks,
- validation_data=val_generator, nb_val_samples=N_seq_val)
+history = model.fit_generator(train_generator, samples_per_epoch / batch_size, nb_epoch, callbacks=callbacks,
+ validation_data=val_generator, validation_steps=N_seq_val / batch_size)
if save_model:
json_string = model.to_json()
diff --git a/prednet.py b/prednet.py
index 2af75ac..cc7b012 100755
--- a/prednet.py
+++ b/prednet.py
@@ -3,9 +3,9 @@
from keras import backend as K
from keras import activations
from keras.layers import Recurrent
-from keras.layers import Convolution2D, UpSampling2D, MaxPooling2D
+from keras.layers import Conv2D, UpSampling2D, MaxPooling2D
from keras.engine import InputSpec
-
+from keras_utils import legacy_prednet_support
class PredNet(Recurrent):
'''PredNet architecture - Lotter 2016.
@@ -49,11 +49,9 @@ class PredNet(Recurrent):
The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively.
extrap_start_time: time step for which model will start extrapolating.
Starting at this time step, the prediction from the previous time step will be treated as the "actual"
- dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
- (the depth) is at index 1, in 'tf' mode is it at index 3.
- It defaults to the `image_dim_ordering` value found in your
+ data_format: 'channels_first' or 'channels_last'.
+ It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be "th".
# References
- [Deep predictive coding networks for video prediction and unsupervised learning](https://arxiv.org/abs/1605.08104)
@@ -61,12 +59,13 @@ class PredNet(Recurrent):
- [Convolutional LSTM network: a machine learning approach for precipitation nowcasting](http://arxiv.org/abs/1506.04214)
- [Predictive coding in the visual cortex: a functional interpretation of some extra-classical receptive-field effects](http://www.nature.com/neuro/journal/v2/n1/pdf/nn0199_79.pdf)
'''
+ @legacy_prednet_support
def __init__(self, stack_sizes, R_stack_sizes,
A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
pixel_max=1., error_activation='relu', A_activation='relu',
LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid',
- output_mode='error', extrap_start_time = None,
- dim_ordering=K.image_dim_ordering(), **kwargs):
+ output_mode='error', extrap_start_time=None,
+ data_format=K.image_data_format(), **kwargs):
self.stack_sizes = stack_sizes
self.nb_layers = len(stack_sizes)
assert len(R_stack_sizes) == self.nb_layers, 'len(R_stack_sizes) must equal len(stack_sizes)'
@@ -96,16 +95,15 @@ def __init__(self, stack_sizes, R_stack_sizes,
self.output_layer_num = None
self.extrap_start_time = extrap_start_time
- assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
- self.dim_ordering = dim_ordering
- self.channel_axis = -3 if dim_ordering == 'th' else -1
- self.row_axis = -2 if dim_ordering == 'th' else -3
- self.column_axis = -1 if dim_ordering == 'th' else -2
-
+ assert data_format in {'channels_last', 'channels_first'}, 'data_format must be in {channels_last, channels_first}'
+ self.data_format = data_format
+ self.channel_axis = -3 if data_format == 'channels_first' else -1
+ self.row_axis = -2 if data_format == 'channels_first' else -3
+ self.column_axis = -1 if data_format == 'channels_first' else -2
super(PredNet, self).__init__(**kwargs)
self.input_spec = [InputSpec(ndim=5)]
- def get_output_shape_for(self, input_shape):
+ def compute_output_shape(self, input_shape):
if self.output_mode == 'prediction':
out_shape = input_shape[2:]
elif self.output_mode == 'error':
@@ -118,7 +116,7 @@ def get_output_shape_for(self, input_shape):
out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num]
out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num
out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num
- if self.dim_ordering == 'th':
+ if self.data_format == 'channels_first':
out_shape = (out_stack_size, out_nb_row, out_nb_col)
else:
out_shape = (out_nb_row, out_nb_col, out_stack_size)
@@ -128,13 +126,13 @@ def get_output_shape_for(self, input_shape):
else:
return (input_shape[0],) + out_shape
- def get_initial_states(self, x):
+ def get_initial_state(self, x):
input_shape = self.input_spec[0].shape
init_nb_row = input_shape[self.row_axis]
init_nb_col = input_shape[self.column_axis]
base_initial_state = K.zeros_like(x) # (samples, timesteps) + image_shape
- non_channel_axis = -1 if self.dim_ordering == 'th' else -2
+ non_channel_axis = -1 if self.data_format == 'channels_first' else -2
for _ in range(2):
base_initial_state = K.sum(base_initial_state, axis=non_channel_axis)
base_initial_state = K.sum(base_initial_state, axis=1) # (samples, nb_channels)
@@ -160,7 +158,7 @@ def get_initial_states(self, x):
reducer = K.zeros((input_shape[self.channel_axis], output_size)) # (nb_channels, output_size)
initial_state = K.dot(base_initial_state, reducer) # (samples, output_size)
- if self.dim_ordering == 'th':
+ if self.data_format == 'channels_first':
output_shp = (-1, stack_size, nb_row, nb_col)
else:
output_shp = (-1, nb_row, nb_col, stack_size)
@@ -174,7 +172,7 @@ def get_initial_states(self, x):
initial_states = [T.unbroadcast(init_state, 0, 1) for init_state in initial_states]
if self.extrap_start_time is not None:
- initial_states += [K.variable(0, int)] # the last state will correspond to the current timestep
+ initial_states += [K.variable(0, int if K.backend() != 'tensorflow' else 'int32')] # the last state will correspond to the current timestep
return initial_states
def build(self, input_shape):
@@ -184,19 +182,19 @@ def build(self, input_shape):
for l in range(self.nb_layers):
for c in ['i', 'f', 'c', 'o']:
act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation
- self.conv_layers[c].append(Convolution2D(self.R_stack_sizes[l], self.R_filt_sizes[l], self.R_filt_sizes[l], border_mode='same', activation=act, dim_ordering=self.dim_ordering))
+ self.conv_layers[c].append(Conv2D(self.R_stack_sizes[l], self.R_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
act = 'relu' if l == 0 else self.A_activation
- self.conv_layers['ahat'].append(Convolution2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], self.Ahat_filt_sizes[l], border_mode='same', activation=act, dim_ordering=self.dim_ordering))
+ self.conv_layers['ahat'].append(Conv2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], padding='same', activation=act, data_format=self.data_format))
if l < self.nb_layers - 1:
- self.conv_layers['a'].append(Convolution2D(self.stack_sizes[l+1], self.A_filt_sizes[l], self.A_filt_sizes[l], border_mode='same', activation=self.A_activation, dim_ordering=self.dim_ordering))
+ self.conv_layers['a'].append(Conv2D(self.stack_sizes[l+1], self.A_filt_sizes[l], padding='same', activation=self.A_activation, data_format=self.data_format))
- self.upsample = UpSampling2D(dim_ordering=self.dim_ordering)
- self.pool = MaxPooling2D(dim_ordering=self.dim_ordering)
+ self.upsample = UpSampling2D(data_format=self.data_format)
+ self.pool = MaxPooling2D(data_format=self.data_format)
self.trainable_weights = []
- nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.dim_ordering == 'th' else (input_shape[-3], input_shape[-2])
+ nb_row, nb_col = (input_shape[-2], input_shape[-1]) if self.data_format == 'channels_first' else (input_shape[-3], input_shape[-2])
for c in sorted(self.conv_layers.keys()):
for l in range(len(self.conv_layers[c])):
ds_factor = 2 ** l
@@ -209,18 +207,16 @@ def build(self, input_shape):
if l < self.nb_layers - 1:
nb_channels += self.R_stack_sizes[l+1]
in_shape = (input_shape[0], nb_channels, nb_row // ds_factor, nb_col // ds_factor)
- if self.dim_ordering == 'tf': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1])
- self.conv_layers[c][l].build(in_shape)
+ if self.data_format == 'channels_last': in_shape = (in_shape[0], in_shape[2], in_shape[3], in_shape[1])
+ with K.name_scope('layer_' + c + '_' + str(l)):
+ self.conv_layers[c][l].build(in_shape)
self.trainable_weights += self.conv_layers[c][l].trainable_weights
- if self.initial_weights is not None:
- self.set_weights(self.initial_weights)
- del self.initial_weights
-
self.states = [None] * self.nb_layers*3
if self.extrap_start_time is not None:
- self.t_extrap = K.variable(self.extrap_start_time, int)
+ self.t_extrap = K.variable(self.extrap_start_time, int if K.backend() != 'tensorflow' else 'int32')
+ self.states += [None] * 2 # [previous frame prediction, timestep]
def step(self, a, states):
r_tm1 = states[:self.nb_layers]
@@ -235,6 +231,7 @@ def step(self, a, states):
r = []
e = []
+ # Update R units starting from the top
for l in reversed(range(self.nb_layers)):
inputs = [r_tm1[l], e_tm1[l]]
if l < self.nb_layers - 1:
@@ -252,6 +249,7 @@ def step(self, a, states):
if l > 0:
r_up = self.upsample.call(_r)
+ # Update feedforward path starting from the bottom
for l in range(self.nb_layers):
ahat = self.conv_layers['ahat'][l].call(r[l])
if l == 0:
@@ -306,7 +304,7 @@ def get_config(self):
'A_activation': self.A_activation.__name__,
'LSTM_activation': self.LSTM_activation.__name__,
'LSTM_inner_activation': self.LSTM_inner_activation.__name__,
- 'dim_ordering': self.dim_ordering,
+ 'data_format': self.data_format,
'extrap_start_time': self.extrap_start_time,
'output_mode': self.output_mode}
base_config = super(PredNet, self).get_config()