Skip to content

Commit

Permalink
Update gemma2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Aug 10, 2024
1 parent 53b22c4 commit dca3cc8
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions unsloth/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,25 +155,9 @@ def Gemma2Attention_fast_forward(
window_size = window,
)
A = A.reshape(bsz, q_len, n_heads*head_dim)
elif attention_mask is None:
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
else:
# Grouped query attention
if n_groups != 1:
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
pass
# Must be contiguous or else results are False!
# https://github.com/pytorch/pytorch/issues/112577
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
# Needs (batch_size, n_heads, seq_len, head_dim)
# is_casual and attention_mask must not be both set!
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
# Go back to (batch_size, seq_len, n_heads, head_dim)
A = A.transpose(1, 2).contiguous()
A = A.reshape(bsz, q_len, n_heads*head_dim)
mask = causal_mask if attention_mask is None else attention_mask
A = slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, kv_seq_len)
pass
A = self.apply_o(self, A)
return A, None, past_key_value
Expand Down

0 comments on commit dca3cc8

Please sign in to comment.