diff --git a/prednet.py b/prednet.py index 7faddec..b2e02c7 100755 --- a/prednet.py +++ b/prednet.py @@ -119,9 +119,9 @@ def get_output_shape_for(self, input_shape): out_nb_row = input_shape[self.row_axis] / 2**self.output_layer_num out_nb_col = input_shape[self.column_axis] / 2**self.output_layer_num if self.dim_ordering == 'th': - out_shape = (-1, out_stack_size, out_nb_row, out_nb_col) + out_shape = (out_stack_size, out_nb_row, out_nb_col) else: - out_shape = (-1, out_nb_row, out_nb_col, out_stack_size) + out_shape = (out_nb_row, out_nb_col, out_stack_size) if self.return_sequences: return (input_shape[0], input_shape[1]) + out_shape