diff --git a/horde_worker_regen/amd_go_fast/amd_go_fast.py b/horde_worker_regen/amd_go_fast/amd_go_fast.py
index 1d2404d7..ba8190ca 100644
--- a/horde_worker_regen/amd_go_fast/amd_go_fast.py
+++ b/horde_worker_regen/amd_go_fast/amd_go_fast.py
@@ -5,15 +5,15 @@
 
 
 def _patch_sdpa(
-    patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
+    patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None, bool], Tensor],
 ):
-    """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""
+    """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)"""
 
     torch_sdpa = torch.nn.functional.scaled_dot_product_attention
 
-    def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
+    def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False):
         try:
-            return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
+            return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
         except Exception:
             hidden_states = torch_sdpa(
                 query=query,
@@ -23,6 +23,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
                 dropout_p=dropout_p,
                 is_causal=is_causal,
                 scale=scale,
+                enable_gqa=enable_gqa,
             )
         return hidden_states
 
@@ -32,7 +33,7 @@ def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causa
 try:
     from flash_attn import flash_attn_func
 
-    def sdpa_hijack_flash(q, k, v, m, p, c, s):
+    def sdpa_hijack_flash(q, k, v, m, p, c, s, g):
         assert m is None
         result = flash_attn_func(
             q=q.transpose(1, 2),