Skip to content

Commit

Permalink
Install cutlass kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 12, 2024
1 parent ed124e1 commit b341059
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Empty file.
8 changes: 5 additions & 3 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@

# [Optional] colfax cutlass backend
try:
# colfax Flash Attention V2 for Hopper
# https://www.internalfb.com/code/fbsource/fbcode/ai_codesign/gen_ai/cutlass-kernels/src/fmha/README.md
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
if not hasattr(torch.version, "git_version"):
# colfax Flash Attention V2 for Hopper
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
else:
torch.ops.load_library("colfax_cutlass_fmha_forward_lib.so")
colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward
except (ImportError, IOError, AttributeError):
colfax_cutlass_fmha = None
Expand Down
10 changes: 10 additions & 0 deletions userbenchmark/triton/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ def test_fbgemm():
cmd = [sys.executable, "-c", '"import fbgemm_gpu.experimental.gen_ai"']
subprocess.check_call(cmd)

def install_cutlass():
pass

def test_cutlass():
pass

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
parser.add_argument("--cutlass", action="store_true", help="Install optional CUTLASS kernels")
args = parser.parse_args()
if args.fbgemm:
install_fbgemm()
test_fbgemm()
if args.cutlass:
install_cutlass()
test_cutlass()

0 comments on commit b341059

Please sign in to comment.