Skip to content

Commit

Permalink
run sparse attention test sequentially (#21808)
Browse files Browse the repository at this point in the history
### Description

For some reason, run SparseAttention tests in parallel causes random
failure in CI pipeline. Maybe due to out of memory when too many tests
running in parallel.

This will run those tests in sequentially.
  • Loading branch information
tianleiwu authored Aug 22, 2024
1 parent c0b68e7 commit 44a3923
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions onnxruntime/test/python/transformers/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
from benchmark_mha import InputFormats
from onnx import TensorProto, helper
from parameterized import parameterized
from torch import Tensor

from onnxruntime import InferenceSession, SessionOptions, get_available_providers
Expand Down Expand Up @@ -929,43 +928,36 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot


class TestSparseAttention(unittest.TestCase):

@unittest.skipUnless(has_cuda_support(), "cuda not available")
def test_sparse_attention(self):
def test_sparse_attention_cuda(self):
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
self.run_relevance_test(sm)

@parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True)
def test_simple_token_cpu(self, config: SparseAttentionConfig):
self.run_one_relevance_test(config)
for config in get_test_cases("CUDAExecutionProvider", True, comprehensive_mode):
self.run_one_relevance_test(config)

@parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True)
def test_simple_prompt_cpu(self, config: SparseAttentionConfig):
self.run_one_relevance_test(config)
for config in get_test_cases("CUDAExecutionProvider", False, comprehensive_mode):
self.run_one_relevance_test(config)

@parameterized.expand(
get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True
)
def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig):
# When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense.
if config.sparse_block_size * config.local_blocks > config.total_sequence_length:
def test_sparse_attention_cpu(self):
for config in get_simple_test_case("CPUExecutionProvider", True):
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_token_gpu(self, config):
self.run_one_relevance_test(config)
for config in get_simple_test_case("CPUExecutionProvider", False):
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_token_cpu(self, config):
self.run_one_relevance_test(config)
for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True):
# When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense.
if config.sparse_block_size * config.local_blocks > config.total_sequence_length:
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_prompt_cpu(self, config):
self.run_one_relevance_test(config)
for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode):
self.run_one_relevance_test(config)

@parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True)
def test_sparse_att_prompt_gpu(self, config):
self.run_one_relevance_test(config)
for config in get_test_cases("CPUExecutionProvider", False, comprehensive_mode):
self.run_one_relevance_test(config)

def run_one_relevance_test(self, config: SparseAttentionConfig):
if (not config.do_rotary) and config.total_sequence_length <= 2048:
Expand Down

0 comments on commit 44a3923

Please sign in to comment.