From b341059291a6b3d257dce414c9494f47404f0ee4 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 12 Jun 2024 19:24:51 -0400 Subject: [PATCH] Install cutlass kernels. --- .../operators/flash_attention/fmha_forward_lib.py | 0 torchbenchmark/operators/flash_attention/operator.py | 8 +++++--- userbenchmark/triton/install.py | 10 ++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 torchbenchmark/operators/flash_attention/fmha_forward_lib.py diff --git a/torchbenchmark/operators/flash_attention/fmha_forward_lib.py b/torchbenchmark/operators/flash_attention/fmha_forward_lib.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 9d4b197132..1b3390f7c8 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -61,9 +61,11 @@ # [Optional] colfax cutlass backend try: - # colfax Flash Attention V2 for Hopper - # https://www.internalfb.com/code/fbsource/fbcode/ai_codesign/gen_ai/cutlass-kernels/src/fmha/README.md - torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib") + if not hasattr(torch.version, "git_version"): + # colfax Flash Attention V2 for Hopper + torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib") + else: + torch.ops.load_library("colfax_cutlass_fmha_forward_lib.so") colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward except (ImportError, IOError, AttributeError): colfax_cutlass_fmha = None diff --git a/userbenchmark/triton/install.py b/userbenchmark/triton/install.py index d5faff068c..4a85ade546 100644 --- a/userbenchmark/triton/install.py +++ b/userbenchmark/triton/install.py @@ -18,10 +18,20 @@ def test_fbgemm(): cmd = [sys.executable, "-c", '"import fbgemm_gpu.experimental.gen_ai"'] subprocess.check_call(cmd) +def install_cutlass(): + pass + +def test_cutlass(): + pass + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU") + parser.add_argument("--cutlass", action="store_true", help="Install optional CUTLASS kernels") args = parser.parse_args() if args.fbgemm: install_fbgemm() test_fbgemm() + if args.cutlass: + install_cutlass() + test_cutlass()