Skip to content

Commit

Permalink
allow flash attention test in windows
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jun 10, 2024
1 parent 415c5e1 commit 6adb2cd
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from bert_padding import pad_input, unpad_input
from einops import rearrange, repeat
from onnx import TensorProto, helper
from rotary_flash import apply_rotary_emb

from onnxruntime import InferenceSession, OrtValue, SessionOptions

Expand Down Expand Up @@ -737,6 +736,12 @@ def mha_func(q, k, v, config):
return output


def rotary_options_for_current_os():
# Reference implementation of rotary uses triton, which is not availabe in Windows.

Check warning on line 740 in onnxruntime/test/python/transformers/test_flash_attn_cuda.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "availabe" is a misspelling of "available" Raw Output: ./onnxruntime/test/python/transformers/test_flash_attn_cuda.py:740:67: "availabe" is a misspelling of "available"
# So we only test rotary in Linux right now.
return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)]


def gqa_prompt_func(
q,
k,
Expand Down Expand Up @@ -1161,6 +1166,13 @@ def parity_check_mha(
return all_close


def rotary_embedding(*args, **kwargs):
# Use local import since triton is not available in Windows.
from rotary_flash import apply_rotary_emb

return apply_rotary_emb(*args, **kwargs)


def parity_check_gqa_prompt(
config,
causal=True,
Expand Down Expand Up @@ -1250,11 +1262,12 @@ def parity_check_gqa_prompt(
angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)

if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
Expand All @@ -1265,7 +1278,7 @@ def parity_check_gqa_prompt(
s=config.q_sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
Expand Down Expand Up @@ -1454,11 +1467,12 @@ def parity_check_gqa_prompt_no_buff(
angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)

if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
Expand All @@ -1469,7 +1483,7 @@ def parity_check_gqa_prompt_no_buff(
s=config.q_sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, k_cache_ref
Expand Down Expand Up @@ -1654,10 +1668,10 @@ def parity_check_gqa_past(
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
Expand All @@ -1668,7 +1682,7 @@ def parity_check_gqa_past(
s=config.sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
Expand Down Expand Up @@ -1863,10 +1877,10 @@ def parity_check_gqa_past_no_buff(
cos = torch.cos(angle).to(dtype=torch.float16)
sin = torch.sin(angle).to(dtype=torch.float16)
if causal or local:
q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
q_ro = rearrange(
apply_rotary_emb(
rotary_embedding(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
Expand All @@ -1877,7 +1891,7 @@ def parity_check_gqa_past_no_buff(
s=config.sequence_length,
)
# q_ro = q
k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved)
else:
cos, sin = None, None
q_ro, k_ro = q, new_k
Expand Down Expand Up @@ -2063,7 +2077,7 @@ def test_gqa_no_past_memory_efficient(self):
for sq, skv in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
all_close = parity_check_gqa_prompt(
Expand Down Expand Up @@ -2121,7 +2135,7 @@ def test_gqa_no_past_flash_attention(self):
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
all_close = parity_check_gqa_prompt(
Expand Down Expand Up @@ -2176,7 +2190,7 @@ def test_gqa_past_memory_efficient(self):
for s, s2 in seqs:
for n, n2 in num_h:
for h in h_sizes:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
Expand Down Expand Up @@ -2235,7 +2249,7 @@ def test_gqa_past_flash_attention(self):
for n, n2 in num_h:
for h in h_sizes:
for local in [False, True]:
for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]:
for rotary, rotary_interleaved in rotary_options_for_current_os():
for packed in [False, True]:
sp = random.randint(1, s2 - s) if s2 - s > 0 else 0
config = Config(b, s, s2, sp, n, n2, h)
Expand Down

0 comments on commit 6adb2cd

Please sign in to comment.