From 0e6800bd386a32c10963caa03a9e9494cbeceaa8 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 28 Mar 2024 01:49:01 +0000 Subject: [PATCH 1/3] Set llvm flag using LLVM_FLAG env var --- include/triton/Tools/Sys/GetEnv.hpp | 6 +++--- lib/Target/HSACO/HSACOTranslation.cpp | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 5d2e13c45b97..7169ba4c027f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,9 +30,9 @@ namespace mlir::triton { const std::set ENV_VARS = { - "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", - "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", - "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16"}; + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", + "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16", "LLVM_FLAG"}; namespace tools { diff --git a/lib/Target/HSACO/HSACOTranslation.cpp b/lib/Target/HSACO/HSACOTranslation.cpp index 69e755ec3f5a..a4192d69e93a 100644 --- a/lib/Target/HSACO/HSACOTranslation.cpp +++ b/lib/Target/HSACO/HSACOTranslation.cpp @@ -135,6 +135,14 @@ std::string generate_amdgcn_assembly(llvm::Module *module, if (machine == nullptr) return ""; + std::string llvm_flag = mlir::triton::tools::getenv("LLVM_FLAG"); + if (!llvm_flag.empty()) { + std::vector args; + args.push_back((char *)("triton")); + args.push_back((char *)(llvm_flag.c_str())); + llvm::cl::ParseCommandLineOptions(args.size(), &args[0]); + } + llvm::SmallVector buffer; llvm::legacy::PassManager pass; llvm::raw_svector_ostream stream(buffer); From b56afbaaa1605278e2c6a29c3e86ac9b87e93a89 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 28 Mar 2024 01:50:33 +0000 Subject: [PATCH 2/3] Use BLOCK_M=128 for MI300 --- python/perf-kernels/06-fused-attention-fwd-transV.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/perf-kernels/06-fused-attention-fwd-transV.py b/python/perf-kernels/06-fused-attention-fwd-transV.py index 35a6da764746..64eb5e07c532 100644 --- a/python/perf-kernels/06-fused-attention-fwd-transV.py +++ b/python/perf-kernels/06-fused-attention-fwd-transV.py @@ -163,10 +163,9 @@ def forward(ctx, q, k, v, sm_scale): kpack = 1 else: ## D_HEAD = 128 - ## For fp16, pick BLOCK_M=256, num_warps=8 - ## For fp8, pick BLOCK_M=128, num_warps=4 + ## Tuning for MI300 ## TODO (zhanglx): add tuning infra for FA - BLOCK_M = 128 if TORCH_HAS_FP8E4 and q.dtype == torch.float8_e4m3fnuz else 256 + BLOCK_M = 128 BLOCK_N = 128 waves_per_eu = 2 num_warps = BLOCK_M // 32 From 9b3596d6cda36b1f280f890e357064cac88d0aff Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 12 Apr 2024 21:50:18 -0500 Subject: [PATCH 3/3] Rename env var --- include/triton/Tools/Sys/GetEnv.hpp | 2 +- lib/Target/HSACO/HSACOTranslation.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 7169ba4c027f..4650950954c2 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -32,7 +32,7 @@ namespace mlir::triton { const std::set ENV_VARS = { "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", - "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16", "LLVM_FLAG"}; + "AMDGCN_ENABLE_DUMP", "TRUNCATE_F32_TO_BF16", "TRITON_LLVM_FLAG"}; namespace tools { diff --git a/lib/Target/HSACO/HSACOTranslation.cpp b/lib/Target/HSACO/HSACOTranslation.cpp index a4192d69e93a..9d5de2514e51 100644 --- a/lib/Target/HSACO/HSACOTranslation.cpp +++ b/lib/Target/HSACO/HSACOTranslation.cpp @@ -135,7 +135,7 @@ std::string generate_amdgcn_assembly(llvm::Module *module, if (machine == nullptr) return ""; - std::string llvm_flag = mlir::triton::tools::getenv("LLVM_FLAG"); + std::string llvm_flag = mlir::triton::tools::getenv("TRITON_LLVM_FLAG"); if (!llvm_flag.empty()) { std::vector args; args.push_back((char *)("triton"));