Skip to content

Commit

Permalink
Mistral Nemo
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 19, 2024
1 parent 1524504 commit c1d3493
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 15 deletions.
31 changes: 27 additions & 4 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,26 @@

# =============================================
# Edits all Config files to enable RoPE Scaling for all models
from transformers import PretrainedConfig

# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
if "head_dim (" not in config:
add_head_dim = "If it is not specified, will default to `8`.\n"\
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
" The attention head dimension."
config = config.replace("If it is not specified, will default to `8`.", add_head_dim)

add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
config = config.replace("num_key_value_heads=8,", add_head_dim)

add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
pass
return config
pass

from transformers import __version__ as transformers_version
from transformers import PretrainedConfig
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]

for model_name in model_architectures:
Expand All @@ -87,8 +105,14 @@
r"\n self.rope_scaling = rope_scaling\n",
config,
)
exec(config, globals())

# Just for Mistral Nemo
if model_name == "mistral":
if Version(transformers_version) <= Version("4.42.4"):
config = patch_mistral_nemo_config(config)
pass

exec(config, globals())
exec(f"import {config_filepath}", globals())
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
pass
Expand All @@ -97,7 +121,6 @@
# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
Expand Down Expand Up @@ -748,7 +771,7 @@ def patch_linear_scaling(
"self.rotary_emb = .+?\)", function,
flags = re.DOTALL | re.MULTILINE,
)
if len(rotary_emb) == 0: return
if len(rotary_emb) == 0: return None, function
rotary_emb = rotary_emb[0]
function = function.replace(rotary_emb, fix_rope_function, 1)
function = exec_code + "\n\n" + function
Expand Down
58 changes: 48 additions & 10 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def LlamaAttention_fast_forward_inference(
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")

# Mistral Nemo 12b has weird dimensions
if attention_size != self.hidden_size:
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
else:
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
pass

self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
self.scalar = 1.0 / math_sqrt(self.head_dim)
self.half_head_dim = head_dim // 2
Expand Down Expand Up @@ -239,7 +247,7 @@ def LlamaAttention_fast_forward_inference(
pass
A = A.transpose(1, 2)
A = A.reshape(bsz, 1, attention_size)
A = fast_linear_forward(self.o_proj, A, out = self.temp_QA[1][:,:,:self.hidden_size])
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
return A, (Kn, Vn)
pass

Expand Down Expand Up @@ -335,6 +343,9 @@ def LlamaAttention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# Extend RoPE dynamically to fit in VRAM
self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)

if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Expand Down Expand Up @@ -662,6 +673,12 @@ def LlamaModel_fast_forward(
offloaded_gradient_checkpointing = True
pass

# Check for Flex Attention
# if IS_GEMMA2 and HAS_FLEX_ATTENTION:
# if not (seq_length % FLEX_ATTENTION_PADDING == 0):
# USE_FLEX_ATTENTION = True


# Gemma2 has alternating SWA and global attn
if IS_GEMMA2 and not hasattr(self, "SWA_mask"):
n = self.config.max_position_embeddings
Expand Down Expand Up @@ -965,19 +982,21 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
# in FP32. They are applied (multiplied) in FP32 as well.
self.max_seq_len_cached = seq_len
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()

freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand All @@ -988,14 +1007,21 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def forward(self, x, position_ids=None, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
if seq_len > self.current_rope_size:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.cos_cached[:seq_len].to(dtype = x.dtype),
self.sin_cached[:seq_len].to(dtype = x.dtype),
)
pass

def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass


Expand All @@ -1010,11 +1036,11 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
self.current_rope_size = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
)
t = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = t / self.scaling_factor

freqs = torch.outer(t, inv_freq)
Expand Down Expand Up @@ -1134,6 +1160,15 @@ def from_pretrained(
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth'
print(statistics)

# Warn about fast transfers
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1":
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!")
pass
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer

model_patcher.pre_patch()
get_statistics() # For debugging - we use a download counter to see if environments are not breaking

Expand Down Expand Up @@ -1215,6 +1250,8 @@ def from_pretrained(
attn_implementation = "eager",
**kwargs,
)
# Return old flag
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer
# We currently only support NVIDIA GPUs - AMD / Intel is a work in progress!
post_check = check_nvidia()

Expand Down Expand Up @@ -2081,4 +2118,5 @@ def for_training(model, use_gradient_checkpointing = True):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass
pass
pass
pass

22 changes: 21 additions & 1 deletion unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,24 @@ def MistralForCausalLM_fast_forward(
pass


# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_attention(function):
function = function.replace(
"(self.head_dim * self.num_heads) != self.hidden_size",
"False",
)
function = function.replace(
"self.head_dim = self.hidden_size // self.num_heads",
"self.head_dim = config.head_dim",
)
function = function.replace(
"self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)",
"self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)",
)
return function
pass


class FastMistralModel(FastLlamaModel):

@staticmethod
Expand All @@ -280,7 +298,9 @@ def pre_patch():
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
attention_module = MistralAttention,
)
if init_name is not None:
# Just for Mistral Nemo models!
function = patch_mistral_nemo_attention(function)
if True:#init_name is not None:
exec(function, globals())
MistralAttention.__init__ = eval(init_name)
pass
Expand Down

0 comments on commit c1d3493

Please sign in to comment.