From 65c2cb7e192e87d15ffe00ddf09c9160d7a598f0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 19:08:06 +0000 Subject: [PATCH] move import inside test --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7a9abb48860f2..37bc6c066a1f9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -25,7 +25,6 @@ # Import autocasting libs from torch import nn from torch.cuda import amp -from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer from transformers.modeling_outputs import SequenceClassifierOutput @@ -6929,6 +6928,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): def test_aten_attention(): + from torch.nn.attention import SDPBackend, sdpa_kernel + class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__()