Skip to content

Commit

Permalink
move import inside test
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Jul 2, 2024
1 parent 8219ec9 commit 65c2cb7
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
# Import autocasting libs
from torch import nn
from torch.cuda import amp
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput

Expand Down Expand Up @@ -6929,6 +6928,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):


def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel

class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 65c2cb7

Please sign in to comment.