diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 54965e741819..cafaed9c793a 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -242,4 +242,11 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods { + let summary = "A placeholder Op for the instruction group scheduling"; + let description = [{ + A placeholder Op for the instruction group scheduling. + }]; +} + #endif diff --git a/python/src/llvm.cc b/python/src/llvm.cc index bd1599779a7c..663152d5f659 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -116,6 +116,10 @@ std::string translateLLVMIRToASM(llvm::Module &module, opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; opt.TrapUnreachable = true; + + opt.MCOptions.AsmVerbose = true; + opt.MCOptions.PreserveAsmComments = true; + std::unique_ptr machine{target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, std::nullopt, diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index d2c3aa4f3f12..27917b6810f2 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -180,6 +180,8 @@ def make_llir(src, metadata, options): passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) + + amd.passes.ttgpuir.insert_sched_group_barriers(pm) ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless ## of the value of kernel arg `allow_flush_denorm`. @@ -197,6 +199,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + amd.passes.ttgpuir.add_sched_group_barriers(pm) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index be9efe4033a4..559aef8e9ec4 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -34,6 +34,9 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); +std::unique_ptr> +createSchedGroupBarriersInsertionPass(); +std::unique_ptr> createSchedGroupBarriersLoweringPass(); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index b27c3bf8f929..0dbdb437edac 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -55,4 +55,18 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul } +def SchedGroupBarriersInsertion : Pass<"insert-sched-group-barriers", "mlir::ModuleOp"> { + let summary = "Insert Scheduling Group Barriers"; + let constructor = "mlir::triton::createSchedGroupBarriersInsertionPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + +def SchedGroupBarriersLowering : Pass<"lower-sched-group-barriers", "mlir::ModuleOp"> { + let summary = "Lower Scheduling Group Barriers to LLVM intrinsics"; + let constructor = "mlir::triton::createSchedGroupBarriersLoweringPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 705c4258d052..3b989242af6f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -18,6 +18,7 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUsage.cpp OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp + SchedInstructions.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp new file mode 100644 index 000000000000..3b29b5d9dda6 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -0,0 +1,207 @@ +#include "TritonAMDGPUToLLVM/Passes.h" + +#include "TritonAMDGPUTransforms/MfmaGroup.h" +#include "Utility.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/AxisInfo.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_SCHEDGROUPBARRIERSINSERTION +#define GEN_PASS_DEF_SCHEDGROUPBARRIERSLOWERING +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; + +namespace { +enum class InstructionMaskEnum : int64_t { + NONE = 0x0000000, + VALU = 0x00000002, + SALU = 0x00000004, + MFMA = 0x00000008, + ALL_VMEM = 0x00000010, + VMEM_READ = 0x00000020, + VMEM_WRITE = 0x00000040, + ALL_DS = 0x00000080, + DS_READ = 0x00000100, + DS_WRITE = 0x00000200 +}; + +const bool modifyScheduling{false}; +// const bool modifyScheduling{true}; + +void buildSchedGroupBarrier(PatternRewriter &builder, + InstructionMaskEnum maskValue, int sizeValue, + int groupIdValue) { + MLIRContext *ctx = builder.getContext(); + Location loc = builder.getUnknownLoc(); + auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.group.barrier"); + LLVM::FastmathFlagsAttr defaultFlags{}; + Type i32 = builder.getI32Type(); + auto mask = builder.create( + loc, builder.getIntegerAttr(i32, static_cast(maskValue))); + auto size = builder.create( + loc, builder.getIntegerAttr(i32, sizeValue)); + auto groupId = builder.create( + loc, builder.getIntegerAttr(i32, groupIdValue)); + builder.create(loc, TypeRange{}, intrinsicName, + ValueRange{mask, size, groupId}, + defaultFlags); +} + +Operation *generatedSchedBarrier(PatternRewriter &rewriter, + InstructionMaskEnum maskValue) { + MLIRContext *ctx = rewriter.getContext(); + Location loc = rewriter.getUnknownLoc(); + auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.barrier"); + LLVM::FastmathFlagsAttr defaultFlags{}; + Type i32 = rewriter.getI32Type(); + auto mask = rewriter.create( + loc, rewriter.getIntegerAttr(i32, static_cast(maskValue))); + return rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mask}, defaultFlags); +} + +struct SchedGroupBarriersRewriter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(triton::gpu::GroupSched schedBarrier, + PatternRewriter &rewriter) const override { + + Block *block = schedBarrier->getBlock(); + + size_t numGlbLoads = 0; + block->walk([&numGlbLoads](LLVM::CallOp callOp) { + StringRef calleeName = callOp.getCallee().value(); + if (calleeName.contains("__predicated_load_vector")) + ++numGlbLoads; + }); + + size_t numDsReads = 0; + block->walk([&numDsReads](LLVM::LoadOp op) { + auto operandType = op.getOperand().getType(); + if (auto ptr = llvm::dyn_cast(operandType)) + if (ptr.getAddressSpace() == 3) + ++numDsReads; + }); + + size_t numDsWrites = 0; + block->walk([&numDsWrites](LLVM::StoreOp op) { + auto operandType = op.getOperand(1).getType(); + if (auto ptr = llvm::dyn_cast(operandType)) + if (ptr.getAddressSpace() == 3) + ++numDsWrites; + }); + + size_t numMfmas = 0; + block->walk([&numMfmas](Operation *op) { + StringRef opName = op->getName().getStringRef(); + if (opName.contains("mfma")) + ++numMfmas; + }); + + llvm::dbgs() << "group scheduling info: [" + << "numGlbLoads: " << numGlbLoads << ", " + << "numDsReads: " << numDsReads << ", " + << "numDsWrites: " << numDsWrites << ", " + << "numMfmas: " << numMfmas << "]\n"; + + size_t barrierCounter{0}; + block->walk([&barrierCounter, &rewriter](ROCDL::BarrierOp op) { + if (barrierCounter == 1) { + rewriter.setInsertionPointAfter(op); + return WalkResult::interrupt(); + } + ++barrierCounter; + return WalkResult::advance(); + }); + + // rewriter.setInsertionPointToStart(block); + auto op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE); + + rewriter.setInsertionPointAfter(schedBarrier); + const size_t numIssues = numGlbLoads; + for (size_t i = 0; i < numIssues; ++i) { + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0); + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_READ, + numDsReads / numIssues, 0); + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0); + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_WRITE, + numDsWrites / numIssues, 0); + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0); + buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, + (numMfmas / numIssues) - 3, 0); + } + op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE); + rewriter.eraseOp(schedBarrier); + return mlir::success(); + } +}; + +struct SchedGroupBarriersLowering + : public triton::impl::SchedGroupBarriersLoweringBase< + SchedGroupBarriersLowering> { + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + if (!modifyScheduling) + return; + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +struct SchedGroupBarriersInsertion + : public triton::impl::SchedGroupBarriersInsertionBase< + SchedGroupBarriersInsertion> { + + void insertPlaceholder(mlir::OpBuilder &builder, triton::DotOp dot) { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(dot); + Location loc = builder.getUnknownLoc(); + builder.create(loc); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + if (!modifyScheduling) + return; + + mlir::OpBuilder builder(ctx); + mod.walk( + [this, &builder](triton::DotOp op) { insertPlaceholder(builder, op); }); + } +}; +} // namespace + +namespace mlir { +namespace triton { +std::unique_ptr> +createSchedGroupBarriersLoweringPass() { + return std::make_unique(); +} + +std::unique_ptr> +createSchedGroupBarriersInsertionPass() { + return std::make_unique(); +} +} // namespace triton +} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7c2d..0753df0573d8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -57,6 +57,7 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addLegalOp(); + addLegalOp(); } }; diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a6ef2fec7c67..c738ecec7b33 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -36,6 +36,9 @@ const char *const amdTargetTriple = "amdgcn-amd-amdhsa"; void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; + m.def("insert_sched_group_barriers", [](mlir::PassManager &pm) { + pm.addPass(createSchedGroupBarriersInsertionPass()); + }); m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch, bool ftz) { pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); @@ -43,6 +46,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); }); + m.def("add_sched_group_barriers", [](mlir::PassManager &pm) { + pm.addPass(createSchedGroupBarriersLoweringPass()); + }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { pm.addPass(