-
Notifications
You must be signed in to change notification settings - Fork 2
/
AttentionPoolingLayer.py
356 lines (301 loc) · 17.7 KB
/
AttentionPoolingLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
from tensorflow.keras.layers import Layer
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import Zeros, glorot_normal
from tensorflow.keras.layers import Layer, Lambda
from tensorflow.keras.regularizers import l2
import numpy as np
class Dice(Layer):
"""The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data.
Input shape
- Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model.
Output shape
- Same shape as the input.
Arguments
- **axis** : Integer, the axis that should be used to compute data distribution (typically the features axis).
- **epsilon** : Small float added to variance to avoid dividing by zero.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""
def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
self.axis = axis
self.epsilon = epsilon
super(Dice, self).__init__(**kwargs)
def build(self, input_shape):
self.bn = tf.keras.layers.BatchNormalization(
axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
self.alphas = self.add_weight(shape=(input_shape[-1],), initializer=Zeros(
), dtype=tf.float32, name= 'dice_alpha') # name='alpha_'+self.name
super(Dice, self).build(input_shape) # Be sure to call this somewhere!
self.uses_learning_phase = True
def call(self, inputs,training=None,**kwargs):
inputs_normed = self.bn(inputs,training=training)
# tf.layers.batch_normalization(
# inputs, axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
x_p = tf.keras.activations.sigmoid(inputs_normed)
return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self, ):
config = {'axis': self.axis, 'epsilon': self.epsilon}
base_config = super(Dice, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def activation_layer(activation):
if activation == "dice" or activation == "Dice":
act_layer = Dice()
elif (isinstance(activation, str)) or (sys.version_info.major == 2 and isinstance(activation, (str, unicode))):
act_layer = tf.keras.layers.Activation(activation)
elif issubclass(activation, Layer):
act_layer = activation()
else:
raise ValueError(
"Invalid activation,found %s.You should use a str or a Activation Layer Class." % (activation))
return act_layer
#-------------------------------------------------------------------------------
class AttentionSequencePoolingLayer(Layer):
"""The Attentional sequence pooling operation used in DIN.
Input shape
- A list of three tensor: [query,keys,keys_length]
- query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)``
- keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)``
- keys_length is a 2D tensor with shape: ``(batch_size, 1)``
Output shape
- 3D tensor with shape: ``(batch_size, 1, embedding_size)``.
Arguments
- **att_hidden_units**:list of positive integer, the attention net layer number and units in each layer.
- **att_activation**: Activation function to use in attention net.
- **weight_normalization**: bool.Whether normalize the attention score of local activation unit.
- **supports_masking**:If True,the input need to support masking.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""
def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False,
return_score=False,
supports_masking=False, **kwargs):
self.att_hidden_units = att_hidden_units
self.att_activation = att_activation
self.weight_normalization = weight_normalization
self.return_score = return_score
super(AttentionSequencePoolingLayer, self).__init__(**kwargs)
self.supports_masking = supports_masking
def build(self, input_shape):
if not self.supports_masking:
pass
else:
pass
self.local_att = LocalActivationUnit(self.att_hidden_units, self.att_activation, l2_reg=0, dropout_rate=0, use_bn=False, seed=1024, )
super(AttentionSequencePoolingLayer, self).build(
input_shape) # Be sure to call this somewhere!
def call(self, inputs, mask=None, training=None, **kwargs):
queries, keys = inputs
attention_score = self.local_att([queries, keys], training=training)
print("attention_score",attention_score)
outputs = tf.keras.layers.Lambda(Transpose)(attention_score)
if self.weight_normalization:
outputs = tf.keras.activations.softmax(outputs)
if not self.return_score:
outputs = tf.keras.backend.batch_dot(outputs, keys)
if tf.__version__ < '1.13.0':
outputs._uses_learning_phase = attention_score._uses_learning_phase
else:
outputs._uses_learning_phase = training is not None
return outputs
def compute_output_shape(self, input_shape):
if self.return_score:
return (None, 1, input_shape[1][1])
else:
return (None, 1, input_shape[0][-1])
def compute_mask(self, inputs, mask):
return None
def get_config(self, ):
config = {'att_hidden_units': self.att_hidden_units, 'att_activation': self.att_activation,
'weight_normalization': self.weight_normalization, 'return_score': self.return_score,
'supports_masking': self.supports_masking}
base_config = super(AttentionSequencePoolingLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
#-------------------------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------------------------
class LocalActivationUnit(Layer):
"""The LocalActivationUnit used in DIN with which the representation of
user interests varies adaptively given different candidate items.
Input shape
- A list of two 3D tensor with shape: ``(batch_size, 1, embedding_size)`` and ``(batch_size, T, embedding_size)``
Output shape
- 3D tensor with shape: ``(batch_size, T, 1)``.
Arguments
- **hidden_units**:list of positive integer, the attention net layer number and units in each layer.
- **activation**: Activation function to use in attention net.
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix of attention net.
- **dropout_rate**: float in [0,1). Fraction of the units to dropout in attention net.
- **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net.
- **seed**: A Python integer to use as random seed.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""
def __init__(self, hidden_units=(64, 32), activation='sigmoid', l2_reg=0, dropout_rate=0, use_bn=False, seed=1024,
**kwargs):
self.hidden_units = hidden_units
self.activation = activation
self.l2_reg = l2_reg
self.dropout_rate = dropout_rate
self.use_bn = use_bn
self.seed = seed
super(LocalActivationUnit, self).__init__(**kwargs)
self.supports_masking = True
def build(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `LocalActivationUnit` layer should be called '
'on a list of 2 inputs')
if len(input_shape[0]) != 3 or len(input_shape[1]) != 3:
raise ValueError("Unexpected inputs dimensions %d and %d, expect to be 3 dimensions" % (
len(input_shape[0]), len(input_shape[1])))
if input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1:
raise ValueError('A `LocalActivationUnit` layer requires '
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape))
size = 4 * \
int(input_shape[0][-1]) if len(self.hidden_units) == 0 else self.hidden_units[-1]
self.kernel = self.add_weight(shape=(size, 1),initializer=glorot_normal(seed=self.seed),name="kernel")
self.bias = self.add_weight(shape=(1,), initializer=Zeros(), name="bias")
self.dnn = DNN(self.hidden_units, self.activation, self.l2_reg,self.dropout_rate, self.use_bn, seed=self.seed)
self.dense = tf.keras.layers.Lambda(lambda x:tf.nn.bias_add(tf.tensordot(x[0], x[1], axes=(-1, 0)), x[2]))
super(LocalActivationUnit, self).build(input_shape) # Be sure to call this somewhere!
def call(self, inputs, training=None, **kwargs):
query, keys = inputs
keys_len = keys.get_shape()[1]
queries = K.repeat_elements(query, keys_len, 1)
att_input = tf.keras.layers.concatenate([queries, keys, queries - keys, queries * keys], axis=-1)
att_out = self.dnn(att_input, training=training)
attention_score = self.dense([att_out,self.kernel,self.bias])
return attention_score
def compute_output_shape(self, input_shape):
return input_shape[1][:2] + (1,)
def compute_mask(self, inputs, mask):
return mask
def get_config(self, ):
config = {'activation': self.activation, 'hidden_units': self.hidden_units,
'l2_reg': self.l2_reg, 'dropout_rate': self.dropout_rate, 'use_bn': self.use_bn, 'seed': self.seed}
base_config = super(LocalActivationUnit, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
#-------------------------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------------------------
class DNN(Layer):
"""The Multi Layer Percetron
Input shape
- nD tensor with shape: ``(batch_size, ..., input_dim)``. The most common situation would be a 2D input with shape ``(batch_size, input_dim)``.
Output shape
- nD tensor with shape: ``(batch_size, ..., hidden_size[-1])``. For instance, for a 2D input with shape ``(batch_size, input_dim)``, the output would have shape ``(batch_size, hidden_size[-1])``.
Arguments
- **hidden_units**:list of positive integer, the layer number and units in each layer.
- **activation**: Activation function to use.
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix.
- **dropout_rate**: float in [0,1). Fraction of the units to dropout.
- **use_bn**: bool. Whether use BatchNormalization before activation or not.
- **seed**: A Python integer to use as random seed.
"""
def __init__(self, hidden_units, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False, seed=1024, **kwargs):
self.hidden_units = hidden_units
self.activation = activation
self.dropout_rate = dropout_rate
self.seed = seed
self.l2_reg = l2_reg
self.use_bn = use_bn
super(DNN, self).__init__(**kwargs)
def build(self, input_shape):
input_size = input_shape[-1]
hidden_units = [int(input_size)] + list(self.hidden_units)
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(
hidden_units[i], hidden_units[i + 1]),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(len(self.hidden_units))]
self.bias = [self.add_weight(name='bias' + str(i),
shape=(self.hidden_units[i],),
initializer=Zeros(),
trainable=True) for i in range(len(self.hidden_units))]
if self.use_bn:
self.bn_layers = [tf.keras.layers.BatchNormalization() for _ in range(len(self.hidden_units))]
self.dropout_layers = [tf.keras.layers.Dropout(self.dropout_rate,seed=self.seed+i) for i in range(len(self.hidden_units))]
self.activation_layers = [activation_layer(self.activation) for _ in range(len(self.hidden_units))]
super(DNN, self).build(input_shape) # Be sure to call this somewhere!
def call(self, inputs, training=None, **kwargs):
deep_input = inputs
for i in range(len(self.hidden_units)):
fc = tf.keras.backend.bias_add(tf.tensordot(deep_input, self.kernels[i], axes=(-1, 0)), self.bias[i])
# fc = Dense(self.hidden_size[i], activation=None, \
# kernel_initializer=glorot_normal(seed=self.seed), \
# kernel_regularizer=l2(self.l2_reg))(deep_input)
if self.use_bn:
fc = self.bn_layers[i](fc, training=training)
fc = self.activation_layers[i](fc)
fc = self.dropout_layers[i](fc,training = training)
deep_input = fc
return deep_input
def compute_output_shape(self, input_shape):
if len(self.hidden_units) > 0:
shape = input_shape[:-1] + (self.hidden_units[-1],)
else:
shape = input_shape
return tuple(shape)
def get_config(self, ):
config = {'activation': self.activation, 'hidden_units': self.hidden_units,
'l2_reg': self.l2_reg, 'use_bn': self.use_bn, 'dropout_rate': self.dropout_rate, 'seed': self.seed}
base_config = super(DNN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
class NoMask(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(NoMask, self).__init__(**kwargs)
def build(self, input_shape):
# Be sure to call this somewhere!
super(NoMask, self).build(input_shape)
def call(self, x, mask=None, **kwargs):
return x
def compute_mask(self, inputs, mask):
return None
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
class PredictionLayer(Layer):
"""
Arguments
- **task**: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
- **use_bias**: bool.Whether add bias term or not.
"""
def __init__(self, task='binary', use_bias=True, **kwargs):
if task not in ["binary", "multiclass", "regression"]:
raise ValueError("task must be binary,multiclass or regression")
self.task = task
self.use_bias = use_bias
super(PredictionLayer, self).__init__(**kwargs)
def build(self, input_shape):
if self.use_bias:
self.global_bias = self.add_weight(
shape=(1,), initializer=Zeros(), name="global_bias")
# Be sure to call this somewhere!
super(PredictionLayer, self).build(input_shape)
def call(self, inputs, **kwargs):
x = inputs
if self.use_bias:
x = tf.keras.backend.bias_add(x, self.global_bias, data_format='channels_last')
if self.task == "binary":
x = tf.keras.activations.sigmoid(x)
output=tf.keras.layers.Lambda(Reshape)(x)
print("output: ",output)
return output
def compute_output_shape(self, input_shape):
return (None, 1)
def get_config(self, ):
config = {'task': self.task, 'use_bias': self.use_bias}
base_config = super(PredictionLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def Reshape(tensor_input):
x=tf.reshape(tensor_input,(-1, 1))
return x
def Transpose(inp):
output=tf.transpose(inp, (0, 2, 1))
return output