From f943856f2a12c2c1984395e2432f98cc41ec9dbf Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 12 Nov 2024 18:44:25 -0500 Subject: [PATCH] Fix flaky ET attention test (#6795) * Fix flaky ET attention test * Use assert_close * Remove msg from assert_close Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: Mengwei Liu --- extension/llm/modules/test/test_attention.py | 11 +++++++---- pytest.ini | 1 - 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 9ae136a213..565e8c67d7 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -13,6 +13,7 @@ MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime +from torch.testing import assert_close from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention @@ -94,7 +95,7 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(et_res, tt_res) # test with kv cache self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) @@ -124,7 +125,8 @@ def test_attention_eager(self): tt_res = self.tt_mha( self.x, self.x, input_pos=next_input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + + assert_close(et_res, tt_res) def test_attention_export(self): # Self attention. @@ -136,7 +138,8 @@ def test_attention_export(self): ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res, tt_res)) + + assert_close(et_res, tt_res) # TODO: KV cache. @@ -162,6 +165,6 @@ def test_attention_executorch(self): et_res = method.execute((self.x, self.x, self.input_pos)) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + assert_close(et_res[0], tt_res) # TODO: KV cache. diff --git a/pytest.ini b/pytest.ini index 03c015c397..a5041504ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -39,7 +39,6 @@ addopts = backends/xnnpack/test # extension/ extension/llm/modules/test - --ignore=extension/llm/modules/test/test_mha.py extension/pybindings/test # Runtime runtime