diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index bb4f93395b160..ff473cc2ced92 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -150,15 +150,16 @@ def get_provider_support_info(provider: str, use_kv_cache: bool): return device, dtype, formats -def has_cuda_support(): +def get_compute_capability(): if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): - major, _ = torch.cuda.get_device_capability() - return major >= 6 - return False + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm + return 0 def no_kv_cache_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -221,7 +222,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): def kv_cache_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -292,7 +293,7 @@ def mha_test_cases(provider: str, comprehensive: bool): def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -331,7 +332,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): - if provider == "CUDAExecutionProvider" and not has_cuda_support(): + if provider == "CUDAExecutionProvider" and get_compute_capability() < 60: return yield @@ -473,14 +474,14 @@ def parity_check_mha_multi_threading( config = test_inputs[0]["config"] # For now, MHA CUDA kernel does not support causal so skip such test cases. if config.causal and config.provider == "CUDAExecutionProvider": - return + return None # Some kernel does not support certain input format. if sdpa_kernel not in [ SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: - return + return None if verbose: print(f"create a shared session with {vars(config)}") onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) @@ -582,6 +583,7 @@ def check_parity_with_config(i: int): except AssertionError as e: print(f"Failed with {vars(config)}: {e}") return e + if verbose: print(f"Passed: {vars(config)}") return None @@ -630,19 +632,18 @@ def run_mha_cuda_multi_threading(self, spda_kernel): def test_mha_cuda_multi_threading(self): self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT) - def test_mha_cuda_multi_threading_flash(self): - self.run_mha_cuda_multi_threading(SdpaKernel.FLASH_ATTENTION) - def test_mha_cuda_multi_threading_efficient(self): self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION) def test_mha_cuda_multi_threading_trt(self): - self.run_mha_cuda_multi_threading( - SdpaKernel.TRT_FUSED_ATTENTION - | SdpaKernel.TRT_FLASH_ATTENTION - | SdpaKernel.TRT_CROSS_ATTENTION - | SdpaKernel.TRT_CAUSAL_ATTENTION - ) + sm = get_compute_capability() + if sm in [75, 80, 86, 89]: + self.run_mha_cuda_multi_threading( + SdpaKernel.TRT_FUSED_ATTENTION + | SdpaKernel.TRT_FLASH_ATTENTION + | SdpaKernel.TRT_CROSS_ATTENTION + | SdpaKernel.TRT_CAUSAL_ATTENTION + ) if __name__ == "__main__":