-
Notifications
You must be signed in to change notification settings - Fork 123
/
pretraining_model.py
96 lines (82 loc) · 4.59 KB
/
pretraining_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import numpy as np
def create_sinusoidal_embeddings(embeds):
position_enc = torch.tensor([
[pos / np.power(10000, 2 * (j // 2) / embeds.embedding_dim) for j in range(embeds.embedding_dim)]
for pos in range(embeds.num_embeddings)])
embeds.weight[:, 0::2] = torch.sin(position_enc[:, 0::2])
embeds.weight[:, 1::2] = torch.cos(position_enc[:, 1::2])
embeds.weight.detach_()
embeds.weight.requires_grad = False
class Transformer(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout,
sinusoidal_embeddings, causal=False):
""" Transformer (GPT-2 architecture) """
super().__init__()
self.causal = causal
self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim)
self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
if sinusoidal_embeddings:
create_sinusoidal_embeddings(self.position_embeddings)
self.dropout = nn.Dropout(dropout)
self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList()
for _ in range(num_layers):
self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embed_dim)))
self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))
def forward(self, x, padding_mask=None):
""" Input has shape [seq length, batch] """
positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
h = self.tokens_embeddings(x)
h = h + self.position_embeddings(positions).expand_as(h)
h = self.dropout(h)
attn_mask = None
if self.causal:
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
attn_mask = torch.triu(attn_mask, diagonal=1)
for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
self.layer_norms_2, self.feed_forwards):
h = layer_norm_1(h)
x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
x = self.dropout(x)
h = x + h
h = layer_norm_2(h)
x = feed_forward(h)
x = self.dropout(x)
h = x + h
return h
class TransformerWithLMHead(nn.Module):
def __init__(self, config):
""" Transformer with a language modeling head on top (tied weights) """
super().__init__()
self.config = config
self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
config.num_max_positions, config.num_heads, config.num_layers,
config.dropout, config.sinusoidal_embeddings, causal=not config.mlm)
self.lm_head = nn.Linear(config.embed_dim, config.num_embeddings, bias=False)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.transformer.tokens_embeddings.weight
def init_weights(self, module):
""" initialize weights - note that nn.MultiheadAttention is already initalized by PyTorch (xavier_uniform) """
if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
module.bias.data.zero_()
def forward(self, x, labels=None, padding_mask=None):
""" Input has shape [seq length, batch] """
hidden_states = self.transformer(x, padding_mask)
logits = self.lm_head(hidden_states)
if labels is not None:
shift_logits = logits[:-1] if self.transformer.causal else logits
shift_labels = labels[1:] if self.transformer.causal else labels
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return logits, loss
return logits