forked from DINESH-WebDev/mistral-src
-
Notifications
You must be signed in to change notification settings - Fork 0
/
one_file_ref.py
365 lines (299 loc) · 11.8 KB
/
one_file_ref.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import torch
from torch import nn
from dataclasses import dataclass
from pathlib import Path
import fire
import json
from typing import Optional, Tuple, List
from sentencepiece import SentencePieceProcessor
@dataclass
class ModelArgs:
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
sliding_window: int
norm_eps: float
vocab_size: int
max_batch_size: int = 0
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=2)
values = torch.repeat_interleave(values, repeats=repeats, dim=2)
return keys, values
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.sliding_window = self.args.sliding_window
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(
args.dim,
args.n_heads * args.head_dim,
bias=False
)
self.wk = nn.Linear(
args.dim,
args.n_kv_heads * args.head_dim,
bias=False
)
self.wv = nn.Linear(
args.dim,
args.n_kv_heads * args.head_dim,
bias=False
)
self.wo = nn.Linear(
args.n_heads * args.head_dim,
args.dim,
bias=False
)
self.cache_k = torch.empty(
(
args.max_batch_size,
args.sliding_window,
self.n_kv_heads,
self.args.head_dim,
), dtype=torch.float16
).cuda()
self.cache_v = torch.empty(
(
args.max_batch_size,
args.sliding_window,
self.n_kv_heads,
self.args.head_dim,
), dtype=torch.float16
).cuda()
def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor]
) -> torch.Tensor:
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.args.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.args.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# The cache is a rotating buffer
scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])
if positions.shape[0] > 1:
# prefill
key, value = repeat_kv(xk, xv, self.repeats)
else:
cur_pos = positions[-1].item() + 1
key, value = repeat_kv(self.cache_k[:bsz, :cur_pos, ...], self.cache_v[:bsz, :cur_pos, ...], self.repeats)
query = xq.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# scores : [bsz, n_heads, seqlen | 1, seqlen]
scores = torch.matmul(query, key.transpose(2, 3)) * self.scale
if mask is not None:
scores += mask[None, None, ...]
scores = scores.float()
scores = nn.functional.softmax(scores, dim=-1).type_as(query)
output = torch.matmul(scores, value) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.dim,
args.hidden_dim,
bias=False
)
self.w2 = nn.Linear(
args.hidden_dim,
args.dim,
bias=False
)
self.w3 = nn.Linear(
args.dim,
args.hidden_dim,
bias=False
)
def forward(self, x) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.feed_forward = FeedForward(args=args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor, mask: Optional[torch.Tensor]
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
return torch.polar(torch.ones_like(freqs), freqs) # complex64
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
assert self.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList(
[TransformerBlock(args=args) for _ in range(args.n_layers)]
)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False
)
self.freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000).to("cuda")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
):
h = self.tok_embeddings(input_ids)
freqs_cis = self.freqs_cis[positions]
mask: Optional[torch.Tensor] = None
if input_ids.shape[1] > 1:
seqlen = input_ids.shape[1]
tensor = torch.full(
(seqlen, seqlen),
dtype=h.dtype,
fill_value=1,
device=h.device,
)
mask = torch.tril(tensor, diagonal=0).to(h.dtype)
# make the mask banded to account for sliding window
mask = torch.triu(mask, diagonal=-self.args.sliding_window)
mask = torch.log(mask)
for layer in self.layers:
h = layer(h, freqs_cis, positions, mask)
return self.output(self.norm(h)).float()
@staticmethod
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16):
with open(folder / 'params.json', 'r') as f:
model_args = ModelArgs(**json.loads(f.read()))
model_args.max_batch_size = max_batch_size
model = Transformer(model_args).to(device=device, dtype=dtype)
loaded = torch.load(folder / 'consolidated.00.pth')
model.load_state_dict(loaded)
return model
class Tokenizer:
def __init__(self, model_path: str):
assert Path(model_path).exists(), model_path
self._model = SentencePieceProcessor(model_file=model_path)
assert self._model.vocab_size() == self._model.get_piece_size()
@property
def eos_id(self) -> int:
return self._model.eos_id()
@property
def pad_id(self) -> int:
return self._model.pad_id()
def encode(self, s: str) -> List[int]:
return [self._model.bos_id(), *self._model.encode(s)]
def decode(self, t: List[int]) -> str:
return self._model.decode(t)
@torch.no_grad()
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, max_tokens: int):
encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
prompt_lens = [len(x) for x in encoded_prompts]
min_prompt_len = min(prompt_lens)
max_prompt_len = max(prompt_lens)
input_tokens = torch.full((len(prompts), max_prompt_len), tokenizer.pad_id, dtype=torch.long, device="cuda")
for i, encoded in enumerate(encoded_prompts):
input_tokens[i, :len(encoded)] = torch.tensor(encoded).to(input_tokens)
input_mask = input_tokens != tokenizer.pad_id
# pre-fill
positions = torch.arange(0, min_prompt_len).to("cuda")
logits = model.forward(input_tokens[:, :min_prompt_len], positions)
logprobs = nn.functional.log_softmax(logits, dim=-1)
# decode
generated = []
all_logprobs = [
logprobs[:,:-1,:].gather(2, input_tokens[:,1:min_prompt_len,None]).squeeze(-1),
]
cur_pos = min_prompt_len
for _ in range(max_tokens):
next_token = torch.argmax(logprobs[:, -1,:], dim=-1)
if cur_pos < input_mask.shape[1]:
next_token = torch.where(input_mask[:, cur_pos], input_tokens[:, cur_pos], next_token)
all_logprobs.append(
logprobs[:,-1,:].gather(1, next_token[:, None]),
)
generated.append(next_token[:, None])
logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token))
logprobs = nn.functional.log_softmax(logits, dim=-1)
cur_pos += 1
all_logprobs = torch.cat(all_logprobs, 1)
res = []
if max_tokens > 0:
generated = torch.cat(generated, 1)
for i, x in enumerate(encoded_prompts):
res.append(tokenizer.decode(x[:min_prompt_len] + generated[i].tolist()))
return res, all_logprobs
def demo(model_path: str, max_tokens: int = 35):
tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model"))
transformer = Transformer.from_folder(Path(model_path), max_batch_size=3)
res, _logprobs = generate(
[
"This is a test",
"This is another test",
"This is a third test, mistral AI is very good at testing. ",
],
transformer,
tokenizer,
max_tokens=max_tokens,
)
for x in res:
print(x)
print("=====================")
if __name__ == "__main__":
fire.Fire(demo)