Skip to content

Commit

Permalink
Update gemma2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 31, 2024
1 parent 86b71c4 commit cf1054c
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,19 @@ def Gemma2Attention_fast_forward(
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
pass

# FA uses 1 / sqrt for softmax_scale!
if not hasattr(self, "_flash_attention_softmax_scale"):
self._flash_attention_softmax_scale = 1.0 / self.config.query_pre_attn_scalar**0.5
pass

Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
A = flash_attn_func(
Q, K, V,
causal = True,
softcap = self.config.attn_logit_softcapping,
softmax_scale = self.config.query_pre_attn_scalar,
softmax_scale = self._flash_attention_softmax_scale,
window_size = window,
)
A = A.reshape(bsz, q_len, n_heads*head_dim)
Expand Down

0 comments on commit cf1054c

Please sign in to comment.