Skip to content

Commit

Permalink
[AMD] Added instruction scheduling for the CK's V3 pipelining
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Oct 7, 2024
1 parent 5f9bb95 commit 2d9123e
Show file tree
Hide file tree
Showing 17 changed files with 507 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
constexpr int patternBenefitClampOptimizedPattern = 20;
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;

struct BackendCallbacks {
/**
* The reference to a backend-specific callback for appending auxiliary data
* during `LocalStoreOp` conversion.
*
* @param[in] op The reference to the re-written `LocalStoreOp`.
* @param[in] count The number of issued LLVM instructions.
* @param[in] type The input type of issued LLVM instructions.
*/
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
Type llvmOpType)>
localStoreOpConversion = nullptr;
};

void populateElementwiseOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);
// The given callback is invoked at the end of a successful rewrite. The
// callback receives 1) the current source op, 2) the number of issued LLVM
// instructions and 3) their input types. Each MLIR backend can provide a
// callback and, thus, handle backend-specific behaviors.
void populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);

void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down
10 changes: 5 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1360,11 +1360,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);
void storeDistributedToShared(
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
Expand Down
36 changes: 25 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(Location loc, Value src, Value dst,
Value adaptorSrc,
const SharedMemoryObject &smemObj,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
void lowerDistributedToShared(
Location loc, Value src, Value dst, Value adaptorSrc,
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
Expand All @@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo);
loc, rewriter, targetInfo, llvmOpCount);
}

struct LocalAllocOpConversion
Expand Down Expand Up @@ -185,12 +184,15 @@ struct LocalStoreOpConversion
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
using BackendCallbackType =
decltype(BackendCallbacks::localStoreOpConversion);

LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
BackendCallbackType backendCallback,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}
targetInfo(targetInfo), backendCallback(backendCallback) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
Expand All @@ -200,24 +202,36 @@ struct LocalStoreOpConversion
getTypeConverter()->convertType(op.getDst().getType().getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);

std::pair<size_t, Type> llvmOpCount;
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
adaptor.getSrc(), smemObj, getTypeConverter(),
rewriter, targetInfo);
rewriter, targetInfo, &llvmOpCount);

if (backendCallback)
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);

rewriter.eraseOp(op);
return success();
}

private:
const TargetInfoBase &targetInfo;
BackendCallbackType backendCallback;
};

} // namespace

void mlir::triton::populateMemoryOpToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);

auto backendCall =
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
benefit);
}
8 changes: 7 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
Type elemLlvmTy, ArrayRef<Value> srcVals,
Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target) {
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
Expand All @@ -419,7 +420,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
llvm::report_fatal_error("Failed to emit transfer from register to shared");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,31 @@ class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
: AttrDef<TritonAMDGPU_Dialect, name, traits, baseCppClass> {
}

def TritonAMDGPU_OpIdxAttr : TritonAMDGPU_Attr<"OpIdx"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "OpIdx";
let summary = "An operand index attribute.";
let description = [{
The attribute is a way to describe which input argument of the target
operation (e.g., `tt.dot`) the result of a given operation belongs to.
}];

let parameters = (ins "uint32_t":$value);
let assemblyFormat = "`<` $value `>`";
}

def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> {
let cppNamespace = "::mlir::triton::amdgpu";
let mnemonic = "InstCounter";
let summary = "An instruction counter attribute.";
let description = [{
The attribute holds the number of issued LLVM instructions of a specific kind as well as
the data type.
}];

let parameters = (ins "uint32_t":$value, "Type":$type);
let assemblyFormat = "`<` params `>`";
}


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def TritonAMDGPU_Dialect : Dialect {
}];

let dependentDialects = [];

let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 1;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,27 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
interleave for better instruction level parallelism.
}];

