Skip to content

Commit

Permalink
Make xlformers optional
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 12, 2024
1 parent e915cf9 commit df9f3d9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def xformers_preprocess(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> xformers_fmha.Inputs:
):
q_1, k_1, v_1 = permute_qkv(q, k, v, perm=(0, 2, 1, 3))
attn_bias = xformers.ops.LowerTriangularMask() if self.causal else None
fhma_input = xformers_fmha.Inputs(
Expand Down

0 comments on commit df9f3d9

Please sign in to comment.