From df9f3d9cb5b99f801415d26422d9c33f601e1e51 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 12 Jun 2024 16:12:56 -0400 Subject: [PATCH] Make xlformers optional --- torchbenchmark/operators/flash_attention/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 90a7964ad2..9d4b197132 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -208,7 +208,7 @@ def xformers_preprocess( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> xformers_fmha.Inputs: + ): q_1, k_1, v_1 = permute_qkv(q, k, v, perm=(0, 2, 1, 3)) attn_bias = xformers.ops.LowerTriangularMask() if self.causal else None fhma_input = xformers_fmha.Inputs(