Skip to content

Commit

Permalink
[AMD] Added instruction scheduling for the CK's V3 pipelining
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Oct 4, 2024
1 parent 06210a4 commit cf97e35
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
std::function<void(triton::gpu::LocalStoreOp, size_t, Type)>
/**
* The reference to a backend-specific callback for appending auxiliary data
* during `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount, Type llvmOpType)>
localStoreOpConversion = nullptr;
};

Expand All @@ -37,6 +45,10 @@ void populateElementwiseOpToLLVMPatterns(
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,10 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
++((*llvmOpCount).first);
(*llvmOpCount).second = vecTy;
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "OpIdx";

let summary = "An operand index attribute.";
let description = [{
An operand index attribute.

The attribute is a way to describe which input argument of the target
operation (e.g., `tt.dot`) the result of a given operation belongs to.
}];
Expand All @@ -47,15 +45,13 @@ def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> {
let assemblyFormat = "`<` $value `>`";
}

def TritonAMDGPU_LLVMInstructionCounterAttr : TritonAMDGPU_Attr<"LLVMInstructionCounter"> {
def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "LLVMInstructionCounter";

let mnemonic = "InstCounter";
let summary = "An instruction counter attribute.";
let description = [{
An instruction counter attribute.

The attribute holds the number of issued LLVM instructions of a specific kind as well as
the data type.
The attribute holds the number of issued LLVM instructions of a specific kind as well as
the data type.
}];

let parameters = (ins "uint32_t":$value, "Type":$type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
}];

let arguments = (ins
TritonAMDGPU_LLVMInstructionCounterAttr:$numDsReadsA,
TritonAMDGPU_LLVMInstructionCounterAttr:$numDsReadsB,
TritonAMDGPU_LLVMInstructionCounterAttr:$numDsWritesA,
TritonAMDGPU_LLVMInstructionCounterAttr:$numDsWritesB,
TritonAMDGPU_LLVMInstructionCounterAttr:$numGlobalLoadsA,
TritonAMDGPU_LLVMInstructionCounterAttr:$numGlobalLoadsB,
TritonAMDGPU_LLVMInstructionCounterAttr:$numMMAs
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
TritonAMDGPU_InstCounter:$numDsWritesB,
TritonAMDGPU_InstCounter:$numGlobalLoadsA,
TritonAMDGPU_InstCounter:$numGlobalLoadsB,
TritonAMDGPU_InstCounter:$numMMAs
);

let builders = [
OpBuilder<(ins), [{
auto ctx = $_state.getContext();
auto type = IntegerType::get(ctx, 32);
auto emptyAttr = amdgpu::LLVMInstructionCounterAttr::get(ctx, 0, type);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, emptyAttr);
}]>
Expand Down
6 changes: 4 additions & 2 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];
}

def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "SharedToDotOperandHelper.h"
#include "Utility.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"

using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "SharedToDotOperandHelper.h"
#include "Utility.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"

using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "TritonAMDGPUTransforms/MfmaGroup.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"

using namespace mlir;
using namespace mlir::triton;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"

namespace mlir::triton::AMD {
namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

Expand Down
128 changes: 117 additions & 11 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ namespace mlir::triton {
using namespace mlir;

namespace mlir::triton {

void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n,
unsigned k, Type elementType) {
auto *ctx = op->getContext();
auto mmaType = RankedTensorType::get({m, n, k}, elementType);
auto counterAttr =
amdgpu::LLVMInstructionCounterAttr::get(ctx, mmaCount, mmaType);
auto counterAttr = amdgpu::InstCounterAttr::get(ctx, mmaCount, mmaType);

op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
schedHint.setNumMMAsAttr(counterAttr);
Expand All @@ -30,8 +28,7 @@ void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n,
void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount,
Type type) {
MLIRContext *ctx = op->getContext();
auto counterAttr =
amdgpu::LLVMInstructionCounterAttr::get(ctx, globalLoadsCount, type);
auto counterAttr = amdgpu::InstCounterAttr::get(ctx, globalLoadsCount, type);

op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
auto opIdxAttr =
Expand All @@ -47,8 +44,7 @@ void setNumGeneratedGlobalLoads(triton::LoadOp op, size_t globalLoadsCount,
void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount,
Type type) {
auto *ctx = op->getContext();
auto counterAttr =
amdgpu::LLVMInstructionCounterAttr::get(ctx, dsReadsCount, type);
auto counterAttr = amdgpu::InstCounterAttr::get(ctx, dsReadsCount, type);

op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
Value dst = op.getResult();
Expand All @@ -67,8 +63,7 @@ void setNumGeneratedDsReads(gpu::LocalLoadOp op, size_t dsReadsCount,
void storeOpConversionCallback(triton::gpu::LocalStoreOp op,
size_t localStoreOpCount, Type type) {
MLIRContext *ctx = op->getContext();
auto counterAttr =
amdgpu::LLVMInstructionCounterAttr::get(ctx, localStoreOpCount, type);
auto counterAttr = amdgpu::InstCounterAttr::get(ctx, localStoreOpCount, type);

op->getBlock()->walk([&](amdgpu::InstructionSchedHint schedHint) {
auto opIdxAttr =
Expand Down Expand Up @@ -159,10 +154,116 @@ struct InstructionSchedHintsRewriter
.Case("default", SchedulingType::NONE)
.Case("iglp0", SchedulingType::IGLP0)
.Case("iglp1", SchedulingType::IGLP1)
.Case("ck_v3", SchedulingType::CK_V3)
.Default(SchedulingType::UNKNOWN);
}

enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN };
enum class SchedulingType : uint32_t {
NONE = 0,
IGLP0,
IGLP1,
CK_V3,
UNKNOWN
};

// This is the implementation of the CK's V3 pipelining (see
// see ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp).
// This scheduling requires 1x register and 1x LDS buffers combined with the
// local (LDS to registers) and global (HBN to registers) data prefetching.
// see:
// include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.h
void createCKV3Schedule(PatternRewriter &rewriter, Location loc,
amdgpu::InstructionSchedHint schedHint) {
const uint32_t numDsReadInstA = schedHint.getNumDsReadsA().getValue();
const uint32_t numDsReadInstB = schedHint.getNumDsReadsB().getValue();

const uint32_t numDsWriteInstA = schedHint.getNumDsWritesA().getValue();
const uint32_t numDsWriteInstB = schedHint.getNumDsWritesB().getValue();

const uint32_t numBufferLoadInstA =
schedHint.getNumGlobalLoadsA().getValue();
const uint32_t numBufferLoadInstB =
schedHint.getNumGlobalLoadsB().getValue();

const uint32_t numMfmaInst = schedHint.getNumMMAs().getValue();

auto mfmaType = cast<RankedTensorType>(schedHint.getNumMMAs().getType());
const uint32_t nPerXDL = mfmaType.getShape()[1];
const uint32_t mfmaCycle = nPerXDL == 16 ? 16 : 32;

auto dsReadsAType = cast<VectorType>(schedHint.getNumDsReadsA().getType());
auto dsReadsBType = cast<VectorType>(schedHint.getNumDsReadsB().getType());

const uint32_t dsReadAIssueCycle = dsReadsAType.getShape()[0] == 16 ? 8 : 4;
const uint32_t dsReadBIssueCycle = dsReadsBType.getShape()[0] == 16 ? 8 : 4;

const auto dsReadAMfmaRate =
(mfmaCycle - 4 + 2 * dsReadAIssueCycle - 1) / (2 * dsReadAIssueCycle);
const auto dsReadBMfmaRate =
(mfmaCycle - 4 + 2 * dsReadBIssueCycle - 1) / (2 * dsReadBIssueCycle);

const auto numDsreadAMfma =
(numDsReadInstA + dsReadAMfmaRate - 1) / dsReadAMfmaRate;
const auto numDsreadBMfma =
(numDsReadInstB + dsReadBMfmaRate - 1) / dsReadBMfmaRate;

// stage 1
const auto numMfmaStage1 = numMfmaInst - (numDsreadAMfma + numDsreadBMfma);
const auto num_mfma_per_issue =
numMfmaStage1 / (numBufferLoadInstA + numBufferLoadInstB);

const auto numDswritePerIssueA = numDsWriteInstA / numBufferLoadInstA;
const auto numDswritePerIssueB = numDsWriteInstB / numBufferLoadInstB;

for (size_t i = 0; i < numBufferLoadInstA; ++i) {
for (size_t idswrite = 0; idswrite < numDswritePerIssueA; ++idswrite) {
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1,
0);
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0);
}
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1,
0);
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA,
num_mfma_per_issue - numDswritePerIssueA, 0);
}

for (size_t i = 0; i < numBufferLoadInstB; ++i) {
for (size_t idswrite = 0; idswrite < numDswritePerIssueB; ++idswrite) {
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_WRITE, 1,
0);
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0);
}
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::VMEM_READ, 1,
0);
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA,
num_mfma_per_issue - numDswritePerIssueB, 0);
}

// stage 2
for (size_t i = 0; i < numDsreadAMfma; ++i) {
if ((numDsReadInstA - (i + 1) * dsReadAMfmaRate) >= dsReadAMfmaRate) {
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ,
dsReadAMfmaRate, 0);
} else {
createSchedGroupBarrier(
rewriter, loc, InstructionKindMask::DS_READ,
numDsReadInstA - (numDsreadAMfma - 1) * dsReadAMfmaRate, 0);
}
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0);
}

for (size_t i = 0; i < numDsreadBMfma; ++i) {
if ((numDsReadInstB - (i + 1) * dsReadBMfmaRate) >= dsReadBMfmaRate) {
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::DS_READ,
dsReadBMfmaRate, 0);
} else {
createSchedGroupBarrier(
rewriter, loc, InstructionKindMask::DS_READ,
numDsReadInstB - (numDsreadBMfma - 1) * dsReadBMfmaRate, 0);
}
createSchedGroupBarrier(rewriter, loc, InstructionKindMask::MFMA, 1, 0);
}
}

LogicalResult
matchAndRewrite(amdgpu::InstructionSchedHint instructionSchedHint,
Expand All @@ -180,7 +281,8 @@ struct InstructionSchedHintsRewriter
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
const bool limitSchedulingRange =
!(schedulingType == SchedulingType::IGLP0 ||
!(schedulingType == SchedulingType::NONE ||
schedulingType == SchedulingType::IGLP0 ||
schedulingType == SchedulingType::IGLP1);
Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
Expand All @@ -198,6 +300,10 @@ struct InstructionSchedHintsRewriter
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
break;
}
case SchedulingType::CK_V3: {

break;
}
case SchedulingType::NONE:
[[fallthrough]];
default: {
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

// The following functions are used to collect and set side-channel information
// during to LLVM conversion/lowering to facilitate instruction scheduling
// controls.
namespace mlir::triton {
void setNumGeneratedMMAs(DotOp op, size_t mmaCount, unsigned m, unsigned n,
unsigned k, Type elementType);
Expand Down
5 changes: 3 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ struct ConvertTritonAMDGPUToLLVM
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect,
mlir::ROCDL::ROCDLDialect>();
registry
.insert<LLVM::LLVMDialect, NVVM::NVVMDialect, mlir::ROCDL::ROCDLDialect,
mlir::triton::amdgpu::TritonAMDGPUDialect>();
}

void runOnOperation() override {
Expand Down

0 comments on commit cf97e35

Please sign in to comment.