diff --git a/prednet.py b/prednet.py index f2f1d7e..7faddec 100755 --- a/prednet.py +++ b/prednet.py @@ -36,13 +36,17 @@ class PredNet(Recurrent): A_activation: activation function for the target (A) and prediction (A_hat) units. LSTM_activation: activation function for the cell and hidden states of the LSTM. LSTM_inner_activation: activation function for the gates in the LSTM. - output_mode: either 'error', 'prediction', or 'all'. + output_mode: either 'error', 'prediction', 'all' or layer specification (ex. R2, see below). Controls what is outputted by the PredNet. If 'error', the mean response of the error (E) units of each layer will be outputted. That is, the output shape will be (batch_size, nb_layers). If 'prediction', the frame prediction will be outputted. If 'all', the output will be the frame prediction concatenated with the mean layer errors. The frame prediction is flattened before concatenation. + Nomenclature of 'all' is kept for backwards compatibility, but should not be confused with returning all of the layers of the model + For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number. + For instance, to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'. + The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively. extrap_start_time: time step for which model will start extrapolating. Starting at this time step, the prediction from the previous time step will be treated as the "actual" dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension @@ -80,8 +84,16 @@ def __init__(self, stack_sizes, R_stack_sizes, self.LSTM_activation = activations.get(LSTM_activation) self.LSTM_inner_activation = activations.get(LSTM_inner_activation) - assert output_mode in {'prediction', 'error', 'all'}, 'output_mode must be in {prediction, error, all}' + default_output_modes = ['prediction', 'error', 'all'] + layer_output_modes = [layer + str(n) for n in range(self.nb_layers) for layer in ['R', 'E', 'A', 'Ahat']] + assert output_mode in default_output_modes + layer_output_modes, 'Invalid output_mode: ' + str(output_mode) self.output_mode = output_mode + if self.output_mode in layer_output_modes: + self.output_layer_type = self.output_mode[:-1] + self.output_layer_num = int(self.output_mode[-1]) + else: + self.output_layer_type = None + self.output_layer_num = None self.extrap_start_time = extrap_start_time assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' @@ -95,20 +107,26 @@ def __init__(self, stack_sizes, R_stack_sizes, def get_output_shape_for(self, input_shape): if self.output_mode == 'prediction': - if self.return_sequences: - return input_shape - else: - return (input_shape[0],) + input_shape[2:] + out_shape = input_shape[2:] elif self.output_mode == 'error': - if self.return_sequences: - return (input_shape[0], input_shape[1], self.nb_layers) - else: - return (input_shape[0], self.nb_layers) + out_shape = (self.nb_layers,) + elif self.output_mode == 'all': + out_shape = (np.prod(input_shape[2:]) + self.nb_layers,) else: - if self.return_sequences: - return (input_shape[0], input_shape[1], np.prod(input_shape[2:]) + self.nb_layers) + stack_str = 'R_stack_sizes' if self.output_layer_type == 'R' else 'stack_sizes' + stack_mult = 2 if self.output_layer_type == 'E' else 1 + out_stack_size = stack_mult * getattr(self, stack_str)[self.output_layer_num] + out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num + out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num + if self.dim_ordering == 'th': + out_shape = (-1, out_stack_size, out_nb_row, out_nb_col) else: - return (input_shape[0], np.prod(input_shape[2:]) + self.nb_layers) + out_shape = (-1, out_nb_row, out_nb_col, out_stack_size) + + if self.return_sequences: + return (input_shape[0], input_shape[1]) + out_shape + else: + return (input_shape[0],) + out_shape def get_initial_states(self, x): input_shape = self.input_spec[0].shape @@ -246,20 +264,31 @@ def step(self, a, states): e.append(K.concatenate((e_up, e_down), axis=self.channel_axis)) + if self.output_layer_num == l: + if self.output_layer_type == 'A': + output = a + elif self.output_layer_type == 'Ahat': + output = ahat + elif self.output_layer_type == 'R': + output = r[l] + elif self.output_layer_type == 'E': + output = e[l] + if l < self.nb_layers - 1: a = self.conv_layers['a'][l].call(e[l]) a = self.pool.call(a) # target for next layer - if self.output_mode == 'prediction': - output = frame_prediction - else: - for l in range(self.nb_layers): - layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) - all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1) - if self.output_mode == 'error': - output = all_error + if self.output_layer_type is None: + if self.output_mode == 'prediction': + output = frame_prediction else: - output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1) + for l in range(self.nb_layers): + layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) + all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1) + if self.output_mode == 'error': + output = all_error + else: + output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1) states = r + c + e if self.extrap_start_time is not None: