Skip to content

Commit

Permalink
Add llama implementation based on nanoGPT (OpenGVLab#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Mar 24, 2023
1 parent 4f78d19 commit 1598d12
Show file tree
Hide file tree
Showing 2 changed files with 397 additions and 0 deletions.
153 changes: 153 additions & 0 deletions compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import math

import models.llama.model as llama
import models.nano.model as nano

import torch
import torch.nn as nn


# LLAMA XQ torch.Size([3, 32, 16, 2]) # B T nh hs
# NANO Q torch.Size([3, 16, 32, 2]) # B nh T hs

# LLAMA COS torch.Size([1, 32, 1, 2]) # 1 T 1 hs
# NANO COS torch.Size([32, 1, 1, 2]) # 1 1 T hs

def compare_rope():
x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float32)
x = x[:, None, None, :]

llama_rot_x = llama.rotate_half(x)
nano_rot_x = nano.rotate_neg_half(x)

rot_x_matches = torch.allclose(llama_rot_x, nano_rot_x)

print(f"Comparing rot half\t\t{'OK' if rot_x_matches else 'KO'}")

_, seq_len, _, dim = x.shape
llama_cos_cached, llama_sin_cached = llama.precompute_cos_sin(seq_len, dim, x.dtype, x.device, base=10000)
nano_rope_cache = nano.build_rope_cache(seq_len, dim, dtype=x.dtype, device=x.device, base=10000)

cos_sin_cache_matches = torch.allclose(llama_cos_cached, nano_rope_cache[0]) and torch.allclose(llama_sin_cached, nano_rope_cache[1])

print(f"Comparing cos sin cache:\t{'OK' if cos_sin_cache_matches else 'KO'}")

nano_x_rope = nano.apply_rope(x, nano_rope_cache)
llama_x_rope, _ = llama.apply_rotary_pos_emb(x, x, llama_cos_cached, llama_sin_cached)

apply_rope_matches = torch.allclose(nano_x_rope, llama_x_rope)

print(f"Comparing apply rope:\t\t{'OK' if apply_rope_matches else 'KO'}")


def compare_rmsnorm():
block_size = 16
vocab_size = 16

sample = torch.rand(size=(2, block_size, vocab_size), dtype=torch.float32)

eps = 1e-6
llama_rmsnorm = llama.RMSNorm(vocab_size, eps=eps)(sample)
nano_rmsnorm = nano.RMSNorm(vocab_size, eps=eps)(sample)

rmsnorm_matches = torch.allclose(llama_rmsnorm, nano_rmsnorm)

print(f"Comparing rmsnorm:\t\t{'OK' if rmsnorm_matches else 'KO'}")


def copy_mlp(nano_mlp, llama_mlp):
llama_mlp.w1.weight.copy_(nano_mlp.c_fc1.weight)
llama_mlp.w3.weight.copy_(nano_mlp.c_fc2.weight)
llama_mlp.w2.weight.copy_(nano_mlp.c_proj.weight)


def copy_attention(nano_attn, llama_attn):
n_embd = nano_attn.c_attn.weight.shape[1]
llama_attn.wq.weight.copy_(nano_attn.c_attn.weight[:n_embd])
llama_attn.wk.weight.copy_(nano_attn.c_attn.weight[n_embd:-n_embd])
llama_attn.wv.weight.copy_(nano_attn.c_attn.weight[-n_embd:])
llama_attn.wo.weight.copy_(nano_attn.c_proj.weight)


def copy_block(nano_block, llama_block):
llama_block.attention_norm.weight.copy_(nano_block.rms_1.scale)
copy_attention(nano_block.attn, llama_block.attention)
llama_block.ffn_norm.weight.copy_(nano_block.rms_2.scale)
copy_mlp(nano_block.mlp, llama_block.feed_forward)


def copy_weights(nano_model, llama_model):
llama_model.tok_embeddings.weight.copy_(nano_model.transformer.wte.weight)
for nano_block, llama_block in zip(nano_model.transformer.h, llama_model.layers):
copy_block(nano_block, llama_block)
llama_model.norm.weight.copy_(nano_model.transformer.ln_f.scale)
llama_model.output.weight.copy_(nano_model.lm_head.weight)


def compare_to_llama():
block_size = 32
vocab_size = 32000
n_layer = 16
n_head = 16
n_embd = 32

nano_config = nano.LLaMAConfig(
block_size=block_size,
vocab_size=vocab_size,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd
)
llama_config = llama.ModelArgs(
dim=n_embd,
n_layers=n_layer,
n_heads=n_head,
vocab_size=vocab_size,
norm_eps=1e-6,
max_seq_length=block_size
)

batch_size = 3

token_sample = torch.randint(0, llama_config.vocab_size, size=(batch_size, llama_config.dim), dtype=torch.int64)

nano_model = nano.LLaMA(nano_config)
llama_model = llama.LLaMA(llama_config)

def _init_weights(module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * nano_config.n_layer))
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * nano_config.n_layer))

nano_model.apply(_init_weights)

with torch.no_grad():
copy_weights(nano_model, llama_model)

llama_embed = llama_model.tok_embeddings(token_sample)
nano_embed = nano_model.transformer.wte(token_sample)
embed_matches = torch.allclose(llama_embed, nano_embed)

print(f"Comparing embed:\t\t{'OK' if embed_matches else 'KO'}")

seq_len = token_sample.shape[1]
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
llama_block_out = llama_model.layers[0](llama_embed, llama_model.cos_cached, llama_model.sin_cached, mask)
nano_block_out = nano_model.transformer.h[0](nano_embed)
block_matches = torch.allclose(llama_block_out, nano_block_out)

print(f"Comparing block out:\t\t{'OK' if block_matches else 'KO'}")

expected = llama_model(token_sample)
out, _ = nano_model(token_sample)
forward_matches = torch.allclose(out, expected)

print(f"Comparing forward:\t\t{'OK' if forward_matches else 'KO'}")


if __name__ == "__main__":
compare_rope()
compare_rmsnorm()
compare_to_llama()
244 changes: 244 additions & 0 deletions models/nano/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""
Full definition of a LLaMA Language Model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F


# Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py
# MIT licensed: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license

def build_rope_cache(seq_len, n_elem, dtype, device, base=10000):
"""
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py
MIT licensed: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1. / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device, dtype=dtype)

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta)

# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

# Cache them
cos_cache = idx_theta2.cos()[None, None, :, :]
sin_cache = idx_theta2.sin()[None, None, :, :]

return torch.stack((cos_cache, sin_cache), dim=0)


def rotate_neg_half(x: torch.Tensor):
# $\frac{d}{2}$
d_2 = x.shape[-1] // 2

# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)


def apply_rope(x: torch.Tensor, rope_cache):
neg_half_x = rotate_neg_half(x)
cos, sin = rope_cache
return (x * cos) + (neg_half_x * sin)


# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
# BSD 3-Clause License
class RMSNorm(nn.Module):

def __init__(self, size, dim=-1, eps=1e-8):
super().__init__()
self.scale = nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim

def forward(self, x):
# NOTE: the original RMSNorm paper implementation is not equivalent
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)

norm_x = x.norm(2, dim=self.dim, keepdim=True)
norm_x = torch.mean(x*x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)

return self.scale * x_normed


class CausalSelfAttention(nn.Module):

def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0

# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.flash = False
if not self.flash:
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

self.rope_cache = build_rope_cache(
seq_len=config.block_size,
n_elem=config.n_embd // config.n_head,
dtype=self.c_attn.weight.dtype,
device=self.c_attn.weight.device,
)

def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)

q = apply_rope(q, self.rope_cache)
k = apply_rope(k, self.rope_cache)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
y = self.c_proj(y)

return y


class MLP(nn.Module):

def __init__(self, config):
super().__init__()
hidden_dim = 4 * config.n_embd
n_hidden = int(2 * hidden_dim / 3)
N = 256
# ensure n_hidden is multiple of N
n_hidden = ((n_hidden - 1) // N) * N + N

self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)

def forward(self, x):
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x


class Block(nn.Module):

def __init__(self, config):
super().__init__()
self.rms_1 = RMSNorm(config.n_embd, eps=1e-6)
self.attn = CausalSelfAttention(config)
self.rms_2 = RMSNorm(config.n_embd, eps=1e-6)
self.mlp = MLP(config)

def forward(self, x):
x = x + self.attn(self.rms_1(x))
x = x + self.mlp(self.rms_2(x))
return x


@dataclass
class LLaMAConfig:
block_size: int = 4096 # 7B
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096


class LLaMA(nn.Module):

def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config

self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd, eps=1e-6),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

# init all weights
self.apply(self._init_weights)
# # apply special scaled init to the residual projections, per GPT-2 paper
# for pn, p in self.named_parameters():
# if pn.endswith('c_proj.weight'):
# torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

# report number of parameters
# print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

def get_num_params(self):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

# forward the LLaMA model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)

if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.lm_head(x)
loss = None

return logits, loss

0 comments on commit 1598d12

Please sign in to comment.