diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index f7da41942b..b792fd9b1e 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -21,13 +21,10 @@ from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention -torch.manual_seed(0) - - class AttentionTest(unittest.TestCase): def setUp(self): super().setUp() - + torch.manual_seed(0) # Constants self.embed_dim = 2048 self.num_heads = 32