diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7a9abb48860f2..37bc6c066a1f9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -25,7 +25,6 @@ # Import autocasting libs from torch import nn from torch.cuda import amp -from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer from transformers.modeling_outputs import SequenceClassifierOutput @@ -6929,6 +6928,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): def test_aten_attention(): + from torch.nn.attention import SDPBackend, sdpa_kernel + class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__()