Skip to content

Commit

Permalink
Modified insertion of IGLP_OPT intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Jul 10, 2024
1 parent 13b20fe commit e254a02
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 21 deletions.
9 changes: 8 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,15 @@ def make_llir(src, metadata, options):
## depends on the value of kernel arg `allow_flush_denorm`.
## 3. __HIP_FTZ is default to 1 and not exposed as a kernel argument.
## For now it is used as a controller for developers only.
sched_mode = ""
if "AMD_OPS_SCHED_MODE" in os.environ.keys():
sched_mode = os.environ['AMD_OPS_SCHED_MODE']
allowed = ["iglp-opt-0", "iglp-opt-1", ""]
if not sched_mode in allowed:
raise RuntimeError(f'unknown mode for `AMD_OPS_SCHED_MODE`. Given `{sched_mode}`. Allowed: {", ".join(allowed)}')

__HIP_FTZ = True
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ)
amd.passes.ttgpuir.add_to_llvmir(pm, options.arch, __HIP_FTZ, sched_mode)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)

Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ createDecomposeUnsupportedConversionsPass(StringRef targetArch);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz,
std::string schedMode);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();

#define GEN_PASS_REGISTRATION
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true, \"\")";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::math::MathDialect",
Expand All @@ -32,6 +32,8 @@ def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::Mod
"gfx target device architecture, e.g., gfx942">,
Option<"ftz", "ftz", "bool", /*default*/"true",
"flush denorms for math functions">,
Option<"sched", "sched", "std::string", /*default*/"\"\"",
"scheduling variants">,
];
}

Expand Down
19 changes: 14 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ::mlir::triton::gpu::getShapePerCTA;
namespace mlir::triton::AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter);
ConversionPatternRewriter &rewriter,
StringRef schedMode);

LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
Expand All @@ -18,7 +19,11 @@ LogicalResult convertWMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,

namespace {
struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
// using ConvertOpToLLVMPattern<triton::DotOp>::ConvertOpToLLVMPattern;
DotOpConversion(LLVMTypeConverter &typeConverter, PatternBenefit benefit,
StringRef schedMode)
: ConvertOpToLLVMPattern<triton::DotOp>(typeConverter, benefit),
schedMode(schedMode) {}

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
Expand All @@ -37,7 +42,8 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
if (!isOuter) {
auto dEncoding = cast<RankedTensorType>(D.getType()).getEncoding();
if (isa<AMDMfmaEncodingAttr>(dEncoding) && supportMFMA(op)) {
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter);
return AMD::convertMFMA(op, adaptor, getTypeConverter(), rewriter,
schedMode);
}
if (isa<AMDWmmaEncodingAttr>(dEncoding)) {
return AMD::convertWMMA(op, adaptor, getTypeConverter(), rewriter);
Expand All @@ -51,14 +57,17 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
llvm::report_fatal_error(
"Unsupported DotOp found when converting TritonGPU to LLVM.");
}

private:
StringRef schedMode;
};
} // namespace

namespace mlir::triton::AMD {
void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, benefit);
PatternBenefit benefit, StringRef schedMode) {
patterns.add<DotOpConversion>(typeConverter, benefit, schedMode);
}
} // namespace mlir::triton::AMD
36 changes: 32 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,25 @@ using ::mlir::triton::gpu::SharedEncodingAttr;

using ValueTable = std::map<std::array<int, 3>, Value>;

enum class SchedulingOptionsEnum { IGLP_OPT_0 = 0, IGLP_OPT_1 = 1, NONE_SCHED };

