Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
include enable_gqa
Browse files Browse the repository at this point in the history
HPPinata authored Nov 24, 2024
1 parent a54bac3 commit a5201f3
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions horde_worker_regen/amd_go_fast/amd_go_fast.py
Original file line number Diff line number Diff line change
@@ -5,15 +5,15 @@


def _patch_sdpa(
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None, bool], Tensor],
):
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)"""

torch_sdpa = torch.nn.functional.scaled_dot_product_attention

def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False):
try:
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
except Exception:
hidden_states = torch_sdpa(
query=query,
@@ -23,6 +23,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
return hidden_states

@@ -32,7 +33,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
try:
from flash_attn import flash_attn_func

def sdpa_hijack_flash(q, k, v, m, p, c, s):
def sdpa_hijack_flash(q, k, v, m, p, c, s, g):
assert m is None
result = flash_attn_func(
q=q.transpose(1, 2),

0 comments on commit a5201f3

Please sign in to comment.