Skip to content

Commit

Permalink
include enable_gqa
Browse files Browse the repository at this point in the history
  • Loading branch information
HPPinata committed Dec 12, 2024
1 parent a088ef6 commit 682a81a
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions horde_worker_regen/amd_go_fast/amd_go_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@


def _patch_sdpa(
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None, bool], Tensor],
):
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)"""

torch_sdpa = torch.nn.functional.scaled_dot_product_attention

def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False):
try:
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
except Exception:
hidden_states = torch_sdpa(
query=query,
Expand All @@ -23,6 +23,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
return hidden_states

Expand All @@ -32,7 +33,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
try:
from flash_attn import flash_attn_func

def sdpa_hijack_flash(q, k, v, m, p, c, s):
def sdpa_hijack_flash(q, k, v, m, p, c, s, g):
assert m is None
result = flash_attn_func(
q=q.transpose(1, 2),
Expand Down

0 comments on commit 682a81a

Please sign in to comment.