From cf1054c9bcc74bd659739f34444f46d8c79837cf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 30 Jul 2024 23:11:23 -0700 Subject: [PATCH] Update gemma2.py --- unsloth/models/gemma2.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 2191a99c..d2bfb789 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -139,6 +139,11 @@ def Gemma2Attention_fast_forward( window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) pass + # FA uses 1 / sqrt for softmax_scale! + if not hasattr(self, "_flash_attention_softmax_scale"): + self._flash_attention_softmax_scale = 1.0 / self.config.query_pre_attn_scalar**0.5 + pass + Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) @@ -146,7 +151,7 @@ def Gemma2Attention_fast_forward( Q, K, V, causal = True, softcap = self.config.attn_logit_softcapping, - softmax_scale = self.config.query_pre_attn_scalar, + softmax_scale = self._flash_attention_softmax_scale, window_size = window, ) A = A.reshape(bsz, q_len, n_heads*head_dim)