Skip to content

Commit

Permalink
Added a sketch of the instr. sched. group barriers
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Aug 2, 2024
1 parent 1bb5868 commit c29d621
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,11 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
}];
}

def TTG_GroupSched : TTG_Op<"group_sched"> {
let summary = "A placeholder Op for the instruction group scheduling";
let description = [{
A placeholder Op for the instruction group scheduling.
}];
}

#endif
3 changes: 3 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def make_llir(src, metadata, options):
passes.convert.add_index_to_llvmir(pm)

passes.ttgpuir.add_allocate_shared_memory(pm)

amd.passes.ttgpuir.insert_sched_group_barriers(pm)
## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows:
## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless
## of the value of kernel arg `allow_flush_denorm`.
Expand All @@ -197,6 +199,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.add_sched_group_barriers(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersInsertionPass();
std::unique_ptr<OperationPass<ModuleOp>> createSchedGroupBarriersLoweringPass();

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
14 changes: 14 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,18 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

}

def SchedGroupBarriersInsertion : Pass<"insert-sched-group-barriers", "mlir::ModuleOp"> {
let summary = "Insert Scheduling Group Barriers";
let constructor = "mlir::triton::createSchedGroupBarriersInsertionPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

def SchedGroupBarriersLowering : Pass<"lower-sched-group-barriers", "mlir::ModuleOp"> {
let summary = "Lower Scheduling Group Barriers to LLVM intrinsics";
let constructor = "mlir::triton::createSchedGroupBarriersLoweringPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

#endif
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUsage.cpp
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
186 changes: 186 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#include "TritonAMDGPUToLLVM/Passes.h"

#include "TritonAMDGPUTransforms/MfmaGroup.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_SCHEDGROUPBARRIERSINSERTION
#define GEN_PASS_DEF_SCHEDGROUPBARRIERSLOWERING
#include "TritonAMDGPUToLLVM/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;

namespace {
enum class InstructionMaskEnum : int64_t {
NONE = 0x0000000,
VALU = 0x00000002,
SALU = 0x00000004,
MFMA = 0x00000008,
ALL_VMEM = 0x00000010,
VMEM_READ = 0x00000020,
VMEM_WRITE = 0x00000040,
ALL_DS = 0x00000080,
DS_READ = 0x00000100,
DS_WRITE = 0x00000200
};

void buildSchedGroupBarrier(PatternRewriter &builder,
InstructionMaskEnum maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = builder.getContext();
Location loc = builder.getUnknownLoc();
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.group.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = builder.getI32Type();
auto mask = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, static_cast<int64_t>(maskValue)));
auto size = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, sizeValue));
auto groupId = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(i32, groupIdValue));
builder.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
}

Operation *generatedSchedBarrier(PatternRewriter &rewriter,
InstructionMaskEnum maskValue) {
MLIRContext *ctx = rewriter.getContext();
Location loc = rewriter.getUnknownLoc();
auto intrinsicName = StringAttr::get(ctx, "llvm.amdgcn.sched.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};
Type i32 = rewriter.getI32Type();
auto mask = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(i32, static_cast<int64_t>(maskValue)));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
}

struct SchedGroupBarriersRewriter
: public OpRewritePattern<triton::gpu::GroupSched> {
using OpRewritePattern<triton::gpu::GroupSched>::OpRewritePattern;
LogicalResult matchAndRewrite(triton::gpu::GroupSched schedBarrier,
PatternRewriter &rewriter) const override {
Block *block = schedBarrier->getBlock();

size_t numGlbLoads = 0;
block->walk([&numGlbLoads](LLVM::CallOp callOp) {
StringRef calleeName = callOp.getCallee().value();
if (calleeName.contains("__predicated_load_vector"))
++numGlbLoads;
});

size_t numDsReads = 0;
block->walk([&numDsReads](LLVM::LoadOp op) {
auto operandType = op.getOperand().getType();
if (auto ptr = llvm::dyn_cast<LLVM::LLVMPointerType>(operandType))
if (ptr.getAddressSpace() == 3)
++numDsReads;
});

size_t numDsWrites = 0;
block->walk([&numDsWrites](LLVM::StoreOp op) {
auto operandType = op.getOperand(1).getType();
if (auto ptr = llvm::dyn_cast<LLVM::LLVMPointerType>(operandType))
if (ptr.getAddressSpace() == 3)
++numDsWrites;
});

size_t numMfmas = 0;
block->walk([&numMfmas](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName.contains("mfma"))
++numMfmas;
});

llvm::dbgs() << "group scheduling info: ["
<< "numGlbLoads: " << numGlbLoads << ", "
<< "numDsReads: " << numDsReads << ", "
<< "numDsWrites: " << numDsWrites << ", "
<< "numMfmas: " << numMfmas << "]\n";

rewriter.setInsertionPointToStart(block);
auto op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE);
rewriter.setInsertionPointAfter(schedBarrier);
const size_t numIssues = numGlbLoads;
for (size_t i = 0; i < numIssues; ++i) {
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_READ,
numDsReads / numIssues, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::DS_WRITE,
numDsWrites / numIssues, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA, 1, 0);
buildSchedGroupBarrier(rewriter, InstructionMaskEnum::MFMA,
(numMfmas / numIssues) - 3, 0);
}
op = generatedSchedBarrier(rewriter, InstructionMaskEnum::NONE);
rewriter.eraseOp(schedBarrier);
return mlir::success();
}
};

struct SchedGroupBarriersLowering
: public triton::impl::SchedGroupBarriersLoweringBase<
SchedGroupBarriersLowering> {

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalOp<triton::gpu::GroupSched>();

RewritePatternSet patterns(ctx);
patterns.add<SchedGroupBarriersRewriter>(ctx);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
};

struct SchedGroupBarriersInsertion
: public triton::impl::SchedGroupBarriersInsertionBase<
SchedGroupBarriersInsertion> {

void insertPlaceholder(mlir::OpBuilder &builder, triton::DotOp dot) {
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(dot);
Location loc = builder.getUnknownLoc();
builder.create<triton::gpu::GroupSched>(loc);
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mlir::OpBuilder builder(ctx);
mod.walk(
[this, &builder](triton::DotOp op) { insertPlaceholder(builder, op); });
}
};
} // namespace

namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersLoweringPass() {
return std::make_unique<SchedGroupBarriersLowering>();
}

std::unique_ptr<OperationPass<ModuleOp>>
createSchedGroupBarriersInsertionPass() {
return std::make_unique<SchedGroupBarriersInsertion>();
}
} // namespace triton
} // namespace mlir
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
addLegalOp<triton::gpu::GroupSched>();
}
};

Expand Down
6 changes: 6 additions & 0 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ const char *const amdTargetTriple = "amdgcn-amd-amdhsa";

void init_triton_amd_passes_ttgpuir(py::module &&m) {
using namespace mlir::triton;
m.def("insert_sched_group_barriers", [](mlir::PassManager &pm) {
pm.addPass(createSchedGroupBarriersInsertionPass());
});
m.def("add_to_llvmir",
[](mlir::PassManager &pm, const std::string &arch, bool ftz) {
pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));
});
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
m.def("add_sched_group_barriers", [](mlir::PassManager &pm) {
pm.addPass(createSchedGroupBarriersLoweringPass());
});
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
const std::string &arch) {
pm.addPass(
Expand Down

0 comments on commit c29d621

Please sign in to comment.