-
Notifications
You must be signed in to change notification settings - Fork 9
/
modules.py
476 lines (387 loc) · 20.3 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
import sys
import inspect
from typing import *
import torch
import torch.nn as nn
from pytorch_lightning.core.lightning import LightningModule
from utils.exp_utils import logging
class PositionEncoding(LightningModule):
def __init__(self, d_pos_enc):
super().__init__()
power_range = -torch.arange(0.0, d_pos_enc, 2.0) / d_pos_enc
inv_freq = 10000 ** power_range
# register buffer tells pytorch that this tensor is part of the modle
# this means that it will be saved in the state_dict and moved to the GPU
# along with the model
self.register_buffer("inv_freq", inv_freq)
def forward(self, n_model: torch.LongTensor) -> torch.FloatTensor:
# outer product
angle_values = torch.einsum("i,j->ij", n_model.float(), self.inv_freq)
pos_emb = torch.cat([angle_values.sin(), angle_values.cos()], dim=-1)
# DIMS: pos_enc -> (n_model, d_pos_enc)
return pos_emb
class MultiHeadAttention(LightningModule):
def __init__(self, d_input: int, d_output: int, d_pos_enc: int, n_head: int,
d_head: int, dropout: float, dropout_attn: float,
debug: bool = False, skip_debug: Optional[List[str]] = [None]):
super().__init__()
self.debug = debug
self.skip_debug = skip_debug
# TODO: Relax this constraint
assert d_input == d_output, \
"MultiHeadAttention: d_input must equal d_output"
self.d_input, self.d_output = d_input, d_output
self.d_head = d_head
self.n_head = n_head
# Queries are only applied to vectors of the current sequence (not to
# the memorized states)
# No bias since simple matrix multiplication
self.linear_q = nn.Linear(d_input, d_head * n_head, bias=False)
# this layer applies the linear transformation required
# for the keys and values for all heads at once for efficiency
# 2 is for keys and values
self.linear_kv = nn.Linear(d_input, d_head * n_head * 2, bias=False)
# for position encodings
self.linear_p = nn.Linear(d_pos_enc, d_head * n_head, bias=False)
# for scaled dot product attention
self.scale = 1 / (d_head ** 0.5)
self.dropout_attn = nn.Dropout(dropout_attn)
# we will use this to project to the output dimension
self.layer_out = nn.Linear(self.d_head * self.n_head, self.d_output,
bias=False)
self.norm_out = nn.LayerNorm(self.d_output)
self.dropout = nn.Dropout(dropout)
def _rel_shift(self, x):
# DIMS: x -> (d1, d2, d3, ....)
# DIMS: zero_pad -> (d1, 1, d3, ....)
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
# DIMS: joined -> (d1, 1+d2, d3,...)
joined = torch.cat([zero_pad, x], dim=1)
# DIMS: swapped -> (1+d2, d1, d3, ...)
swapped = joined.view(x.size(1) + 1, x.size(0), *x.size()[2:])
# DIMS: clipped -> (d2, d1, d3, ...)
clipped = swapped[1:]
# DIMS: _rel_shift -> (d1, d2, d3, ...)
return clipped.view_as(x)
def forward(self, segment: torch.FloatTensor, pos_encs: torch.FloatTensor,
memories: torch.FloatTensor, u: torch.FloatTensor,
v: torch.FloatTensor, mask: Optional[torch.FloatTensor] = None,
):
"""
pos_encs: position encodings is separate to handle relative positions
DIMS: segment -> (n_batch, n_model, d_model)
DIMS: pos_embs -> (n_model + n_mems, self.d_input)
DIMS: output -> (n_model, self.d_input)
DIMS: u -> (n_head, d_head)
"""
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" segment: {segment.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" pos_encs: {pos_encs.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" memories: {memories.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" u: {u.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" v: {v.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" mask: {mask.size()}")
# length of current segment
n_batch, n_model, d_input = segment.shape
# length of memory available
n_current_mems = memories.shape[0]
n_head, d_head = self.n_head, self.d_head
# DIMS: memory_cat_input -> (n_current_mems + n_model, d_input)
memory_cat_input = torch.cat([memories, segment], dim=1)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" memory_cat_input: {memory_cat_input.size()}")
# DIMS: segment -> (d_batch, n_model, d_input)
# DIMS: self.linear_q -> (d_input, n_head * d_head)
# DIMS: queries -> (d_batch, n_model, b, n_head * d_head)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" linear_q: in {self.linear_q.in_features} x "
f"out {self.linear_q.out_features}")
queries = self.linear_q(segment)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" queries: {queries.size()}")
# DIMS: memory_cat_input -> (n_model + n_current_mems, d_input)
# DIMS: self.linear_kv -> (d_input, d_head * n_head * 2)
# DIMS: keys -> (d_batch, n_model + n_current_mems, d_head * n_head)
# DIMS: values -> (d_batch, n_model + n_current_mems, d_head * n_head)
keys, values = torch.chunk(self.linear_kv(memory_cat_input), 2, dim=-1)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" keys: {keys.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" values: {values.size()}")
# DIMS: queries -> (d_batch, n_model, b, n_head * d_head)
# DIMS: u -> (n_head, d_head)
# DIMS: content_attn -> (n_model, n_model + n_current_mems, n_head)
content_attn = torch.einsum(
"bihd,bjhd->bijh",
((queries.view(n_batch, n_model, n_head, d_head) + u),
keys.view(n_batch, n_model + n_current_mems, n_head, d_head)))
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" content_attn: {content_attn.size()}")
# position-based attention term ((b) + (d) in the paper)
# this attention is solely based on the position of the key/values
# (i.e. it does not take the content of the key/values into account)
# DIMS: pos_enc -> (n_model, d_pos_enc)
# DIMS: self.linear_p -> (d_pos_enc, d_head * n_head)
# DIMS: positions -> (n_model, d_head * n_head)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" linear_p: in {self.linear_p.in_features} x "
f"out {self.linear_p.out_features}")
positions = self.linear_p(pos_encs)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" positions: {positions.size()}")
# DIMS: position_attn -> (n_model, n_model + n_current_mems, n_head)
position_attn = torch.einsum(
"bihd,jhd->bijh",
((queries.view(n_batch, n_model, n_head, d_head) + v),
positions.view(n_model + n_current_mems, n_head, d_head)))
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" position_attn: {position_attn.size()}")
# Compute positional attention efficiently
# DIMS: position_attn -> (n_model, n_model + n_current_mems, n_head)
position_attn = self._rel_shift(position_attn)
# the attention is the sum of content-based and position-based attention
# DIMS: attn -> (n_batch, n_model, n_model + n_current_mems, n_head)
attn = content_attn + position_attn
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" attn: {attn.size()}")
if mask is not None and mask.any().item():
padded_mask = mask[None, :, :, :]
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" padded_mask: {padded_mask.size()}")
attn = attn.masked_fill(padded_mask, -float('inf'))
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" attn: {attn.size()}")
# rescale to prevent values from exploding
# normalize across the value sequence dimension
# TODO change softmax
attn = torch.softmax(attn * self.scale, dim=1)
attn = self.dropout_attn(attn)
# DIMS: attn -> (n_model, n_model + n_current_mems, n_head)
# DIMS: values -> (n_model + n_current_mems, d_head * n_head)
# DIMS: values.view -> (n_model + n_current_mems, n_head, d_head)
# i: n_model
# j: n_model + n_current_mems
# h: n_head
# d: d_head
# DIMS: einsum -> (n_model, n_head, d_head)
# DIMS: attn_weighted_values -> (n_model, n_head* d_head)
attn_weighted_values = torch.einsum(
"bijh,bjhd->bihd",
(attn,
values.view(n_batch, n_model + n_current_mems, n_head, d_head),)) \
.contiguous() \
.view(n_batch, n_model, n_head * d_head)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" attn_weighted_values: {attn_weighted_values.size()}")
# Project back to input dimension and add residual connection
# DIMS: self.layer_out() -> (d_head * n_head, d_output)
# DIMS: attn_weighted_values -> (n_model, n_head* d_head)
# DIMS: self_dropout(...) -> (n_model, d_output)
# DIMS: segment -> (n_model, self.d_input)
output = segment + self.dropout(self.layer_out(attn_weighted_values))
output = self.norm_out(output)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" output: {output.size()}")
return output
class Positionwise_FeedFwd(LightningModule):
def __init__(self, d_input, d_FF_inner, dropout):
super().__init__()
self.d_input = d_input
self.d_FF_inner = d_FF_inner
self.dropout = dropout
self.ff = nn.Sequential(nn.Linear(d_input, d_FF_inner),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_FF_inner, d_input),
nn.Dropout(dropout))
self.layer_norm = nn.LayerNorm(d_input)
def forward(self, sequence: torch.FloatTensor) -> torch.FloatTensor:
# DIMS: sequence -> (n_model, d_input)
fed_fwd = self.ff(sequence)
output = self.layer_norm(sequence + fed_fwd)
# DIMS: output -> (cur_seq, bs, d_input)
return output
class DecoderBlock(LightningModule):
def __init__(self, n_head: int, d_input: int, d_output: int, d_pos_enc: int,
d_head_inner: int, d_FF_inner: int, dropout: float = 0.0,
dropout_attn: float = 0.0, debug: bool = False,
skip_debug: Optional[List[str]] = [None]):
super().__init__()
self.debug = debug
self.skip_debug = skip_debug
self.multi_heat_attn = MultiHeadAttention(d_input, d_output, d_pos_enc,
n_head=n_head,
d_head=d_head_inner,
dropout=dropout,
dropout_attn=dropout_attn,
debug=debug,
skip_debug=skip_debug)
self.feed_fwd = Positionwise_FeedFwd(d_input, d_FF_inner, dropout)
def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)
pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_input),
u: torch.FloatTensor, # (H, d_input),
v: torch.FloatTensor, # (H, d_input),
mask=None,
mems=None,
):
return self.feed_fwd(
self.multi_heat_attn(input_, pos_embs, mems, u, v, mask=mask))
class Transformer_XL(LightningModule):
def __init__(self, n_layer: int, d_hidden, d_pos_enc, n_head: int,
d_head: int, d_FF_inner: int, d_model: int, dropout: float,
dropout_attn: float, n_model: int, n_mems: int,
debug: bool = False, skip_debug: Optional[List[str]] = [None]):
super().__init__()
self.debug = debug
self.skip_debug = skip_debug
self.n_layer = n_layer
self.n_head, self.d_head, self.d_ff_inner = n_head, d_head, d_FF_inner
self.d_model = d_model
# Position encoding
self.pos_enc = PositionEncoding(d_pos_enc)
# Core transformer
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
for l in range(n_layer):
# All layers have d_hidden input size and d_hidden output size
# apart from:
# layer 0: input size = d_model
# layer n_layer - 1: output size = d_model
input_size = d_model if l == 0 else d_hidden
output_size = d_model if l == n_layer - 1 else d_hidden
self.layers.append(
DecoderBlock(n_head, input_size, output_size, d_pos_enc,
d_head_inner=d_head, d_FF_inner=d_FF_inner,
dropout=dropout, dropout_attn=dropout_attn,
debug=debug, skip_debug=skip_debug))
# TODO explore other loss functions e.g. CTCLoss()
self.loss_fn = nn.MSELoss()
self.seq_len, self.mem_len = n_model, n_mems
# u and v are global parameters since position encodings are too.
self.u = nn.Parameter(
torch.zeros(self.n_head, self.d_head, dtype=torch.float))
self.v = nn.Parameter(
torch.zeros(self.n_head, self.d_head, dtype=torch.float))
def init_memory(self, device=torch.device("cpu")) -> List[
torch.FloatTensor]:
return [torch.empty(0).to(device=device, dtype=torch.float)
for _ in range(self.n_layer + 1)]
def update_memory(self,
previous_memory: List[torch.FloatTensor],
hidden_states: List[torch.FloatTensor], ):
assert len(hidden_states) == len(previous_memory)
mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)
# For the updated memory, we use the most recent `self.mem_len`
# states, including the previous memory
# In other words, if `seq_len` < `self.mem_len` some of the previous memory
# will carry over to the next memory
with torch.no_grad():
new_memory = []
end_idx = mem_len + seq_len
beg_idx = max(0, end_idx - self.mem_len)
for m, h in zip(previous_memory, hidden_states):
# (mem_len + seq_len, bs, d)
cat = torch.cat([m, h], dim=0)
# (self.mem_len, bs, d)
new_memory.append(cat[beg_idx:end_idx].detach())
return new_memory
def reset_length(self, seq_len, ext_len, mem_len):
self.seq_len = seq_len
self.mem_len = mem_len
def forward(self, data: torch.FloatTensor,
target: torch.FloatTensor, # (n_model, bs)
memory: Optional[List[torch.FloatTensor]] = None,
) -> Dict[str, torch.Tensor]:
# DIMS: data -> (n_batch, n_model, d_model)
# DIMS: target -> (n_model, d_model)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" data: {data.size()}")
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" target: {target.size()}")
if memory is None:
memory: List[torch.FloatTensor] = self.init_memory(data.device)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" memory: {len(memory)}")
assert len(memory) == len(self.layers) + 1
n_batch, n_sequence, d_sequence = data.size()
prev_seq = memory[0].size(0)
# Construct attention mask
dec_attn_mask = torch.triu(
torch.ones((n_sequence, n_sequence + prev_seq)),
diagonal=1 + prev_seq,
).bool()[..., None].to(data.device)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" dec_attn_mask: {dec_attn_mask.size()}")
current_segment = self.dropout(data)
pos_idxs = torch.arange(n_sequence + prev_seq - 1, -1, -1.0,
dtype=torch.float).to(current_segment.device)
pos_embs = self.dropout(self.pos_enc(pos_idxs))
# Main part of forward pass
hidden_states = [current_segment]
layer_out = current_segment
for mem, layer in zip(memory, self.layers):
layer_out = layer(layer_out, pos_embs, self.u, self.v,
mask=dec_attn_mask, mems=mem)
hidden_states.append(layer_out)
layer_out = self.dropout(layer_out)
if self.debug and (self.__class__.__name__ not in self.skip_debug):
logging(f"{self.__class__.__name__}, "
f"{inspect.currentframe().f_code.co_name}: "
f" layer_out: {layer_out.size()}")
# loss = self.loss_fn(layer_out.view(-1, layer_out.size(-1)),
# target.view(-1))
loss = self.loss_fn(
layer_out[:, -1, :],
target)
# Update memory
# Ensure the memory is treated as a constant
# and we do not back propagate through them
new_memory = self.update_memory(memory, hidden_states)
return {"loss": loss, "layer_out": layer_out, "memory": new_memory}