Skip to content

Commit

Permalink
Added feature extraction output modes
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Feb 24, 2017
1 parent d15ee79 commit 1d405a5
Showing 1 changed file with 51 additions and 22 deletions.
73 changes: 51 additions & 22 deletions prednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ class PredNet(Recurrent):
A_activation: activation function for the target (A) and prediction (A_hat) units.
LSTM_activation: activation function for the cell and hidden states of the LSTM.
LSTM_inner_activation: activation function for the gates in the LSTM.
output_mode: either 'error', 'prediction', or 'all'.
output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below).
Controls what is outputted by the PredNet.
If 'error', the mean response of the error (E) units of each layer will be outputted.
That is, the output shape will be (batch_size, nb_layers).
If 'prediction', the frame prediction will be outputted.
If 'all', the output will be the frame prediction concatenated with the mean layer errors.
The frame prediction is flattened before concatenation.
Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model
For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number.
For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'.
The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively.
extrap_start_time: time step for which model will start extrapolating.
Starting at this time step, the prediction from the previous time step will be treated as the "actual"
dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension
Expand Down Expand Up @@ -80,8 +84,16 @@ def __init__(self, stack_sizes, R_stack_sizes,
self.LSTM_activation = activations.get(LSTM_activation)
self.LSTM_inner_activation = activations.get(LSTM_inner_activation)

assert output_mode in {'prediction', 'error', 'all'}, 'output_mode must be in {prediction, error, all}'
default_output_modes = ['prediction', 'error', 'all']
layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']]
assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode)
self.output_mode = output_mode
if self.output_mode in layer_output_modes:
self.output_layer_type = self.output_mode[:-1]
self.output_layer_num = int(self.output_mode[-1])
else:
self.output_layer_type = None
self.output_layer_num = None
self.extrap_start_time = extrap_start_time

assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
Expand All @@ -95,20 +107,26 @@ def __init__(self, stack_sizes, R_stack_sizes,

def get_output_shape_for(self, input_shape):
if self.output_mode == 'prediction':
if self.return_sequences:
return input_shape
else:
return (input_shape[0],) + input_shape[2:]
out_shape = input_shape[2:]
elif self.output_mode == 'error':
if self.return_sequences:
return (input_shape[0], input_shape[1], self.nb_layers)
else:
return (input_shape[0], self.nb_layers)
out_shape = (self.nb_layers,)
elif self.output_mode == 'all':
out_shape = (np.prod(input_shape[2:]) + self.nb_layers,)
else:
if self.return_sequences:
return (input_shape[0], input_shape[1], np.prod(input_shape[2:]) + self.nb_layers)
stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes'
stack_mult = 2 if self.output_layer_type == 'E' else 1
out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num]
out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num
out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num
if self.dim_ordering == 'th':
out_shape = (-1, out_stack_size, out_nb_row, out_nb_col)
else:
return (input_shape[0], np.prod(input_shape[2:]) + self.nb_layers)
out_shape = (-1, out_nb_row, out_nb_col, out_stack_size)

if self.return_sequences:
return (input_shape[0], input_shape[1]) + out_shape
else:
return (input_shape[0],) + out_shape

def get_initial_states(self, x):
input_shape = self.input_spec[0].shape
Expand Down Expand Up @@ -246,20 +264,31 @@ def step(self, a, states):

e.append(K.concatenate((e_up, e_down), axis=self.channel_axis))

if self.output_layer_num == l:
if self.output_layer_type == 'A':
output = a
elif self.output_layer_type == 'Ahat':
output = ahat
elif self.output_layer_type == 'R':
output = r[l]
elif self.output_layer_type == 'E':
output = e[l]

if l < self.nb_layers - 1:
a = self.conv_layers['a'][l].call(e[l])
a = self.pool.call(a) # target for next layer

if self.output_mode == 'prediction':
output = frame_prediction
else:
for l in range(self.nb_layers):
layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True)
all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1)
if self.output_mode == 'error':
output = all_error
if self.output_layer_type is None:
if self.output_mode == 'prediction':
output = frame_prediction
else:
output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1)
for l in range(self.nb_layers):
layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True)
all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1)
if self.output_mode == 'error':
output = all_error
else:
output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1)

states = r + c + e
if self.extrap_start_time is not None:
Expand Down

0 comments on commit 1d405a5

Please sign in to comment.