diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f18bcdba65579..688e6250fecbd 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -13,7 +13,6 @@ import torch from benchmark_mha import InputFormats from onnx import TensorProto, helper -from parameterized import parameterized from torch import Tensor from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -929,43 +928,36 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): + @unittest.skipUnless(has_cuda_support(), "cuda not available") - def test_sparse_attention(self): + def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() sm = major * 10 + minor self.run_relevance_test(sm) - @parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True) - def test_simple_token_cpu(self, config: SparseAttentionConfig): - self.run_one_relevance_test(config) + for config in get_test_cases("CUDAExecutionProvider", True, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True) - def test_simple_prompt_cpu(self, config: SparseAttentionConfig): - self.run_one_relevance_test(config) + for config in get_test_cases("CUDAExecutionProvider", False, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand( - get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True - ) - def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig): - # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. - if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + def test_sparse_attention_cpu(self): + for config in get_simple_test_case("CPUExecutionProvider", True): self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_token_gpu(self, config): - self.run_one_relevance_test(config) + for config in get_simple_test_case("CPUExecutionProvider", False): + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_token_cpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True): + # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. + if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_prompt_cpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_prompt_gpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", False, comprehensive_mode): + self.run_one_relevance_test(config) def run_one_relevance_test(self, config: SparseAttentionConfig): if (not config.do_rotary) and config.total_sequence_length <= 2048: