From 480f2ef31c7c19e216bf21f57d08fa1895e5fc73 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 3 Sep 2024 13:52:12 -0700 Subject: [PATCH] Gemma faster inference (#987) * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * update token retrieval logic (#952) * Fix DPO (#947) * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * update hf token retrieval logic --------- Co-authored-by: Daniel Han * Update llama.py * get_token * Update README.md * Update gemma2.py * Update rms_layernorm.py * synchronize * Update gemma2.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * layernorm * Update rms_layernorm.py * Update gemma2.py * Update rms_layernorm.py * Update rms_layernorm.py * revert * Gemma * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update rms_layernorm.py * Update gemma2.py * Change UnslothTrainingArguments base class to SFTConfig (#979) * Cohere * Update trainer.py * Cohere * Cohere * New models * Update llama.py * Update llama.py * Update cohere.py * Update llama.py * Update cohere.py * retry * Update fast_lora.py * Update llama.py * Update fast_lora.py * Update llama.py * Update llama.py * Update cross_entropy_loss.py * _apply_lora_mlp * Update _utils.py * Gemma fixes * Update llama.py * Update flex_attention.py --------- Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com> Co-authored-by: Tuan Pham <82665400+vTuanpham@users.noreply.github.com> --- unsloth/kernels/__init__.py | 6 ++++- unsloth/kernels/flex_attention.py | 37 +++++++++++++++++++++++++++++++ unsloth/models/_utils.py | 4 ++++ unsloth/models/gemma2.py | 6 ++++- unsloth/models/llama.py | 11 +++++++++ 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index c2de979a..26f632ee 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -33,7 +33,11 @@ ) from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora -from .flex_attention import HAS_FLEX_ATTENTION, slow_attention_softcapping +from .flex_attention import ( + HAS_FLEX_ATTENTION, + slow_attention_softcapping, + slow_inference_attention_softcapping, +) if HAS_FLEX_ATTENTION: from .flex_attention import ( diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index a992a023..9cf999e2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -80,3 +80,40 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): return A pass + +torch_matmul = torch.matmul +torch_tanh = torch.tanh +torch_nn_functional_softmax = torch.nn.functional.softmax +def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + n_heads = self.num_heads + head_dim = self.head_dim + n_kv_heads = self.num_key_value_heads + n_groups = self.num_key_value_groups + + # Grouped query attention + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim) + K = K.reshape(bsz, n_heads, q_len, head_dim) + V = V.reshape(bsz, n_heads, q_len, head_dim) + + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch_matmul(Q, K.transpose(2, 3)) + + # Logit softcapping + A /= t; torch_tanh(A, out = A); A *= t; + A += causal_mask[:q_len, :q_len] + # Much slower in torch compile! + # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) + A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype) + A = torch_matmul(A, V) + A = A.transpose(1, 2).contiguous() + A = A.reshape(bsz, q_len, n_heads*head_dim) + return A +pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ea9a0c53..242d234d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -39,6 +39,8 @@ "create_boolean_mask", "torch_amp_custom_fwd", "torch_amp_custom_bwd", + "accelerate_old_send_to_device", + "accelerate_new_send_to_device", ] import torch @@ -287,6 +289,7 @@ def patch_mistral_nemo_config(config): import accelerate.utils.operations if hasattr(accelerate.utils.operations, "send_to_device") and \ accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": + accelerate_old_send_to_device = accelerate.utils.operations.send_to_device from accelerate.utils.operations import * send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) send_to_device = re.sub( @@ -296,6 +299,7 @@ def patch_mistral_nemo_config(config): ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) # accelerate.utils.operations.send_to_device = _fixed_send_to_device + accelerate_new_send_to_device = _fixed_send_to_device pass pass # ============================================= diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 6858f525..218849ef 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -157,7 +157,10 @@ def Gemma2Attention_fast_forward( A = A.reshape(bsz, q_len, n_heads*head_dim) else: mask = causal_mask if attention_mask is None else attention_mask - A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len) + fx = slow_inference_attention_softcapping \ + if "_flag_for_generation" in kwargs else \ + slow_attention_softcapping + A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) pass A = self.apply_o(self, A) return A, None, past_key_value @@ -192,6 +195,7 @@ def Gemma2DecoderLayer_fast_forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + _flag_for_generation=True, ) hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) hidden_states += residual diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5ccf906a..3fcb8a76 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -953,6 +953,8 @@ def _CausalLM_fast_forward( if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) logits = logits.unsqueeze(0).unsqueeze(0) + elif num_logits_to_keep != 0: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)) else: logits = self.lm_head(hidden_states.to(lm_head.dtype)) pass @@ -1368,8 +1370,14 @@ def _fast_generate(*args, **kwargs): pass internal_model._flag_for_generation = True + # Must patch accelerate for Xformers + import accelerate.utils.operations + accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + # For newer HF kwargs["cache_implementation"] = "dynamic" + # For num_logits_to_keep + kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids kwargs.pop("token_type_ids", None) @@ -1402,6 +1410,9 @@ def _fast_generate(*args, **kwargs): pass if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation + # Return accelerate back + accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + return output pass return _fast_generate