From 44a3923ba5711a6a0e2e3933669bc3ae2a4a9446 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Aug 2024 17:24:58 -0700 Subject: [PATCH] run sparse attention test sequentially (#21808) ### Description For some reason, run SparseAttention tests in parallel causes random failure in CI pipeline. Maybe due to out of memory when too many tests running in parallel. This will run those tests in sequentially. --- .../transformers/test_sparse_attention.py | 44 ++++++++----------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f18bcdba65579..688e6250fecbd 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -13,7 +13,6 @@ import torch from benchmark_mha import InputFormats from onnx import TensorProto, helper -from parameterized import parameterized from torch import Tensor from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -929,43 +928,36 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): + @unittest.skipUnless(has_cuda_support(), "cuda not available") - def test_sparse_attention(self): + def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() sm = major * 10 + minor self.run_relevance_test(sm) - @parameterized.expand(get_simple_test_case("CPUExecutionProvider", True), skip_on_empty=True) - def test_simple_token_cpu(self, config: SparseAttentionConfig): - self.run_one_relevance_test(config) + for config in get_test_cases("CUDAExecutionProvider", True, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand(get_simple_test_case("CPUExecutionProvider", False), skip_on_empty=True) - def test_simple_prompt_cpu(self, config: SparseAttentionConfig): - self.run_one_relevance_test(config) + for config in get_test_cases("CUDAExecutionProvider", False, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand( - get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True), skip_on_empty=True - ) - def test_sparse_att_token_cpu_rotary(self, config: SparseAttentionConfig): - # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. - if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + def test_sparse_attention_cpu(self): + for config in get_simple_test_case("CPUExecutionProvider", True): self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CUDAExecutionProvider", True, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_token_gpu(self, config): - self.run_one_relevance_test(config) + for config in get_simple_test_case("CPUExecutionProvider", False): + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CPUExecutionProvider", True, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_token_cpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode, do_rotary=True): + # When there is rotary, we use ORT GQA as reference: ORT GQA does not support mask so here we use dense. + if config.sparse_block_size * config.local_blocks > config.total_sequence_length: + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CPUExecutionProvider", False, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_prompt_cpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", True, comprehensive_mode): + self.run_one_relevance_test(config) - @parameterized.expand(get_test_cases("CUDAExecutionProvider", False, comprehensive_mode), skip_on_empty=True) - def test_sparse_att_prompt_gpu(self, config): - self.run_one_relevance_test(config) + for config in get_test_cases("CPUExecutionProvider", False, comprehensive_mode): + self.run_one_relevance_test(config) def run_one_relevance_test(self, config: SparseAttentionConfig): if (not config.do_rotary) and config.total_sequence_length <= 2048: