diff --git a/horde_worker_regen/amd_go_fast/amd_go_fast.py b/horde_worker_regen/amd_go_fast/amd_go_fast.py index c81d3c16..1d2404d7 100644 --- a/horde_worker_regen/amd_go_fast/amd_go_fast.py +++ b/horde_worker_regen/amd_go_fast/amd_go_fast.py @@ -1,40 +1,54 @@ import torch +from torch import Tensor +from typing import Callable from loguru import logger -if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name(): - try: # this import is handled via script, skipping it in mypy. If this fails somehow the module will simply not run. - from flash_attn import flash_attn_func # type: ignore - - sdpa = torch.nn.functional.scaled_dot_product_attention - - def sdpa_hijack( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False - ): - if query.shape[3] <= 256 and attn_mask is None and query.dtype != torch.float32: - hidden_states = flash_attn_func( - q=query.transpose(1, 2), - k=key.transpose(1, 2), - v=value.transpose(1, 2), - dropout_p=dropout_p, - causal=is_causal, - softmax_scale=scale, - ).transpose(1, 2) - else: - hidden_states = sdpa( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - return hidden_states - - torch.nn.functional.scaled_dot_product_attention = sdpa_hijack - logger.debug("# # # AMD GO FAST # # #") - except ImportError as e: - logger.debug(f"# # # AMD GO SLOW {e} # # #") -else: - logger.debug(f"# # # AMD GO SLOW Could not detect AMD GPU from: {torch.cuda.get_device_name()} # # #") + +def _patch_sdpa( + patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor], +): + """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)""" + + torch_sdpa = torch.nn.functional.scaled_dot_product_attention + + def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + try: + return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale) + except Exception: + hidden_states = torch_sdpa( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + return hidden_states + + torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash + + +try: + from flash_attn import flash_attn_func + + def sdpa_hijack_flash(q, k, v, m, p, c, s): + assert m is None + result = flash_attn_func( + q=q.transpose(1, 2), + k=k.transpose(1, 2), + v=v.transpose(1, 2), + dropout_p=p, + softmax_scale=s if s else q.shape[-1] ** (-0.5), + causal=c, + ) + assert isinstance(result, Tensor) + return result.transpose(1, 2) + + _patch_sdpa(sdpa_hijack_flash) + logger.debug("# # # Patched SDPA with Flash Attention # # #") +except ImportError as e: + logger.debug(f"# # # Could not load Flash Attention for hijack: {e} # # #") + +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {}