Skip to content

Commit

Permalink
[AMD] Fixed bug in setNumGeneratedGlobalLoads
Browse files Browse the repository at this point in the history
* add a test for the presence of OpIdx attribute
  • Loading branch information
ravil-mobile committed Oct 22, 2024
1 parent 9e8c9a9 commit 00ab1fe
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 45 deletions.
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.num_stages, options.instruction_sched_variant)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
Expand Down
3 changes: 2 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant);
createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages,
std::string variant);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruc

def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(\"\")";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*numStages=*/2, /*variant=*/\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
"instruction scheduling variant">,
];
Expand Down
95 changes: 63 additions & 32 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount,
triton::amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type);

op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) {
auto opIdxAttr = cast<triton::amdgpu::OpIdxAttr>(
op->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic()));
assert(opIdxAttr.getValue() < 2);
if (opIdxAttr.getValue() == 0)
schedHint.setNumGlobalLoadsAAttr(counterAttr);
else
schedHint.setNumGlobalLoadsBAttr(counterAttr);
if (auto opIdxAttr = op->getAttrOfType<triton::amdgpu::OpIdxAttr>(
triton::amdgpu::OpIdxAttr::getMnemonic())) {
assert(opIdxAttr.getValue() < 2);
if (opIdxAttr.getValue() == 0)
schedHint.setNumGlobalLoadsAAttr(counterAttr);
else
schedHint.setNumGlobalLoadsBAttr(counterAttr);
}
});
}

Expand Down Expand Up @@ -72,15 +73,28 @@ void storeOpConversionCallback(triton::gpu::LocalStoreOp op,
triton::amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type);

op->getBlock()->walk([&](triton::amdgpu::InstructionSchedHint schedHint) {
auto opIdxAttr = op->getAttrOfType<triton::amdgpu::OpIdxAttr>(
triton::amdgpu::OpIdxAttr::getMnemonic());
assert(opIdxAttr.getValue() < 2);
if (opIdxAttr.getValue() == 0)
schedHint.setNumDsWritesAAttr(counterAttr);
else
schedHint.setNumDsWritesBAttr(counterAttr);
if (auto opIdxAttr = op->getAttrOfType<triton::amdgpu::OpIdxAttr>(
triton::amdgpu::OpIdxAttr::getMnemonic())) {
assert(opIdxAttr.getValue() < 2);
if (opIdxAttr.getValue() == 0)
schedHint.setNumDsWritesAAttr(counterAttr);
else
schedHint.setNumDsWritesBAttr(counterAttr);
}
});
}

llvm::FailureOr<triton::DotOp> hasSingleDotOp(scf::ForOp forOp) {
triton::DotOp dotOp = nullptr;
size_t dotCounter = 0;
forOp->walk(
[&dotOp, &dotCounter](triton::DotOp op) { dotOp = op, ++dotCounter; });

if (dotCounter == 1)
return dotOp;

return llvm::failure();
}
} // namespace mlir::triton

namespace {
Expand Down Expand Up @@ -119,8 +133,9 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
struct InstructionSchedHintsRewriter
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {

InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant)
: OpRewritePattern(ctx) {
InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, int32_t numStages,
std::string variant)
: OpRewritePattern(ctx), numStages(numStages) {
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });

Expand All @@ -130,6 +145,13 @@ struct InstructionSchedHintsRewriter
.Case("iglp1", SchedulingType::IGLP1)
.Case("ck_v3", SchedulingType::CK_V3)
.Default(SchedulingType::UNKNOWN);

if (this->numStages < 2) {
this->schedulingType = SchedulingType::NONE;
llvm::dbgs() << "[" << getDebugName() << "]: "
<< "ignoring instruction scheduling due to a very low num. "
"stages value. Must be >= 2\n";
}
}

enum class SchedulingType : uint32_t {
Expand Down Expand Up @@ -160,6 +182,11 @@ struct InstructionSchedHintsRewriter
const uint32_t numBufferLoadInstB =
schedHint.getNumGlobalLoadsB().getValue();

assert(numBufferLoadInstA &&
"buffer load count for tile A must be initialized");
assert(numBufferLoadInstB &&
"buffer load count for tile B must be initialized");

const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue();

auto mfmaType = cast<RankedTensorType>(schedHint.getNumMMAs().getType());
Expand All @@ -184,7 +211,7 @@ struct InstructionSchedHintsRewriter

// stage 1
const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma);
const auto num_mfma_per_issue =
const auto numMfmaPerIssue =
numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB);

const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA;
Expand All @@ -203,7 +230,7 @@ struct InstructionSchedHintsRewriter
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0);
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma,
num_mfma_per_issue - numDswritePerIssueA, 0);
numMfmaPerIssue - numDswritePerIssueA, 0);
}

for (size_t i = 0; i < numBufferLoadInstB; ++i) {
Expand All @@ -219,7 +246,7 @@ struct InstructionSchedHintsRewriter
rewriter, loc, mlir::amdgpu::sched_barrier_opt_enum::vmem_read, 1, 0);
createSchedGroupBarrier(rewriter, loc,
mlir::amdgpu::sched_barrier_opt_enum::mfma_wmma,
num_mfma_per_issue - numDswritePerIssueB, 0);
numMfmaPerIssue - numDswritePerIssueB, 0);
}

// stage 2
Expand Down Expand Up @@ -308,14 +335,17 @@ struct InstructionSchedHintsRewriter
}

private:
int32_t numStages;
SchedulingType schedulingType;
};

struct TritonAMDGPULowerInstructionSchedHints
: public triton::impl::TritonAMDGPULowerInstructionSchedHintsBase<
TritonAMDGPULowerInstructionSchedHints> {

explicit TritonAMDGPULowerInstructionSchedHints(std::string variant) {
explicit TritonAMDGPULowerInstructionSchedHints(int32_t numStages,
std::string variant) {
this->numStages = numStages;
this->variant = variant;
}

Expand All @@ -331,7 +361,8 @@ struct TritonAMDGPULowerInstructionSchedHints
target.addLegalOp<ROCDL::SchedGroupBarrier>();

RewritePatternSet patterns(ctx);
patterns.add<InstructionSchedHintsRewriter>(ctx, this->variant);
patterns.add<InstructionSchedHintsRewriter>(ctx, this->numStages,
this->variant);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand All @@ -343,24 +374,22 @@ struct TritonAMDGPULowerInstructionSchedHints
struct TritonAMDGPUInsertInstructionSchedHints
: public triton::impl::TritonAMDGPUInsertInstructionSchedHintsBase<
TritonAMDGPUInsertInstructionSchedHints> {

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mod.walk([this, ctx](scf::ForOp forOp) {
triton::DotOp dot = nullptr;
size_t dotCounter = 0;
forOp->walk([&dot, &dotCounter](triton::DotOp op) {
dot = op;
++dotCounter;
});
auto maybeSingleDotOp = hasSingleDotOp(forOp);

// Note, instruction schedule barriers are inserted only in the case of
// a single `tt.dot` op in a `scf::ForOp` scope in the current
// implementation.
if (dotCounter == 1) {
if (llvm::succeeded(maybeSingleDotOp)) {
triton::DotOp dotOp = maybeSingleDotOp.value();
mlir::OpBuilder rewriter(ctx);
rewriter.setInsertionPointAfter(dot);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dot->getLoc());
rewriter.setInsertionPointAfter(dotOp);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dotOp->getLoc());
}
});
}
Expand All @@ -369,8 +398,10 @@ struct TritonAMDGPUInsertInstructionSchedHints

namespace mlir::triton {
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(std::string variant) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(variant);
createTritonAMDGPULowerInstructionSchedHintsPass(int32_t numStages,
std::string variant) {
return std::make_unique<TritonAMDGPULowerInstructionSchedHints>(numStages,
variant);
}

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t numDsReadsCount,
Type type);
void storeOpConversionCallback(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type type);
llvm::FailureOr<triton::DotOp> hasSingleDotOp(scf::ForOp forOp);
} // namespace mlir::triton

#endif
15 changes: 7 additions & 8 deletions third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -368,17 +369,15 @@ FailureOr<Operation *> rewindUnaryOps(Value value) {
return failure();
}

// Annotate each `tt.LoadOp` instruction with its corresponding gemm operand
// index. Note, this is a part of the instruction scheduling routine. Currently,
// we support `forOp`s which contain only a single `tt.DotOp` in the bodies.
void StreamPipeliner::labelLoadOpsForTritonDot() {
mlir::MLIRContext *ctx = forOp->getContext();
auto maybeSingleDotOp = triton::hasSingleDotOp(forOp);

triton::DotOp dotOp;
size_t dotCounter = 0;
forOp->walk([&dotCounter, &dotOp](triton::DotOp op) {
dotOp = op;
++dotCounter;
});

if (dotCounter == 1) {
if (llvm::succeeded(maybeSingleDotOp)) {
triton::DotOp dotOp = maybeSingleDotOp.value();
for (auto [opIdx, dotOperand] : llvm::enumerate(dotOp->getOperands())) {
auto maybeLoadOp = rewindUnaryOps<triton::LoadOp>(dotOperand);
if (llvm::succeeded(maybeLoadOp)) {
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
pm.addPass(createTritonAMDGPUInsertInstructionSchedHintsPass());
});
m.def("lower_instruction_sched_hints",
[](mlir::PassManager &pm, std::string variant) {
pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(variant));
[](mlir::PassManager &pm, int32_t numStages, std::string variant) {
pm.addPass(createTritonAMDGPULowerInstructionSchedHintsPass(numStages,
variant));
});
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
const std::string &arch) {
Expand Down

0 comments on commit 00ab1fe

Please sign in to comment.