From ec83548f5e7bc6799d93c925fa9d5e13cf6f258a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 16 Jul 2024 14:49:35 -0700 Subject: [PATCH] adjust gqa flash attention test threshold for rocm --- .../python/transformers/test_flash_attn_rocm.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py index fe7e39722237f..880f4175e00b7 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -35,8 +35,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - rtol=0.002, - atol=0.002, + rtol=0.001, + atol=0.005, ) parity_check_gqa_prompt_no_buff( config, @@ -45,8 +45,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - rtol=0.002, - atol=0.002, + rtol=0.001, + atol=0.005, ) @parameterized.expand(gqa_past_flash_attention_test_cases()) @@ -67,8 +67,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - rtol=0.002, - atol=0.002, + rtol=0.001, + atol=0.005, ) parity_check_gqa_past_no_buff( config, @@ -77,8 +77,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, - rtol=0.002, - atol=0.002, + rtol=0.001, + atol=0.005, )