Skip to content

Commit

Permalink
try/catch instead of hard code
Browse files Browse the repository at this point in the history
  • Loading branch information
HPPinata committed Dec 12, 2024
1 parent 7e9c271 commit a088ef6
Showing 1 changed file with 51 additions and 37 deletions.
88 changes: 51 additions & 37 deletions horde_worker_regen/amd_go_fast/amd_go_fast.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,54 @@
import torch
from torch import Tensor
from typing import Callable
from loguru import logger

if "AMD" in torch.cuda.get_device_name() or "Radeon" in torch.cuda.get_device_name():
try: # this import is handled via script, skipping it in mypy. If this fails somehow the module will simply not run.
from flash_attn import flash_attn_func # type: ignore

sdpa = torch.nn.functional.scaled_dot_product_attention

def sdpa_hijack(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
):
if query.shape[3] <= 256 and attn_mask is None and query.dtype != torch.float32:
hidden_states = flash_attn_func(
q=query.transpose(1, 2),
k=key.transpose(1, 2),
v=value.transpose(1, 2),
dropout_p=dropout_p,
causal=is_causal,
softmax_scale=scale,
).transpose(1, 2)
else:
hidden_states = sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
return hidden_states

torch.nn.functional.scaled_dot_product_attention = sdpa_hijack
logger.debug("# # # AMD GO FAST # # #")
except ImportError as e:
logger.debug(f"# # # AMD GO SLOW {e} # # #")
else:
logger.debug(f"# # # AMD GO SLOW Could not detect AMD GPU from: {torch.cuda.get_device_name()} # # #")

def _patch_sdpa(
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
):
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""

torch_sdpa = torch.nn.functional.scaled_dot_product_attention

def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try:
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
except Exception:
hidden_states = torch_sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return hidden_states

torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash


try:
from flash_attn import flash_attn_func

def sdpa_hijack_flash(q, k, v, m, p, c, s):
assert m is None
result = flash_attn_func(
q=q.transpose(1, 2),
k=k.transpose(1, 2),
v=v.transpose(1, 2),
dropout_p=p,
softmax_scale=s if s else q.shape[-1] ** (-0.5),
causal=c,
)
assert isinstance(result, Tensor)
return result.transpose(1, 2)

_patch_sdpa(sdpa_hijack_flash)
logger.debug("# # # Patched SDPA with Flash Attention # # #")
except ImportError as e:
logger.debug(f"# # # Could not load Flash Attention for hijack: {e} # # #")

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}

0 comments on commit a088ef6

Please sign in to comment.