Skip to content

Commit

Permalink
Fix flash attention kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 12, 2024
1 parent c13df57 commit 38a6a8a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 22 deletions.
52 changes: 30 additions & 22 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This benchmark script is based on the benchmark code from:
Expand All @@ -10,17 +14,18 @@
https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
* Flash-V2: the FA-V2 from //ai_codesign/gen_ai/flash_attention_v2:flash_attention_v2,
* SDPA: the torch.nn.attention version of FA-V2
* [optional] Flash-V2: the FA-V2 from //ai_codesign/gen_ai/flash_attention_v2:flash_attention_v2,
which was imported from https://github.com/Dao-AILab/flash-attention
* Xformers: the memory-efficient attention from xformers:
* [optional] Xformers: the memory-efficient attention from xformers:
https://fburl.com/code/cuorcm9h
* [optional] Xformers-Splitk: the triton-splitk FMHA kernel from xformers:
https://fburl.com/code/awt36vjj
Disabled by default because it failed with some configs. Note that
the relevant benchmark only works with causal = False at the moment.
Known to work with "--batch=8 --n-heads=8 --xformers-splitk"
Expand All @@ -29,26 +34,32 @@
import argparse
import math
import os
from typing import Callable, Optional

import numpy

import torch
import triton # @manual=//triton:triton
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers

from triton.ops.flash_attention import attention as triton_op_FA2
from torchbenchmark.util.kernels.triton_fused_attention import attention as triton_tutorial_FA2
from aikl.gpu.triton.tests.test_fmha_utils import (
generate_qkv,
make_packed_qkv,
permute_qkv,
)
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func as flash_attn_func
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention as sdpa
from triton.ops.flash_attention import attention as triton_op_FA2

from typing import Callable, Optional

# [Optional] flash_attn_func
try:
from .test_fmha_utils import make_packed_qkv
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func as flash_attn_func
except (ImportError, IOError, AttributeError):
pass

# [Optional] xformers backend
try:
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers
from .test_fmha_utils import permute_qkv
except (ImportError, IOError, AttributeError):
pass

# [Optional] colfax cutlass backend
try:
# colfax Flash Attention V2 for Hopper
# https://www.internalfb.com/code/fbsource/fbcode/ai_codesign/gen_ai/cutlass-kernels/src/fmha/README.md
Expand All @@ -57,6 +68,7 @@
except (ImportError, IOError, AttributeError):
colfax_cutlass_fmha = None

# [Optional] ThunderKittens backend
try:
import h100_fwd as tk_fwd
import h100_fwd_causal as tk_fwd_causal
Expand All @@ -65,10 +77,6 @@
tk_fwd_causal = None

from typing import Any, Generator, List

import torch
import triton
import triton.language as tl
from torchbenchmark.util.input import input_filter

from torchbenchmark.util.triton_op import (
Expand Down Expand Up @@ -208,7 +216,7 @@ def xformers_preprocess(
)
return fhma_input

@register_benchmark()
@register_benchmark(enabled=False)
def xformers(
self,
q: torch.Tensor,
Expand Down
70 changes: 70 additions & 0 deletions torchbenchmark/operators/flash_attention/test_fmha_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch


def generate_qkv(
BATCH: int,
H: int,
N_CTX: int,
D_HEAD: int,
dtype: torch.dtype,
device: str = "cuda",
requires_grad: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
torch.manual_seed(20)
q = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
k = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
v = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
return (q, k, v)


def permute_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
perm: Tuple[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_1 = torch.permute(q, perm)
k_1 = torch.permute(k, perm)
v_1 = torch.permute(v, perm)
return (q_1, k_1, v_1)


def make_packed_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Make a packed qkv tensor for flash_attention:
from 3 * (batch, num_head, seq, head_dim) -> (batch, seq, 3, num_head, head_dim)
"""
assert (
q.size() == k.size() == v.size()
), f"{q.size()=}, {k.size()=}, {v.size()=} must be equal!"
(BATCH, H, N_CTX, D_HEAD) = q.size()
(q_1, k_1, v_1) = permute_qkv(q, k, v, perm=(0, 2, 1, 3))
qkv = torch.cat([q_1, k_1, v_1], dim=2)
return torch.reshape(qkv, (BATCH, N_CTX, 3, H, D_HEAD))

0 comments on commit 38a6a8a

Please sign in to comment.