struct DotOpMFMAConversionHelper {
AMDMfmaEncodingAttr mfmaLayout;

ConversionPatternRewriter &rewriter;
const LLVMTypeConverter *typeConverter;
SchedulingOptionsEnum schedMode;
Location loc;
MLIRContext *ctx{};

explicit DotOpMFMAConversionHelper(AMDMfmaEncodingAttr mfmaLayout,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter *typeConverter,
SchedulingOptionsEnum schedMode,
Location loc)
: mfmaLayout(mfmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
typeConverter(typeConverter), schedMode(schedMode), loc(loc),
ctx(mfmaLayout.getContext()) {}

Value getThreadId() const {
auto llvmIndexTy = typeConverter->getIndexType();
Expand All @@ -70,6 +75,19 @@ struct DotOpMFMAConversionHelper {
return rewriter.create(loweredOp)->getResult(0);
}

void generatedIglpIntrinsic() const {
if (schedMode == SchedulingOptionsEnum::NONE_SCHED)
return;
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.iglp.opt");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = rewriter.getI32Type();

auto option = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, static_cast<int>(schedMode)));
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{option}, defaultFlags);
}

int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if ((mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64))
return 1;
Expand Down Expand Up @@ -171,6 +189,8 @@ struct DotOpMFMAConversionHelper {
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

generatedIglpIntrinsic();

Value a = op.getA();
Value b = op.getB();
Value d = op.getD();
Expand Down Expand Up @@ -351,13 +371,13 @@ struct DotOpMFMAConversionHelper {
return dotOpVals;
}
};

} // namespace

namespace mlir::triton::AMD {
LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter,
StringRef schedMode) {
auto rankedTType = [](Value tensor) {
return cast<RankedTensorType>(tensor.getType());
};
Expand All @@ -375,11 +395,19 @@ LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor,
cTensorTy.getShape()[1] == dTensorTy.getShape()[1] &&
"DotOp's $c operand should pass the same number of values as $d");

static const DenseMap<StringRef, SchedulingOptionsEnum> schedModesToEnum = {
{"iglp-opt-0", SchedulingOptionsEnum::IGLP_OPT_0},
{"iglp-opt-1", SchedulingOptionsEnum::IGLP_OPT_1},
{"", SchedulingOptionsEnum::NONE_SCHED}};
assert(schedModesToEnum.contains(schedMode) &&
"sched mode must be in the allowed set");

auto loc = op.getLoc();
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(
cast<RankedTensorType>(op.getResult().getType()).getEncoding());

DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter, loc);
DotOpMFMAConversionHelper helper(mfmaLayout, rewriter, typeConverter,
schedModesToEnum.at(schedMode), loc);

return helper.convertDot(op, adaptor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void populateConvertLayoutOpToLLVMPatterns(
void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit);
PatternBenefit benefit, StringRef schedMode);
void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
Expand Down
12 changes: 8 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ class TritonLLVMConversionTarget : public ConversionTarget {
struct ConvertTritonAMDGPUToLLVM
: public triton::impl::ConvertTritonAMDGPUToLLVMBase<
ConvertTritonAMDGPUToLLVM> {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz) {
explicit ConvertTritonAMDGPUToLLVM(StringRef targetArch, bool ftz,
StringRef schedMode) {
this->arch = targetArch.str();
this->ftz = ftz;
this->sched = schedMode.str();
}

void getDependentDialects(DialectRegistry &registry) const override {
Expand Down Expand Up @@ -174,7 +176,7 @@ struct ConvertTritonAMDGPUToLLVM
mlir::triton::populateConvertLayoutOpToLLVMPatterns(
typeConverter, targetInfo, patterns, commonBenefit);
AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
axisInfoAnalysis, AMDBenefit);
axisInfoAnalysis, AMDBenefit, sched);
AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz,
axisInfoAnalysis, allocation,
targetInfo, AMDBenefit);
Expand Down Expand Up @@ -246,8 +248,10 @@ namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz);
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz,
std::string schedMode) {
return std::make_unique<ConvertTritonAMDGPUToLLVM>(targetArch, ftz,
schedMode);
}

} // namespace triton
Expand Down
8 changes: 4 additions & 4 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ namespace py = pybind11;
namespace {
void init_triton_amd_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton;
m.def("add_to_llvmir",
[](mlir::PassManager &pm, const std::string &arch, bool ftz) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
});
m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch,
bool ftz, const std::string &sched) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz, sched));
});
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
Expand Down

0 comments on commit e254a02

Please sign in to comment.