From e254a020b4729d939347f50539ebd9cf34838e1c Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Wed, 10 Jul 2024 13:49:11 +0000 Subject: [PATCH] Modified insertion of IGLP_OPT intrinsics --- third_party/amd/backend/compiler.py | 9 ++++- .../amd/include/TritonAMDGPUToLLVM/Passes.h | 3 +- .../amd/include/TritonAMDGPUToLLVM/Passes.td | 4 ++- .../lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp | 19 +++++++--- .../TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp | 36 ++++++++++++++++--- .../PatternTritonGPUOpToLLVM.h | 2 +- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 12 ++++--- third_party/amd/python/triton_amd.cc | 8 ++--- 8 files changed, 72 insertions(+), 21 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index c30f24aaa602..fa95feb33c10 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -166,8 +166,15 @@ def make_llir(src, metadata, options): ## depends on the value of kernel arg `allow_flush_denorm`. ## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument. ## For now it is used as a controller for developers only. + sched_mode = "" + if "AMD_OPS_SCHED_MODE" in os.environ.keys(): + sched_mode = os.environ['AMD_OPS_SCHED_MODE'] + allowed = ["iglp-opt-0", "iglp-opt-1", ""] + if not sched_mode in allowed: + raise RuntimeError(f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed: {", ".join(allowed)}') + __HIP_FTZ = True - amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ) + amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ, sched_mode) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index df5ad78494ab..16578d035602 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch); } // namespace AMD std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz, + std::string schedMode); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); #define GEN_PASS_REGISTRATION diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index 986c6763bbb3..f59139f22189 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> { let summary = "Convert TritonGPU to LLVM"; - let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)"; + let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")"; let dependentDialects = ["mlir::arith::ArithDialect", "mlir::math::MathDialect", @@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod "gfx target device architecture, e.g., gfx942">, Option<"ftz", "ftz", "bool", /*default*/"true", "flush denorms for math functions">, + Option<"sched", "sched", "std::string", /*default*/"\"\"", + "scheduling variants">, ]; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp index 15237282172f..b25f70b8d1ff 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp @@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA; namespace mlir::triton::AMD { LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter); + ConversionPatternRewriter &rewriter, + StringRef schedMode); LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, @@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, namespace { struct DotOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + // using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit, + StringRef schedMode) + : ConvertOpToLLVMPattern(typeConverter, benefit), + schedMode(schedMode) {} LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, @@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { if (!isOuter) { auto dEncoding = cast(D.getType()).getEncoding(); if (isa(dEncoding) && supportMFMA(op)) { - return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter); + return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter, + schedMode); } if (isa(dEncoding)) { return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter); @@ -51,6 +57,9 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); } + +private: + StringRef schedMode; }; } // namespace @@ -58,7 +67,7 @@ namespace mlir::triton::AMD { void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + PatternBenefit benefit, StringRef schedMode) { + patterns.add(typeConverter, benefit, schedMode); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index c190711f1022..5ff24739d9db 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -38,20 +38,25 @@ using ::mlir::triton::gpu::SharedEncodingAttr; using ValueTable = std::map, Value>; +enum class SchedulingOptionsEnum { IGLP_OPT_0 = 0, IGLP_OPT_1 = 1, NONE_SCHED }; + struct DotOpMFMAConversionHelper { AMDMfmaEncodingAttr mfmaLayout; ConversionPatternRewriter &rewriter; const LLVMTypeConverter *typeConverter; + SchedulingOptionsEnum schedMode; Location loc; MLIRContext *ctx{}; explicit DotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, + SchedulingOptionsEnum schedMode, Location loc) : mfmaLayout(mfmaLayout), rewriter(rewriter), - typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {} + typeConverter(typeConverter), schedMode(schedMode), loc(loc), + ctx(mfmaLayout.getContext()) {} Value getThreadId() const { auto llvmIndexTy = typeConverter->getIndexType(); @@ -70,6 +75,19 @@ struct DotOpMFMAConversionHelper { return rewriter.create(loweredOp)->getResult(0); } + void generatedIglpIntrinsic() const { + if (schedMode == SchedulingOptionsEnum::NONE_SCHED) + return; + auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.iglp.opt"); + LLVM::FastmathFlagsAttr defaultFlags{}; + Type i32 = rewriter.getI32Type(); + + auto option = rewriter.create( + loc, rewriter.getIntegerAttr(i32, static_cast(schedMode))); + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{option}, defaultFlags); + } + int getNumSubmatrices(Type elementType, int mDim, int nDim) const { if ((mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)) return 1; @@ -171,6 +189,8 @@ struct DotOpMFMAConversionHelper { assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) || (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); + generatedIglpIntrinsic(); + Value a = op.getA(); Value b = op.getB(); Value d = op.getD(); @@ -351,13 +371,13 @@ struct DotOpMFMAConversionHelper { return dotOpVals; } }; - } // namespace namespace mlir::triton::AMD { LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, + StringRef schedMode) { auto rankedTType = [](Value tensor) { return cast(tensor.getType()); }; @@ -375,11 +395,19 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, cTensorTy.getShape()[1] == dTensorTy.getShape()[1] && "DotOp's $c operand should pass the same number of values as $d"); + static const DenseMap schedModesToEnum = { + {"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0}, + {"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1}, + {"", SchedulingOptionsEnum::NONE_SCHED}}; + assert(schedModesToEnum.contains(schedMode) && + "sched mode must be in the allowed set"); + auto loc = op.getLoc(); auto mfmaLayout = cast( cast(op.getResult().getType()).getEncoding()); - DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, loc); + DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, + schedModesToEnum.at(schedMode), loc); return helper.convertDot(op, adaptor); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 67e5369b8650..1a3a96874ad7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns( void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, - PatternBenefit benefit); + PatternBenefit benefit, StringRef schedMode); void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7c2d..b6ef605c251c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget { struct ConvertTritonAMDGPUToLLVM : public triton::impl::ConvertTritonAMDGPUToLLVMBase< ConvertTritonAMDGPUToLLVM> { - explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) { + explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz, + StringRef schedMode) { this->arch = targetArch.str(); this->ftz = ftz; + this->sched = schedMode.str(); } void getDependentDialects(DialectRegistry ®istry) const override { @@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populateConvertLayoutOpToLLVMPatterns( typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, AMDBenefit); + axisInfoAnalysis, AMDBenefit, sched); AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz, axisInfoAnalysis, allocation, targetInfo, AMDBenefit); @@ -246,8 +248,10 @@ namespace mlir { namespace triton { std::unique_ptr> -createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) { - return std::make_unique(targetArch, ftz); +createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz, + std::string schedMode) { + return std::make_unique(targetArch, ftz, + schedMode); } } // namespace triton diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index ba73746e0d37..cbf98aa7e096 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -34,10 +34,10 @@ namespace py = pybind11; namespace { void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", - [](mlir::PassManager &pm, const std::string &arch, bool ftz) { - pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz)); - }); + m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch, + bool ftz, const std::string &sched) { + pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz, sched)); + }); m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); });