Skip to content

Commit

Permalink
Added A activation
Browse files Browse the repository at this point in the history
  • Loading branch information
bill-lotter committed Aug 28, 2016
1 parent 3fdd613 commit 7fdac5f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions prednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class PredNet(Recurrent):
pixel_max: the maximum pixel value.
Used to clip the pixel-layer prediction.
error_activation: activation function for the error (E) units.
A_activation: activation function for the target (A) and prediction (A_hat) units.
LSTM_activation: activation function for the cell and hidden states of the LSTM.
LSTM_inner_activation: activation function for the gates in the LSTM.
output_mode: either 'error', 'prediction', or 'all'.
Expand All @@ -58,7 +59,7 @@ class PredNet(Recurrent):
'''
def __init__(self, stack_sizes, R_stack_sizes,
A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
pixel_max=1., error_activation='relu',
pixel_max=1., error_activation='relu', A_activation='relu',
LSTM_activation='tanh', LSTM_inner_activation='hard_sigmoid',
output_mode='error', extrap_start_time = None,
dim_ordering=K.image_dim_ordering(), **kwargs):
Expand All @@ -75,6 +76,7 @@ def __init__(self, stack_sizes, R_stack_sizes,

self.pixel_max = pixel_max
self.error_activation = activations.get(error_activation)
self.A_activation = activations.get(A_activation)
self.LSTM_activation = activations.get(LSTM_activation)
self.LSTM_inner_activation = activations.get(LSTM_inner_activation)

Expand Down Expand Up @@ -163,13 +165,14 @@ def build(self, input_shape):

for l in range(self.nb_layers):
for c in ['i', 'f', 'c', 'o']:
act = self.LSTM_activation.__name__ if c == 'c' else self.LSTM_inner_activation.__name__
act = self.LSTM_activation if c == 'c' else self.LSTM_inner_activation
self.conv_layers[c].append(Convolution2D(self.R_stack_sizes[l], self.R_filt_sizes[l], self.R_filt_sizes[l], border_mode='same', activation=act, dim_ordering=self.dim_ordering))

self.conv_layers['ahat'].append(Convolution2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], self.Ahat_filt_sizes[l], border_mode='same', activation='relu', dim_ordering=self.dim_ordering))
act = 'relu' if l == 0 else self.A_activation
self.conv_layers['ahat'].append(Convolution2D(self.stack_sizes[l], self.Ahat_filt_sizes[l], self.Ahat_filt_sizes[l], border_mode='same', activation=act, dim_ordering=self.dim_ordering))

if l < self.nb_layers - 1:
self.conv_layers['a'].append(Convolution2D(self.stack_sizes[l+1], self.A_filt_sizes[l], self.A_filt_sizes[l], border_mode='same', activation='relu', dim_ordering=self.dim_ordering))
self.conv_layers['a'].append(Convolution2D(self.stack_sizes[l+1], self.A_filt_sizes[l], self.A_filt_sizes[l], border_mode='same', activation=self.A_activation, dim_ordering=self.dim_ordering))

self.upsample = UpSampling2D()
self.pool = MaxPooling2D()
Expand Down Expand Up @@ -269,6 +272,7 @@ def get_config(self):
'R_filt_sizes': self.R_filt_sizes,
'pixel_max': self.pixel_max,
'error_activation': self.error_activation.__name__,
'A_activation': self.A_activation.__name__,
'LSTM_activation': self.LSTM_activation.__name__,
'LSTM_inner_activation': self.LSTM_inner_activation.__name__,
'dim_ordering': self.dim_ordering,
Expand Down

0 comments on commit 7fdac5f

Please sign in to comment.