let assemblyFormat = [{attr-dict}];
let arguments = (ins
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
TritonAMDGPU_InstCounter:$numDsWritesB,
TritonAMDGPU_InstCounter:$numGlobalLoadsA,
TritonAMDGPU_InstCounter:$numGlobalLoadsB,
TritonAMDGPU_InstCounter:$numMMAs
);

let builders = [
OpBuilder<(ins), [{
auto ctx = $_state.getContext();
auto type = IntegerType::get(ctx, 32);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, type);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, emptyAttr);
}]>
];

let assemblyFormat = [{ attr-dict }];
}

#endif
6 changes: 4 additions & 2 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];
}

def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
Expand Down
6 changes: 6 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"

#include "llvm/ADT/TypeSwitch.h"

// clang-format off
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
Expand All @@ -44,5 +47,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() {
>();
}

#define GET_ATTRDEF_CLASSES
#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc"

#define GET_OP_CLASSES
#include "Dialect/TritonAMDGPU/IR/Ops.cpp.inc"
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "../PatternTritonGPUOpToLLVM.h"
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "SharedToDotOperandHelper.h"
#include "Utility.h"

Expand Down Expand Up @@ -330,6 +331,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);

VectorType loadVecTy = vec_ty(elemTy, elemsPerLoad);
for (int b = 0; b < repB; ++b) {
int operandSize = shape[rank - 1] * shape[rank - 2];
Value batchOffset = mul(i32_val(operandSize),
Expand All @@ -340,7 +342,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset;
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Expand All @@ -357,6 +358,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
}
}

for (auto op : tensor.getUsers()) {
if (auto localLoadOp = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op)) {
const size_t numDsReadsCount =
repB * numRepNonK * numRepK * loadsPerThread;
setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy);
}
}

MLIRContext *ctx = mfmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "../PatternTritonGPUOpToLLVM.h"
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "SharedToDotOperandHelper.h"
#include "Utility.h"

Expand Down Expand Up @@ -212,6 +213,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int loadsPerThread = offsets.size() / (numRepNonK * numRepK);
int elemsPerLoad = numElemsPerThreadPerRep / loadsPerThread;
assert(numElemsPerThreadPerRep % loadsPerThread == 0);
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
for (int b = 0; b < repB; ++b) {
int operandSize = shape[rank - 1] * shape[rank - 2];
Value batchOffset = mul(i32_val(operandSize),
Expand All @@ -221,7 +223,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
auto vecTy = vec_ty(resElemTy, numElemsPerThreadPerRep);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
loadOffset = add(loadOffset, batchOffset);
Expand All @@ -237,6 +238,14 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
}
}

for (auto op : tensor.getUsers()) {
if (auto localLoadOp = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op)) {
const size_t numDsReadsCount =
repB * numRepNonK * numRepK * loadsPerThread;
setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy);
}
}

MLIRContext *ctx = wmmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
Expand Down
10 changes: 9 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "../PatternTritonGPUOpToLLVM.h"
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "TritonAMDGPUTransforms/MfmaGroup.h"
#include "Utility.h"

#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"

using namespace mlir;
Expand Down Expand Up @@ -263,6 +263,14 @@ struct DotOpMFMAConversionHelper {
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

Type elemtTy = elemTyA;
const size_t mmaCount =
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
elemtTy);

rewriter.replaceOp(op, res);

return success();
Expand Down
5 changes: 5 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/

#include "../PatternTritonGPUOpToLLVM.h"
#include "../TritonAMDGPUToLLVM/SchedInstructions.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"

Expand Down Expand Up @@ -326,6 +327,10 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
Type structTy = LLVM::LLVMStructType::getLiteral(
wmmaLayout.getContext(), SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

const size_t mmaCount = numRepB * numRepM * numRepN * numRepK;
setNumGeneratedMMAs(op, mmaCount, mnkDim[0], mnkDim[1], mnkDim[2], elemTy);

rewriter.replaceOp(op, res);
return success();
}
Expand Down
Loading

0 comments on commit 2d9123e

Please sign in to comment.