Skip to content

Commit

Permalink
Add register op
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 15, 2024
1 parent 248f154 commit 9c37293
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions userbenchmark/triton/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"]
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.joinpath("third_party", "cutlass")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass-kernel")

Expand All @@ -25,27 +25,28 @@
"-Xcompiler=-fno-strict-aliasing",
"-Xcompiler=-fPIE",
"-Xcompiler=-lcuda",
"-DNDEBUG",
"-DCUTLASS_TEST_LEVEL=0",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
]
PREPROCESSOR_FLAGS = [
f"-I{str(COLFAX_CUTLASS_PATH.joinpath("lib").resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath("include").resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath("include").resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath("examples", "commmon").resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath("tools", "util", "include").resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{CUDA_HOME}/include",
f"-rpath,'{CUDA_HOME}/lib64'",
f"-rpath,'{CUDA_HOME}/lib'"
f"-Wl,-rpath,'{CUDA_HOME}/lib64'",
f"-Wl,-rpath,'{CUDA_HOME}/lib'"
]
FMHA_SOURCES = [
# Source 1
f"{str(COLFAX_CUTLASS_PATH.joinpath("src", "fmha", "fmha_forward.cu").resolve())}"
f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}",
# Source 2
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath("src", "fmha", "register_op.cu").resolve())}"
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}",
"-o",
"fmha_forward_lib",
]
Expand All @@ -69,6 +70,8 @@ def install_cutlass():
cmd.extend(PREPROCESSOR_FLAGS)
cmd.extend(NVCC_FLAGS)
cmd.extend(FMHA_SOURCES)
print(" ".join(cmd))
print(str(output_dir.resolve()))
subprocess.check_call(cmd, cwd=str(output_dir.resolve()))
return str(output_dir.joinpath(FMHA_SOURCES[-1]).resolve())

Expand Down

0 comments on commit 9c37293

Please sign in to comment.