diff --git a/compare.py b/compare.py new file mode 100644 index 0000000..7aa54ee --- /dev/null +++ b/compare.py @@ -0,0 +1,153 @@ +import math + +import models.llama.model as llama +import models.nano.model as nano + +import torch +import torch.nn as nn + + +# LLAMA XQ torch.Size([3, 32, 16, 2]) # B T nh hs +# NANO Q torch.Size([3, 16, 32, 2]) # B nh T hs + +# LLAMA COS torch.Size([1, 32, 1, 2]) # 1 T 1 hs +# NANO COS torch.Size([32, 1, 1, 2]) # 1 1 T hs + +def compare_rope(): + x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float32) + x = x[:, None, None, :] + + llama_rot_x = llama.rotate_half(x) + nano_rot_x = nano.rotate_neg_half(x) + + rot_x_matches = torch.allclose(llama_rot_x, nano_rot_x) + + print(f"Comparing rot half\t\t{'OK' if rot_x_matches else 'KO'}") + + _, seq_len, _, dim = x.shape + llama_cos_cached, llama_sin_cached = llama.precompute_cos_sin(seq_len, dim, x.dtype, x.device, base=10000) + nano_rope_cache = nano.build_rope_cache(seq_len, dim, dtype=x.dtype, device=x.device, base=10000) + + cos_sin_cache_matches = torch.allclose(llama_cos_cached, nano_rope_cache[0]) and torch.allclose(llama_sin_cached, nano_rope_cache[1]) + + print(f"Comparing cos sin cache:\t{'OK' if cos_sin_cache_matches else 'KO'}") + + nano_x_rope = nano.apply_rope(x, nano_rope_cache) + llama_x_rope, _ = llama.apply_rotary_pos_emb(x, x, llama_cos_cached, llama_sin_cached) + + apply_rope_matches = torch.allclose(nano_x_rope, llama_x_rope) + + print(f"Comparing apply rope:\t\t{'OK' if apply_rope_matches else 'KO'}") + + +def compare_rmsnorm(): + block_size = 16 + vocab_size = 16 + + sample = torch.rand(size=(2, block_size, vocab_size), dtype=torch.float32) + + eps = 1e-6 + llama_rmsnorm = llama.RMSNorm(vocab_size, eps=eps)(sample) + nano_rmsnorm = nano.RMSNorm(vocab_size, eps=eps)(sample) + + rmsnorm_matches = torch.allclose(llama_rmsnorm, nano_rmsnorm) + + print(f"Comparing rmsnorm:\t\t{'OK' if rmsnorm_matches else 'KO'}") + + +def copy_mlp(nano_mlp, llama_mlp): + llama_mlp.w1.weight.copy_(nano_mlp.c_fc1.weight) + llama_mlp.w3.weight.copy_(nano_mlp.c_fc2.weight) + llama_mlp.w2.weight.copy_(nano_mlp.c_proj.weight) + + +def copy_attention(nano_attn, llama_attn): + n_embd = nano_attn.c_attn.weight.shape[1] + llama_attn.wq.weight.copy_(nano_attn.c_attn.weight[:n_embd]) + llama_attn.wk.weight.copy_(nano_attn.c_attn.weight[n_embd:-n_embd]) + llama_attn.wv.weight.copy_(nano_attn.c_attn.weight[-n_embd:]) + llama_attn.wo.weight.copy_(nano_attn.c_proj.weight) + + +def copy_block(nano_block, llama_block): + llama_block.attention_norm.weight.copy_(nano_block.rms_1.scale) + copy_attention(nano_block.attn, llama_block.attention) + llama_block.ffn_norm.weight.copy_(nano_block.rms_2.scale) + copy_mlp(nano_block.mlp, llama_block.feed_forward) + + +def copy_weights(nano_model, llama_model): + llama_model.tok_embeddings.weight.copy_(nano_model.transformer.wte.weight) + for nano_block, llama_block in zip(nano_model.transformer.h, llama_model.layers): + copy_block(nano_block, llama_block) + llama_model.norm.weight.copy_(nano_model.transformer.ln_f.scale) + llama_model.output.weight.copy_(nano_model.lm_head.weight) + + +def compare_to_llama(): + block_size = 32 + vocab_size = 32000 + n_layer = 16 + n_head = 16 + n_embd = 32 + + nano_config = nano.LLaMAConfig( + block_size=block_size, + vocab_size=vocab_size, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd + ) + llama_config = llama.ModelArgs( + dim=n_embd, + n_layers=n_layer, + n_heads=n_head, + vocab_size=vocab_size, + norm_eps=1e-6, + max_seq_length=block_size + ) + + batch_size = 3 + + token_sample = torch.randint(0, llama_config.vocab_size, size=(batch_size, llama_config.dim), dtype=torch.int64) + + nano_model = nano.LLaMA(nano_config) + llama_model = llama.LLaMA(llama_config) + + def _init_weights(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * nano_config.n_layer)) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * nano_config.n_layer)) + + nano_model.apply(_init_weights) + + with torch.no_grad(): + copy_weights(nano_model, llama_model) + + llama_embed = llama_model.tok_embeddings(token_sample) + nano_embed = nano_model.transformer.wte(token_sample) + embed_matches = torch.allclose(llama_embed, nano_embed) + + print(f"Comparing embed:\t\t{'OK' if embed_matches else 'KO'}") + + seq_len = token_sample.shape[1] + mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) + mask = torch.triu(mask, diagonal=1) + llama_block_out = llama_model.layers[0](llama_embed, llama_model.cos_cached, llama_model.sin_cached, mask) + nano_block_out = nano_model.transformer.h[0](nano_embed) + block_matches = torch.allclose(llama_block_out, nano_block_out) + + print(f"Comparing block out:\t\t{'OK' if block_matches else 'KO'}") + + expected = llama_model(token_sample) + out, _ = nano_model(token_sample) + forward_matches = torch.allclose(out, expected) + + print(f"Comparing forward:\t\t{'OK' if forward_matches else 'KO'}") + + +if __name__ == "__main__": + compare_rope() + compare_rmsnorm() + compare_to_llama() diff --git a/models/nano/model.py b/models/nano/model.py new file mode 100644 index 0000000..ef7c3f3 --- /dev/null +++ b/models/nano/model.py @@ -0,0 +1,244 @@ +""" +Full definition of a LLaMA Language Model, all of it in this single file. +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. +""" + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +# Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py +# MIT licensed: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license + +def build_rope_cache(seq_len, n_elem, dtype, device, base=10000): + """ + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py + MIT licensed: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1. / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device, dtype=dtype) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + cos_cache = idx_theta2.cos()[None, None, :, :] + sin_cache = idx_theta2.sin()[None, None, :, :] + + return torch.stack((cos_cache, sin_cache), dim=0) + + +def rotate_neg_half(x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = x.shape[-1] // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + +def apply_rope(x: torch.Tensor, rope_cache): + neg_half_x = rotate_neg_half(x) + cos, sin = rope_cache + return (x * cos) + (neg_half_x * sin) + + +# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py +# BSD 3-Clause License +class RMSNorm(nn.Module): + + def __init__(self, size, dim=-1, eps=1e-8): + super().__init__() + self.scale = nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x): + # NOTE: the original RMSNorm paper implementation is not equivalent + # norm_x = x.norm(2, dim=self.dim, keepdim=True) + # rms_x = norm_x * d_x ** (-1. / 2) + # x_normed = x / (rms_x + self.eps) + + norm_x = x.norm(2, dim=self.dim, keepdim=True) + norm_x = torch.mean(x*x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + + return self.scale * x_normed + + +class CausalSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + # regularization + self.n_head = config.n_head + self.n_embd = config.n_embd + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.flash = False + if not self.flash: + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + + self.rope_cache = build_rope_cache( + seq_len=config.block_size, + n_elem=config.n_embd // config.n_head, + dtype=self.c_attn.weight.dtype, + device=self.c_attn.weight.device, + ) + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + + head_size = C // self.n_head + k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) + + q = apply_rope(q, self.rope_cache) + k = apply_rope(k, self.rope_cache) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.c_proj(y) + + return y + + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + hidden_dim = 4 * config.n_embd + n_hidden = int(2 * hidden_dim / 3) + N = 256 + # ensure n_hidden is multiple of N + n_hidden = ((n_hidden - 1) // N) * N + N + + self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) + + def forward(self, x): + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class Block(nn.Module): + + def __init__(self, config): + super().__init__() + self.rms_1 = RMSNorm(config.n_embd, eps=1e-6) + self.attn = CausalSelfAttention(config) + self.rms_2 = RMSNorm(config.n_embd, eps=1e-6) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.rms_1(x)) + x = x + self.mlp(self.rms_2(x)) + return x + + +@dataclass +class LLaMAConfig: + block_size: int = 4096 # 7B + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + n_embd: int = 4096 + + +class LLaMA(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict(dict( + wte = nn.Embedding(config.vocab_size, config.n_embd), + h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f = RMSNorm(config.n_embd, eps=1e-6), + )) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # init all weights + self.apply(self._init_weights) + # # apply special scaled init to the residual projections, per GPT-2 paper + # for pn, p in self.named_parameters(): + # if pn.endswith('c_proj.weight'): + # torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) + + # report number of parameters + # print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) + + def get_num_params(self): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, idx, targets=None): + device = idx.device + b, t = idx.size() + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the LLaMA model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + logits = self.lm_head(x) + loss = None + + return logits, loss