From a908e9283385399c378bfef932b0ec1f2a171476 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 2 Jul 2024 17:06:13 +0200 Subject: [PATCH 1/3] Relax dot operand constrains with FMA based dot This PR: - Refactors FMA dot implementation - Supports dot3d in FMA path - Fixes several issues in operand offset computation - Enables small dot operands --- .../Conversion/TritonGPUToLLVM/Utility.h | 16 + include/triton/Dialect/TritonGPU/IR/Dialect.h | 10 + lib/Analysis/Utility.cpp | 12 +- .../SharedToDotOperandFMA.cpp | 336 ++++++++---------- .../TritonGPUToLLVM/DotOpToLLVM/FMA.cpp | 98 +++-- .../TritonToTritonGPUPass.cpp | 5 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 52 +-- python/test/unit/language/test_core.py | 20 +- third_party/amd/backend/compiler.py | 28 +- third_party/nvidia/backend/compiler.py | 11 +- .../SharedToDotOperandMMAv2.cpp | 17 - 11 files changed, 313 insertions(+), 292 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index b209a02b4bb3..6b17387e80db 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1473,6 +1473,22 @@ inline bool isLayoutMmaV1(Attribute layout) { return isMmaV1; } +inline SharedMemoryObject +getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, + SharedMemoryObject smemObj, + ArrayRef shape) { + auto strides = smemObj.getStrides(); + auto offsets = smemObj.getOffsets(); + auto rank = strides.size(); + if (rank == 3) + return smemObj; + strides.insert(strides.begin(), i32_val(shape[0] * shape[1])); + offsets.insert(offsets.begin(), i32_val(0)); + auto expandedSmemObj = SharedMemoryObject( + smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets); + return expandedSmemObj; +} + } // namespace mlir #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 16e6506e5bad..361d042fc19d 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -131,6 +131,16 @@ void dumpHWLayout(RankedTensorType tensorType); // Return a string representation of the layout of the tensor. std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); +template +llvm::SmallVector expandMatrixShapeWithBatch(llvm::ArrayRef s) { + llvm::SmallVector expanded(3 - s.size(), 1); + expanded.append(s.begin(), s.end()); + return expanded; +} + +llvm::SmallVector +expandMatrixOrderWithBatch(llvm::ArrayRef o); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 933f062d8191..197b0e4a2713 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -482,12 +482,18 @@ bool supportMMA(triton::DotOp op, int version) { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 auto aElemTy = op.getA().getType().getElementType(); auto bElemTy = op.getB().getType().getElementType(); + auto retType = op.getType(); + auto retShapePerCTA = getShapePerCTA(retType); + auto rank = retShapePerCTA.size(); + auto aTensorTy = cast(op.getA().getType()); + auto aShape = aTensorTy.getShape(); + auto encoding = cast(aTensorTy.getEncoding()); + if (retShapePerCTA[rank - 2] < 16 || retShapePerCTA[rank - 1] < 16 || + aShape[rank - 1] < 16) + return false; if (version == 3) { if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; - auto retType = op.getType(); - auto retShapePerCTA = getShapePerCTA(retType); - auto rank = retShapePerCTA.size(); auto mod = op->getParentOfType(); int numWarps = TritonGPUDialect::getNumWarps(mod); // TODO(Keren): for now, fallback to MMAv2 if handling batch matmul. diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index b7bd5fbc3432..d019ea9b787c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -1,5 +1,6 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" using ValueTable = std::map, Value>; using ::mlir::LLVM::delinearize; @@ -7,6 +8,8 @@ using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; @@ -15,47 +18,6 @@ using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; -SmallVector -getThreadIds(Value threadId, ArrayRef shapePerCTATile, - ArrayRef sizePerThread, ArrayRef order, - ConversionPatternRewriter &rewriter, Location loc) { - int dim = order.size(); - SmallVector threadIds(dim); - for (unsigned k = 0; k < dim - 1; k++) { - Value dimK = i32_val(shapePerCTATile[order[k]] / sizePerThread[order[k]]); - Value rem = urem(threadId, dimK); - threadId = udiv(threadId, dimK); - threadIds[order[k]] = rem; - } - Value dimK = i32_val(shapePerCTATile[order[dim - 1]]); - threadIds[order[dim - 1]] = urem(threadId, dimK); - return threadIds; -} - -// Get shapePerCTATile for M or N axis. -int getShapePerCTATileForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto shapePerCTATile = getShapePerCTATile(layout); - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - return isM ? mShapePerCTATile : nShapePerCTATile; -} - -// Get sizePerThread for M or N axis. -int getSizePerThreadForMN(BlockedEncodingAttr layout, bool isM) { - auto order = layout.getOrder(); - auto sizePerThread = getSizePerThread(layout); - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - return isM ? mSizePerThread : nSizePerThread; -} - Value getStructFromValueTable(ArrayRef vals, ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter *typeConverter, @@ -71,154 +33,154 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } -ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA, - int sizePerThread, - ConversionPatternRewriter &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter, - Type type) { - ValueTable res; - auto elems = unpackLLElements(loc, val, rewriter); - int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTA) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } - } - return res; +SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, + Location loc, SmallVector rawIndices, + SharedEncodingAttr layout) { + const auto &order = layout.getOrder(); + auto rank = order.size(); + + if (layout.getMaxPhase() == 1) + return rawIndices; + + auto vec = i32_val(layout.getVec()); + auto perPhase = i32_val(layout.getPerPhase()); + auto maxPhase = i32_val(layout.getMaxPhase()); + + auto fastIdx = rawIndices[order[0]]; + auto secondIdx = rawIndices[order[1]]; + // Original algorithm taken from getSwizzledSharedPtrs function + // (TritonGPUToLLVMBase.h) + // + // phase = (secondIdx // perPhase) % maxPhase + // swizzledGroup = ((fastIdx // vec) ^ phase) * vec + // groupRemainder = fastIdx % vec + // colOff = swizzledGroup + groupRemainder + auto phase = urem(udiv(secondIdx, perPhase), maxPhase); + auto swizzledGroup = mul(xor_(udiv(fastIdx, vec), phase), vec); + auto groupRemainder = urem(fastIdx, vec); + auto colOff = add(swizzledGroup, groupRemainder); + + SmallVector swizzledIndices = rawIndices; + swizzledIndices[order[0]] = colOff; + + return swizzledIndices; } -Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto aTensorTy = cast(A.getType()); - auto aLayout = cast(aTensorTy.getEncoding()); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - - auto aOrder = aLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isARow = aOrder[0] == 1; - - auto aSmem = getSharedMemoryObjectFromStruct( - loc, llA, typeConverter->convertType(aTensorTy.getElementType()), +Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, + Value thread, Location loc, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const int dotOpNo) { + auto ctx = dotOp.getContext(); + const int bDim = 0; + const int kDim = dotOpNo == 0 ? 2 : 1; + const int nonKDim = dotOpNo == 0 ? 1 : 2; + auto opTensorTy = cast(dotOp.getType()); + auto opLayout = cast(opTensorTy.getEncoding()); + auto opShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(opTensorTy))); + + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); + + auto origSmem = getSharedMemoryObjectFromStruct( + loc, llA, typeConverter->convertType(opTensorTy.getElementType()), rewriter); - Value strideAM = aSmem.strides[0]; - Value strideAK = aSmem.strides[1]; - Value strideA0 = isARow ? strideAK : strideAM; - Value strideA1 = isARow ? strideAM : strideAK; - int aNumPtr = 8; - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value mContig = i32_val(sizePerThread[order[1]]); + auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem, + opTensorTy.getShape()); + auto strides = smem.strides; + int B = opShapePerCTA[bDim]; + int K = opShapePerCTA[kDim]; + int NonK = opShapePerCTA[nonKDim]; + + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto threadsPerWarp = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getThreadsPerWarp())); + auto warpsPerCTA = + expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdM = threadIds[0]; - - Value offA0 = isARow ? _0 : mul(threadIdM, mContig); - Value offA1 = isARow ? mul(threadIdM, mContig) : _0; - SmallVector aOff(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) { - aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); - } - auto elemTy = typeConverter->convertType(aTensorTy.getElementType()); - - Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector aPtrs(aNumPtr); - for (int i = 0; i < aNumPtr; ++i) - aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]); - - SmallVector vas; - - int mShapePerCTATile = getShapePerCTATileForMN(dLayout, true /*isM*/); - int mSizePerThread = getSizePerThreadForMN(dLayout, true /*isM*/); - - for (unsigned k = 0; k < K; ++k) - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) { - Value offset = - add(mul(i32_val(m + mm), strideAM), mul(i32_val(k), strideAK)); - Value pa = gep(ptrTy, elemTy, aPtrs[0], offset); - Value va = load(elemTy, pa); - vas.emplace_back(va); - } - - return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy); -} - -Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, - Location loc, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { - auto bTensorTy = cast(B.getType()); - auto bLayout = cast(bTensorTy.getEncoding()); - auto bShapePerCTA = getShapePerCTA(bTensorTy); - - auto bOrder = bLayout.getOrder(); - auto order = dLayout.getOrder(); - - bool isBRow = bOrder[0] == 1; - - auto bSmem = getSharedMemoryObjectFromStruct( - loc, llB, typeConverter->convertType(bTensorTy.getElementType()), - rewriter); - Value strideBN = bSmem.strides[1]; - Value strideBK = bSmem.strides[0]; - Value strideB0 = isBRow ? strideBN : strideBK; - Value strideB1 = isBRow ? strideBK : strideBN; - int bNumPtr = 8; - int K = bShapePerCTA[0]; - int N = bShapePerCTA[1]; - - auto shapePerCTATile = getShapePerCTATile(dLayout); - auto sizePerThread = getSizePerThread(dLayout); - - Value _0 = i32_val(0); - - Value nContig = i32_val(sizePerThread[order[0]]); - - // threadId in blocked layout - auto threadIds = getThreadIds(thread, shapePerCTATile, sizePerThread, order, - rewriter, loc); - Value threadIdN = threadIds[1]; - - Value offB0 = isBRow ? mul(threadIdN, nContig) : _0; - Value offB1 = isBRow ? _0 : mul(threadIdN, nContig); - SmallVector bOff(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) { - bOff[i] = add(mul(offB0, strideB0), mul(offB1, strideB1)); - } - auto elemTy = typeConverter->convertType(bTensorTy.getElementType()); - - Type ptrTy = ptr_ty(rewriter.getContext(), 3); - SmallVector bPtrs(bNumPtr); - for (int i = 0; i < bNumPtr; ++i) - bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]); - - SmallVector vbs; - - int nShapePerCTATile = getShapePerCTATileForMN(dLayout, false /*isM*/); - int nSizePerThread = getSizePerThreadForMN(dLayout, false /*isM*/); - - for (unsigned k = 0; k < K; ++k) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - Value offset = - add(mul(i32_val(n + nn), strideBN), mul(i32_val(k), strideBK)); - Value pb = gep(ptrTy, elemTy, bPtrs[0], offset); - Value vb = load(elemTy, pb); - vbs.emplace_back(vb); - } - - return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy); + auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); + auto laneId = urem(thread, warpSize); + auto warpId = udiv(thread, warpSize); + auto laneIds = + mlir::LLVM::delinearize(rewriter, loc, laneId, threadsPerWarp, order); + auto warpIds = + mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order); + auto sizePerWarpB = sizePerThread[bDim] * threadsPerWarp[bDim]; + auto sizePerWarpNonK = sizePerThread[nonKDim] * threadsPerWarp[nonKDim]; + + Value bTileOffset = mul(laneIds[bDim], i32_val(sizePerThread[bDim])); + bTileOffset = add(bTileOffset, mul(warpIds[bDim], i32_val(sizePerWarpB))); + Value nonKTileOffset = mul(laneIds[nonKDim], i32_val(sizePerThread[nonKDim])); + nonKTileOffset = + add(nonKTileOffset, mul(warpIds[nonKDim], i32_val(sizePerWarpNonK))); + + auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); + Type ptrTy = ptr_ty(ctx, 3); + + unsigned vectorSize = order[0] == kDim ? K : sizePerThread[order[0]]; + if (opLayout.getMaxPhase() > 1) + vectorSize = std::min(vectorSize, opLayout.getVec()); + // limit vector size with maximum width of load available on hardware + // TODO: get maximum vector size from target hardware info + vectorSize = std::min(16u, vectorSize); + auto vecTy = vec_ty(elemTy, vectorSize); + + unsigned dimStep[3] = {1, 1, 1}; + dimStep[order[0]] = vectorSize; + + int shapePerCTABTile = shapePerCTATile[bDim]; + int shapePerCTANonKTile = shapePerCTATile[nonKDim]; + int sizeBPerThread = sizePerThread[bDim]; + int sizeNonKPerThread = sizePerThread[nonKDim]; + int numBTiles = std::max(1, B / shapePerCTABTile); + int numNonKTiles = std::max(1, NonK / shapePerCTANonKTile); + + SmallVector opValues(numBTiles * sizeBPerThread * K * numNonKTiles * + sizeNonKPerThread); + + for (unsigned bTile = 0; bTile < numBTiles; ++bTile) + for (unsigned b = 0; b < sizeBPerThread; b += dimStep[bDim]) + for (unsigned k = 0; k < K; k += dimStep[kDim]) + for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) + for (unsigned nonK = 0; nonK < sizeNonKPerThread; + nonK += dimStep[nonKDim]) { + SmallVector rawIndices(3); + rawIndices[bDim] = + add(bTileOffset, i32_val(bTile * shapePerCTABTile + b)); + rawIndices[nonKDim] = add( + nonKTileOffset, i32_val(nonKTile * shapePerCTANonKTile + nonK)); + rawIndices[kDim] = i32_val(k); + + SmallVector swizzledIndices = + swizzleIndices(rewriter, loc, rawIndices, opLayout); + + Value offset = i32_val(0); + for (int dim = 0; dim < order.size(); ++dim) + offset = add(offset, mul(urem(swizzledIndices[dim], + i32_val(opShapePerCTA[dim])), + strides[dim])); + + Value elemAddr = gep(ptrTy, elemTy, smem.base, offset); + Value vecAddr = bitcast(elemAddr, ptr_ty(ctx, 3)); + Value vec = load(vecTy, elemAddr); + for (int elem = 0; elem < vectorSize; ++elem) { + int outIdx[3] = {}; + outIdx[bDim] = bTile * sizeBPerThread + b; + outIdx[kDim] = k; + outIdx[nonKDim] = nonKTile * sizeNonKPerThread + nonK; + outIdx[order[0]] += elem; + int idx = (outIdx[bDim] * K + outIdx[kDim]) * numNonKTiles * + sizeNonKPerThread + + outIdx[nonKDim]; + opValues[idx] = extract_element(elemTy, vec, i32_val(elem)); + } + } + + return getStructFromValueTable(opValues, rewriter, loc, typeConverter, + elemTy); } namespace SharedToDotOperandFMA { @@ -226,9 +188,7 @@ Value convertLayout(int opIdx, Value val, Value llVal, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter) { - if (opIdx == 0) - return loadAFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); - else - return loadBFMA(val, llVal, dLayout, thread, loc, typeConverter, rewriter); + return loadFMAOp(val, llVal, dLayout, thread, loc, typeConverter, rewriter, + opIdx); } } // namespace SharedToDotOperandFMA diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp index afb5bf01d48b..29bb10d4a5a1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -1,29 +1,30 @@ #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; using namespace mlir::triton; +using namespace ::mlir::triton::gpu; -using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; +using ::mlir::triton::gpu::getSizePerThread; -using ValueTableFMA = std::map, Value>; +using ValueTableFMA = std::map, Value>; static ValueTableFMA -getValueTableFromStructFMA(Value val, int K, int n0, int shapePerCTATile, - int sizePerThread, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Type type) { +getValueTableFromStructFMA(Value val, int batch, int nonK, int K, + ConversionPatternRewriter &rewriter, Location loc) { ValueTableFMA res; auto elems = unpackLLElements(loc, val, rewriter); + assert(elems.size() == K * nonK * batch); int index = 0; - for (unsigned k = 0; k < K; ++k) { - for (unsigned m = 0; m < n0; m += shapePerCTATile) - for (unsigned mm = 0; mm < sizePerThread; ++mm) { - res[{m + mm, k}] = elems[index++]; - } - } + for (unsigned b = 0; b < batch; ++b) + for (unsigned k = 0; k < K; ++k) + for (unsigned i = 0; i < nonK; ++i) + res[{b, i, k}] = elems[index++]; return res; } @@ -39,61 +40,56 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto D = op.getResult(); auto aTensorTy = cast(A.getType()); - auto bTensorTy = cast(B.getType()); auto dTensorTy = cast(D.getType()); + auto dElemTy = dTensorTy.getElementType(); - auto aShapePerCTA = getShapePerCTA(aTensorTy); - auto bShapePerCTA = getShapePerCTA(bTensorTy); + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); BlockedEncodingAttr dLayout = cast(dTensorTy.getEncoding()); - auto order = dLayout.getOrder(); + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); Value llA = adaptor.getA(); Value llB = adaptor.getB(); - auto sizePerThread = getSizePerThread(dLayout); - auto shapePerCTATile = getShapePerCTATile(dLayout); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); - int K = aShapePerCTA[1]; - int M = aShapePerCTA[0]; - int N = bShapePerCTA[1]; + int K = aShapePerCTA[2]; - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; + unsigned retSize[3]; + for (int i = 0; i < 3; ++i) { + unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; + numRep = std::max(static_cast(1), numRep); + retSize[i] = numRep * sizePerThread[i]; + } auto has = - getValueTableFromStructFMA(llA, K, M, mShapePerCTATile, mSizePerThread, - rewriter, loc, typeConverter, aTensorTy); + getValueTableFromStructFMA(llA, retSize[0], retSize[1], K, rewriter, loc); auto hbs = - getValueTableFromStructFMA(llB, K, N, nShapePerCTATile, nSizePerThread, - rewriter, loc, typeConverter, bTensorTy); + getValueTableFromStructFMA(llB, retSize[0], retSize[2], K, rewriter, loc); SmallVector ret = cc; - bool isCRow = order[0] == 1; - - for (unsigned k = 0; k < K; k++) { - for (unsigned m = 0; m < M; m += mShapePerCTATile) - for (unsigned n = 0; n < N; n += nShapePerCTATile) - for (unsigned mm = 0; mm < mSizePerThread; ++mm) - for (unsigned nn = 0; nn < nSizePerThread; ++nn) { - int mIdx = m / mShapePerCTATile * mSizePerThread + mm; - int nIdx = n / nShapePerCTATile * nSizePerThread + nn; - - int z = isCRow - ? mIdx * N / nShapePerCTATile * mSizePerThread + nIdx - : nIdx * M / mShapePerCTATile * nSizePerThread + mIdx; - ret[z] = rewriter.create(loc, has[{m + mm, k}], - hbs[{n + nn, k}], ret[z]); - } - } + + for (unsigned b = 0; b < retSize[0]; ++b) + for (unsigned m = 0; m < retSize[1]; ++m) + for (unsigned n = 0; n < retSize[2]; ++n) { + unsigned idx[] = {b, m, n}; + unsigned linearIdx = 0; + for (auto dim : llvm::reverse(order)) { + linearIdx = linearIdx * retSize[dim] + idx[dim]; + } + for (unsigned k = 0; k < K; ++k) { + ret[linearIdx] = rewriter.create( + loc, has[{b, m, k}], hbs[{b, n, k}], ret[linearIdx]); + } + } auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); rewriter.replaceOp(op, res); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 4aa2712ec939..daa788d5c9af 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -235,6 +235,11 @@ struct TritonDotPattern : public OpConversionPattern { retSizePerThread[rank - 1] = 4; retSizePerThread[rank - 2] = 4; } + retSizePerThread[rank - 1] = std::min( + retSizePerThread[rank - 1], static_cast(origShape[rank - 1])); + retSizePerThread[rank - 2] = std::min( + retSizePerThread[rank - 2], static_cast(origShape[rank - 2])); + SmallVector retOrder(rank); for (unsigned i = 0; i < rank; ++i) retOrder[i] = rank - 1 - i; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d5b5d459a910..db75ea0c5c07 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -939,29 +939,26 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, getKWidth(), getOpIdx()); } if (auto blockedLayout = mlir::dyn_cast(getParent())) { - auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); - auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); - - int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; - int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; - - bool isM = getOpIdx() == 0; - - int mSizePerThread = - order[0] == 1 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int nSizePerThread = - order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]]; - int sizePerThreadMN = isM ? mSizePerThread : nSizePerThread; - - int mShapePerCTATile = - order[0] == 1 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int nShapePerCTATile = - order[0] == 0 ? shapePerCTATile[order[1]] : shapePerCTATile[order[0]]; - int shapePerCTAMNTile = isM ? mShapePerCTATile : nShapePerCTATile; - - return K * std::max(otherDim / shapePerCTAMNTile, 1) * sizePerThreadMN; + auto shapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(*this, shape))); + auto shapePerCTATile = expandMatrixShapeWithBatch( + ArrayRef(::getShapePerCTATile(blockedLayout))); + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(::getSizePerThread(blockedLayout))); + + int batchDim = 0; + int kDim = getOpIdx() == 0 ? 2 : 1; + int nonKDim = getOpIdx() == 0 ? 1 : 2; + + int batchSize = + std::max(shapePerCTA[batchDim] / shapePerCTATile[batchDim], 1) * + sizePerThread[batchDim]; + int kSize = shapePerCTA[kDim]; + int nonKSize = + std::max(shapePerCTA[nonKDim] / shapePerCTATile[nonKDim], 1) * + sizePerThread[nonKDim]; + + return batchSize * kSize * nonKSize; } llvm_unreachable("unknown dot operand parent layout"); return 0; @@ -3165,6 +3162,15 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, return layoutStr; } +llvm::SmallVector +mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef o) { + int oldRank = o.size(); + llvm::SmallVector expanded(3, 0); + for (int i = 0; i < oldRank; ++i) + expanded[i] += o[i] + 3 - oldRank; + return expanded; +} + void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) { llvm::errs() << getLayoutStr(tensorType, /*useHWPointOfView=*/false); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9e5ff8a2ce37..2c9cc77fff92 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3094,7 +3094,11 @@ def convert_fp8_to_fp32(x, device, dtype_str): ([(16, 16, 8, 4, False, False, 'None', 'ieee', 'float32', 'float32', 1), (32, 16, 8, 4, False, False, 'None', 'ieee', 'float16', 'float16', 1)] if "gfx9" in get_arch() else []) + [(128, 128, 64, 4, False, False, 'chain-dot', 'ieee', float8_type, 'float32', 1) - for float8_type in ["float8e5", "float8e4nv"]]) + for float8_type in ["float8e5", "float8e4nv"]] + + [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1) + for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, num_ctas, device): if is_interpreter(): @@ -3284,6 +3288,9 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid return # make sure ld/st are vectorized ptx = pgm.asm['ptx'] + is_fma = K < 16 or N < 16 or M < 16 + if is_fma: + return if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): # XXX: skip small sizes because they are not vectorized assert 'ld.global.v4' in ptx @@ -3327,7 +3334,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + # Large block sizes - [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')]) + [(4, 4, 128, 128, 64, 64, 64, 'float16', 'float16')] + + # Small block sizes + [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) + for B in [1, 2, 8] + for num_warps in [1, 2, 4] + for BLOCK_M, BLOCK_N in [(1, 32), (32, 2), (8, 8)] + for M, N, K in [(32, 32, 32)] + for in_dtype_str, out_dtype_str in [('float16', 'float16'), ('float32', 'float32')]]) def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str, device): if is_hip(): # hip does not support tf32 precision, so use ieee for all tests @@ -3398,6 +3412,8 @@ def kernel( if in_dtype_str == 'int8': out = numpy_random((B, M, N), dtype_str='int32', rs=rs) else: + x *= 0.1 + y *= 0.1 out = numpy_random((B, M, N), dtype_str=out_dtype_str, rs=rs) x_tri = to_triton(x, device=device) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 83b3fed52ac8..386ae927a26b 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -13,16 +13,30 @@ def min_dot_size(target: GPUTarget): + + def fma_supported(lhsType, rhsType): + return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) + + def gfx94_limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 3.0 supports k==8 in all mfma variants except for int8 + # (where the smallest `k` supported is 16) + return (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else (16, 16, 8) + + def gfx9_limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + # CDNA 2.0 always supports `k==8` + return (16, 16, 8) + arch_str = target.arch - # CDNA 3.0 supports k==8 in all mfma variants except for int8 - # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8) - # CDNA 2.0 always supports `k==8` + return gfx94_limits if "gfx9" in arch_str: - return lambda lhsType, rhsType: (16, 16, 8) - # Other architectures will only support 16,16,16 - return lambda lhsType, rhsType: (16, 16, 16) + return gfx9_limits + # Other architectures will only support 16,16,16 with mfma instructions + return lambda lhsType, rhsType: (1, 1, 1) if fma_supported(lhsType.scalar, rhsType.scalar) else (16, 16, 16) @dataclass(frozen=True) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index ea1d79f9ba93..1e7ee03ef7c2 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -15,7 +15,16 @@ def min_dot_size(target: GPUTarget): - return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) + + def fma_supported(lhsType, rhsType): + return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) + + def limits(lhsType, rhsType): + if fma_supported(lhsType.scalar, rhsType.scalar): + return (1, 1, 1) + return (16, 16, 16) + + return limits @functools.lru_cache() diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index d1086c189d33..09162f5c6e9a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -743,23 +743,6 @@ MemDescType getExpandedDesc(MemDescType descTy) { return expandedDesc; } -SharedMemoryObject -getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, - SharedMemoryObject smemObj, - ArrayRef shape) { - auto strides = smemObj.getStrides(); - auto offsets = smemObj.getOffsets(); - auto rank = strides.size(); - if (rank == 3) - return smemObj; - auto expandedStrides = insertValue(strides, 0, i32_val(shape[0] * shape[1])); - auto expandedOffsets = insertValue(offsets, 0, i32_val(0)); - auto expandedSmemObj = - SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), - expandedStrides, expandedOffsets); - return expandedSmemObj; -} - namespace SharedToDotOperandMMAv2 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, From 199d2b157cd8fcc415b8a1fd557e866d932c744c Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 16 Aug 2024 15:13:28 +0000 Subject: [PATCH 2/3] implement separate conversion path for unswizzled tensor to improve compiltion time and reduce number of instructions in assembly, fix bug with wrong order field used for share mem load size computation --- .../SharedToDotOperandFMA.cpp | 177 ++++++++++++------ 1 file changed, 122 insertions(+), 55 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp index d019ea9b787c..30900a8b4df1 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -33,13 +33,15 @@ Value getStructFromValueTable(ArrayRef vals, return packLLElements(loc, typeConverter, elems, rewriter, structTy); } +bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; } + SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, Location loc, SmallVector rawIndices, SharedEncodingAttr layout) { const auto &order = layout.getOrder(); auto rank = order.size(); - if (layout.getMaxPhase() == 1) + if (!isSwizzled(layout)) return rawIndices; auto vec = i32_val(layout.getVec()); @@ -66,16 +68,52 @@ SmallVector swizzleIndices(ConversionPatternRewriter &rewriter, return swizzledIndices; } +/** + * @brief put elements from Value vec to appropriate indexes in opValues array + * + * This function maps elements of 3d sub-tensor in linear array. + * Axes are arranged in following order from fastest to slowest: [nonKdim, kDim, + * bDim] + */ +void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc, + SmallVector &opValues, Value vec, + unsigned kElems, unsigned nonKElems, + unsigned kIdx, unsigned nonKIdx, unsigned bIdx, + int kDim, int nonKDim, int bDim, int fastDim) { + auto vecTy = cast(vec.getType()); + auto vectorSize = vecTy.getNumElements(); + auto elemTy = vecTy.getElementType(); + for (int elem = 0; elem < vectorSize; ++elem) { + unsigned outIdx[3] = {}; + outIdx[bDim] = bIdx; + outIdx[kDim] = kIdx; + outIdx[nonKDim] = nonKIdx; + + outIdx[fastDim] += elem; + auto idx = outIdx[bDim] * kElems * nonKElems + outIdx[kDim] * nonKElems + + outIdx[nonKDim]; + opValues[idx] = extract_element(elemTy, vec, i32_val(elem)); + } +} + Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, const int dotOpNo) { + auto ctaSplit = dLayout.getCTALayout().getCTASplitNum(); + for (auto split : ctaSplit) { + if (split != 1) + llvm::report_fatal_error("tensors splited in CGA(thread group clusters) " + "are not supported in FMA dot yet."); + } + auto ctx = dotOp.getContext(); const int bDim = 0; const int kDim = dotOpNo == 0 ? 2 : 1; const int nonKDim = dotOpNo == 0 ? 1 : 2; auto opTensorTy = cast(dotOp.getType()); - auto opLayout = cast(opTensorTy.getEncoding()); + auto opTensorShape = expandMatrixShapeWithBatch(opTensorTy.getShape()); + auto sharedLayout = cast(opTensorTy.getEncoding()); auto opShapePerCTA = expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(opTensorTy))); @@ -87,9 +125,9 @@ Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem, opTensorTy.getShape()); auto strides = smem.strides; - int B = opShapePerCTA[bDim]; - int K = opShapePerCTA[kDim]; - int NonK = opShapePerCTA[nonKDim]; + int B = opTensorShape[bDim]; + int K = opTensorShape[kDim]; + int NonK = opTensorShape[nonKDim]; auto shapePerCTATile = expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); @@ -100,7 +138,6 @@ Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, auto warpsPerCTA = expandMatrixShapeWithBatch(ArrayRef(dLayout.getWarpsPerCTA())); - // threadId in blocked layout auto warpSize = i32_val(triton::gpu::getWarpSize(dLayout)); auto laneId = urem(thread, warpSize); auto warpId = udiv(thread, warpSize); @@ -120,64 +157,94 @@ Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, auto elemTy = typeConverter->convertType(opTensorTy.getElementType()); Type ptrTy = ptr_ty(ctx, 3); - unsigned vectorSize = order[0] == kDim ? K : sizePerThread[order[0]]; - if (opLayout.getMaxPhase() > 1) - vectorSize = std::min(vectorSize, opLayout.getVec()); - // limit vector size with maximum width of load available on hardware - // TODO: get maximum vector size from target hardware info - vectorSize = std::min(16u, vectorSize); + auto sharedOrder = expandMatrixOrderWithBatch(sharedLayout.getOrder()); + unsigned vectorSize = + sharedOrder[0] == kDim ? K : sizePerThread[sharedOrder[0]]; + if (sharedLayout.getMaxPhase() > 1) + vectorSize = std::min(vectorSize, sharedLayout.getVec()); auto vecTy = vec_ty(elemTy, vectorSize); unsigned dimStep[3] = {1, 1, 1}; - dimStep[order[0]] = vectorSize; + dimStep[sharedOrder[0]] = vectorSize; - int shapePerCTABTile = shapePerCTATile[bDim]; - int shapePerCTANonKTile = shapePerCTATile[nonKDim]; - int sizeBPerThread = sizePerThread[bDim]; - int sizeNonKPerThread = sizePerThread[nonKDim]; - int numBTiles = std::max(1, B / shapePerCTABTile); - int numNonKTiles = std::max(1, NonK / shapePerCTANonKTile); + auto shapePerCTABTile = shapePerCTATile[bDim]; + auto shapePerCTANonKTile = shapePerCTATile[nonKDim]; + auto sizeBPerThread = sizePerThread[bDim]; + auto sizeNonKPerThread = sizePerThread[nonKDim]; + auto numBTiles = std::max(1u, B / shapePerCTABTile); + auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile); SmallVector opValues(numBTiles * sizeBPerThread * K * numNonKTiles * sizeNonKPerThread); + if (isSwizzled(sharedLayout)) { + for (unsigned bTile = 0; bTile < numBTiles; ++bTile) + for (unsigned b = 0; b < sizeBPerThread; b += dimStep[bDim]) + for (unsigned k = 0; k < K; k += dimStep[kDim]) + for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) + for (unsigned nonK = 0; nonK < sizeNonKPerThread; + nonK += dimStep[nonKDim]) { + SmallVector rawIndices(3); + rawIndices[bDim] = + add(bTileOffset, i32_val(bTile * shapePerCTABTile + b)); + rawIndices[nonKDim] = + add(nonKTileOffset, + i32_val(nonKTile * shapePerCTANonKTile + nonK)); + rawIndices[kDim] = i32_val(k); + + SmallVector swizzledIndices = + swizzleIndices(rewriter, loc, rawIndices, sharedLayout); - for (unsigned bTile = 0; bTile < numBTiles; ++bTile) - for (unsigned b = 0; b < sizeBPerThread; b += dimStep[bDim]) - for (unsigned k = 0; k < K; k += dimStep[kDim]) - for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) - for (unsigned nonK = 0; nonK < sizeNonKPerThread; - nonK += dimStep[nonKDim]) { - SmallVector rawIndices(3); - rawIndices[bDim] = - add(bTileOffset, i32_val(bTile * shapePerCTABTile + b)); - rawIndices[nonKDim] = add( - nonKTileOffset, i32_val(nonKTile * shapePerCTANonKTile + nonK)); - rawIndices[kDim] = i32_val(k); - - SmallVector swizzledIndices = - swizzleIndices(rewriter, loc, rawIndices, opLayout); - - Value offset = i32_val(0); - for (int dim = 0; dim < order.size(); ++dim) - offset = add(offset, mul(urem(swizzledIndices[dim], - i32_val(opShapePerCTA[dim])), - strides[dim])); - - Value elemAddr = gep(ptrTy, elemTy, smem.base, offset); - Value vecAddr = bitcast(elemAddr, ptr_ty(ctx, 3)); - Value vec = load(vecTy, elemAddr); - for (int elem = 0; elem < vectorSize; ++elem) { - int outIdx[3] = {}; - outIdx[bDim] = bTile * sizeBPerThread + b; - outIdx[kDim] = k; - outIdx[nonKDim] = nonKTile * sizeNonKPerThread + nonK; - outIdx[order[0]] += elem; - int idx = (outIdx[bDim] * K + outIdx[kDim]) * numNonKTiles * - sizeNonKPerThread + - outIdx[nonKDim]; - opValues[idx] = extract_element(elemTy, vec, i32_val(elem)); + Value offset = i32_val(0); + for (int dim = 0; dim < order.size(); ++dim) { + auto wrappedDimIndex = + urem(swizzledIndices[dim], i32_val(opTensorShape[dim])); + auto dimOffset = mul(wrappedDimIndex, strides[dim]); + offset = add(offset, dimOffset); + } + + Value elemAddr = gep(ptrTy, elemTy, smem.base, offset); + Value vec = load(vecTy, elemAddr); + storeValuesInLinearVector( + rewriter, loc, opValues, vec, /*kElems*/ K, + /*nonKElems*/ numNonKTiles * sizeNonKPerThread, /*kIdx*/ k, + /*nonKIdx*/ nonKTile * sizeNonKPerThread + nonK, + /*bIdx*/ bTile * sizeBPerThread + b, kDim, nonKDim, bDim, + sharedOrder[0]); } - } + } else { + auto bOffset = mul(urem(bTileOffset, i32_val(B)), strides[bDim]); + auto nonKOffset = + mul(urem(nonKTileOffset, i32_val(NonK)), strides[nonKDim]); + Value threadDependantOffset = add(bOffset, nonKOffset); + + auto basePtr = gep(ptrTy, elemTy, smem.base, threadDependantOffset); + + for (unsigned bTile = 0; bTile < numBTiles; ++bTile) + for (unsigned b = 0; b < sizeBPerThread; b += dimStep[bDim]) + for (unsigned k = 0; k < K; k += dimStep[kDim]) + for (unsigned nonKTile = 0; nonKTile < numNonKTiles; ++nonKTile) + for (unsigned nonK = 0; nonK < sizeNonKPerThread; + nonK += dimStep[nonKDim]) { + SmallVector offsetIndices(3); + offsetIndices[bDim] = i32_val((bTile * shapePerCTABTile + b) % B); + offsetIndices[nonKDim] = + i32_val((nonKTile * shapePerCTANonKTile + nonK) % NonK); + offsetIndices[kDim] = i32_val(k); + + Value offset = i32_val(0); + for (int dim = 0; dim < order.size(); ++dim) + offset = add(offset, mul(offsetIndices[dim], strides[dim])); + + Value elemAddr = gep(ptrTy, elemTy, basePtr, offset); + Value vec = load(vecTy, elemAddr); + storeValuesInLinearVector( + rewriter, loc, opValues, vec, /*kElems*/ K, + /*nonKElems*/ numNonKTiles * sizeNonKPerThread, /*kIdx*/ k, + /*nonKIdx*/ nonKTile * sizeNonKPerThread + nonK, + /*bIdx*/ bTile * sizeBPerThread + b, kDim, nonKDim, bDim, + sharedOrder[0]); + } + } return getStructFromValueTable(opValues, rewriter, loc, typeConverter, elemTy); From 90a467ac47351fa5fdb28fdff1e1b54b46b888f2 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 28 Aug 2024 16:41:26 +0200 Subject: [PATCH 3/3] [AMD] Emit AMD specific intrinsics for dot This PR: - Makes AccelerateAMDMatmul pass to emit FMA i8xi8->i32 and fp16xfp16->fp32 cases - Extends AMD FMA Dot code generation with new v_dot instructions for fp16xfp16 and int8 dtypes --- .../Conversion/TritonGPUToLLVM/Utility.h | 2 + .../Dialect/TritonGPU/Transforms/Utility.h | 3 + lib/Dialect/TritonGPU/Transforms/Utility.cpp | 15 ++ test/Conversion/amd/tritongpu_to_llvm.mlir | 36 ++++ .../amd/accelerate-amd-matmul-fma.mlir | 107 ++++++++++ .../amd/accelerate-amd-matmul-wmma-gen1.mlir | 4 +- .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 + .../lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp | 6 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp | 192 ++++++++++++++++++ .../AccelerateAMDMatmul.cpp | 185 ++++++++++++----- 10 files changed, 495 insertions(+), 56 deletions(-) create mode 100644 test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 6b17387e80db..231fdf5929fb 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -102,6 +102,8 @@ using namespace mlir::triton; #define undef(...) rewriter.create(loc, __VA_ARGS__) #define null(...) rewriter.create(loc, __VA_ARGS__) #define call(...) rewriter.create(loc, __VA_ARGS__) +#define call_intrinsic(...) \ + rewriter.create(loc, __VA_ARGS__) // Types #define int_ty(width) rewriter.getIntegerType(width) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 98fae2326b42..bf0c6dc2f8a0 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -180,6 +180,9 @@ bool isPureUnaryInlineAsm(Operation *op); // read the compute capability from the module attributes int getNVIDIAComputeCapability(Operation *module); +// read the amd target from the module attributes +StringRef getAMDArch(Operation *module); + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 57e41e55ff4f..1cf4fdb0b133 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -855,6 +855,21 @@ int getNVIDIAComputeCapability(Operation *module) { return computeCapability; } +StringRef getAMDArch(Operation *module) { + assert(module->hasAttr(triton::AttrTargetName) && + "Expected a target attribute on the module operation"); + + StringAttr targetAttr = + cast(module->getAttr(triton::AttrTargetName)); + + StringRef ref = targetAttr.strref(); + assert(ref.starts_with("hip:") && + "expected target attribute to be prefixed with \"cuda:\""); + + StringRef archStr = ref.drop_front(4); // drop the "hip:" + return archStr; +} + namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 65c0681368f3..acac50d4ae6f 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -34,3 +34,39 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// CHECK-LABEL: v_dot_i8 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @v_dot_i8(%arg0: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xi32, #blocked>) { + // CHECK-4: llvm.call_intrinsic "llvm.amdgcn.sdot4" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xi32, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: v_dot_fp16 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @v_dot_fp16(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xf32, #blocked>) { + // CHECK-8: llvm.call_intrinsic "llvm.amdgcn.fdot2" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> + tt.return + } +} + +// ----- + +// CHECK-LABEL: v_dot_fp16_fp16 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func @v_dot_fp16_fp16(%arg0: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xf16, #blocked>) { + // CHECK-4: llvm.call_intrinsic "llvm.amdgcn.sdot4" + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf16, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir new file mode 100644 index 000000000000..15324f19c45e --- /dev/null +++ b/test/TritonGPU/amd/accelerate-amd-matmul-fma.mlir @@ -0,0 +1,107 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942' | FileCheck %s + +// CHECK: fma_dot_fp16_fp16 +// CHECK: tt.dot {{.*}} : tensor<2x64xf16, {{.*}}> * tensor<64x64xf16, {{.*}}> -> tensor<2x64xf16, {{.*}}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_fp16_fp16( + %arg0: tensor<2x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf16, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf16, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: fma_dot_fp32_fp32 +// CHECK: tt.dot {{.*}} : tensor<2x64xf32, {{.*}}> * tensor<64x64xf32, {{.*}}> -> tensor<2x64xf32, {{.*}}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_fp32_fp32( + %arg0: tensor<2x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_i8 +// CHECK: tt.dot {{.*}} : tensor<2x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xi32, #[[BLOCKED]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_i8( + %arg0: tensor<2x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0> : tensor<2x64xi32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xi32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_f16 +// CHECK: tt.dot {{.*}} : tensor<2x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xf32, #[[BLOCKED]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_f16( + %arg0: tensor<2x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: #[[BLOCKED:.*]] = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: fma_dot_f8 +// CHECK: tt.dot {{.*}} : tensor<2x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[BLOCKED]]}>> * tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[BLOCKED]]}>> -> tensor<2x64xf32, #[[BLOCKED]]> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_f8( + %arg0: tensor<2x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0.0> : tensor<2x64xf32, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xf32, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +// CHECK: fma_dot_i8_i8 +// CHECK-DAG: %[[A:.*]] = arith.sitofp +// CHECK-DAG: %[[B:.*]] = arith.sitofp +// CHECK: %[[D:.*]] = tt.dot %[[A]], %[[B]], {{.*}} : tensor<2x64xf16, {{.*}}> * tensor<64x64xf16, {{.*}}> -> tensor<2x64xf16, {{.*}}> +// CHECK: arith.fptosi %[[D]] +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + tt.func public @fma_dot_i8_i8( + %arg0: tensor<2x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %arg1: tensor<64x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %arg2: tensor<2x64x!tt.ptr, #blocked> ) { + %cst = arith.constant dense<0> : tensor<2x64xi8, #blocked> + %1 = tt.dot %arg0, %arg1, %cst : tensor<2x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x64xi8, #blocked> + tt.store %arg2, %1 : tensor<2x64x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir index 7d3e8c23bed3..57e2ebdad568 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen1.mlir @@ -133,14 +133,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: %[[DOT3_ARG_B:.+]]: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]]}>> %1: tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, %2: tensor<128x32x!tt.ptr, #blocked>) { - // CHECK: %[[DOT3_ARG_C:.+]] = arith.constant dense<0> : tensor<128x32xi16, #[[DOT_OP_PARENT]]> + // CHECK: %[[DOT3_OP_C:.+]] = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #[[DOT_OP_PARENT]]> %3 = arith.constant dense<0> : tensor<128x32xi16, #blocked> // CHECK: %[[DOT3_OP_A:.+]] = arith.sitofp %[[DOT3_ARG_A]] // CHECK-SAME: to tensor<128x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_OP_B:.+]] = arith.sitofp %[[DOT3_ARG_B]] // CHECK-SAME: to tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DOT_OP_PARENT]] - // CHECK: %[[DOT3_OP_C:.+]] = arith.sitofp %[[DOT3_ARG_C]] - // CHECK-SAME: to tensor<128x32xf32, #[[DOT_OP_PARENT]] // CHECK: %[[DOT3_FMA_RES:.+]] = tt.dot %[[DOT3_OP_A]], %[[DOT3_OP_B]], %[[DOT3_OP_C]] // CHECK-SAME: -> tensor<128x32xf32, #[[DOT_OP_PARENT]]> %4 = tt.dot %0, %1, %3 : tensor<128x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x32xi16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x32xi16, #blocked> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 705c4258d052..87799c51e494 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonAMDGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp ConvertLayoutOpToLLVM.cpp + DotOpToLLVM/FMA.cpp DotOpToLLVM/MFMA.cpp DotOpToLLVM/WMMA.cpp DotOpToLLVM.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp index 15237282172f..bb22b212c5e3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp @@ -7,6 +7,10 @@ using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::getShapePerCTA; namespace mlir::triton::AMD { +LogicalResult convertAMDFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter); + LogicalResult convertMFMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter); @@ -46,7 +50,7 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { if (isa( cast(D.getType()).getEncoding())) - return convertFMADot(op, adaptor, getTypeConverter(), rewriter); + return AMD::convertAMDFMADot(op, adaptor, getTypeConverter(), rewriter); llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp new file mode 100644 index 000000000000..cb9b010968da --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/FMA.cpp @@ -0,0 +1,192 @@ +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace ::mlir::triton::gpu; + +using ::mlir::triton::gpu::expandMatrixOrderWithBatch; +using ::mlir::triton::gpu::expandMatrixShapeWithBatch; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getSizePerThread; + +namespace { + +using ValueTableFMA = std::map, Value>; + +static ValueTableFMA +getValueTableFromStructFMA(Value val, int batch, int nonK, int K, + ConversionPatternRewriter &rewriter, Location loc) { + ValueTableFMA res; + auto elems = unpackLLElements(loc, val, rewriter); + assert(elems.size() == K * nonK * batch); + int index = 0; + for (unsigned b = 0; b < batch; ++b) + for (unsigned k = 0; k < K; ++k) + for (unsigned i = 0; i < nonK; ++i) + res[{b, i, k}] = elems[index++]; + return res; +} + +struct DotIntrinsic { + int vectorSize; + Type outElemTy; + StringRef intrinsicName; + SmallVector additionalArgs; +}; + +DotIntrinsic chooseIntrinsic(ConversionPatternRewriter &rewriter, Location loc, + triton::DotOp op) { + auto aOpTy = cast(op.getA().getType()); + auto aElemTy = aOpTy.getElementType(); + auto dOpTy = cast(op.getD().getType()); + auto dElemTy = dOpTy.getElementType(); + auto mod = op->getParentOfType(); + auto arch = getAMDArch(mod); + DotIntrinsic chosenOp; + bool dotAvailable = arch == "gfx908" || arch == "gfx90a" || + arch.starts_with("gfx94") || arch.starts_with("gfx11") || + arch.starts_with("gfx103"); + if (dotAvailable) { + if (aElemTy.isF16() && dElemTy.isF32()) { + chosenOp.vectorSize = 2; + chosenOp.outElemTy = f32_ty; + chosenOp.intrinsicName = "llvm.amdgcn.fdot2"; + chosenOp.additionalArgs = {false_val()}; + return chosenOp; + } + if (aElemTy.isInteger(8) && dElemTy.isInteger(32)) { + chosenOp.vectorSize = 4; + chosenOp.outElemTy = i32_ty; + chosenOp.intrinsicName = "llvm.amdgcn.sdot4"; + chosenOp.additionalArgs = {false_val()}; + return chosenOp; + } + } + // choose one of FMA intrinsics + assert(aElemTy.isIntOrFloat() && !aElemTy.isIntOrIndex()); + assert(aElemTy == dElemTy); + assert(cast(op.getA().getType()).getElementType() == + dElemTy); + chosenOp.vectorSize = 1; + chosenOp.outElemTy = aElemTy; + if (aElemTy.isF32()) + chosenOp.intrinsicName = "llvm.fmuladd.f32"; + if (aElemTy.isF16()) + chosenOp.intrinsicName = "llvm.fmuladd.f16"; + chosenOp.additionalArgs = {}; + return chosenOp; +} + +Value packOperand(ConversionPatternRewriter &rewriter, Location loc, + ValueTableFMA scalarValues, unsigned b, unsigned nonK, + unsigned k, unsigned vectorSize) { + if (vectorSize == 1) + return scalarValues[{b, nonK, k}]; + auto elemTy = scalarValues[{b, nonK, k}].getType(); + auto vecTy = vec_ty(elemTy, vectorSize); + Value vec = undef(vecTy); + for (int elem = 0; elem < vectorSize; ++elem) { + vec = insert_element(vecTy, vec, scalarValues[{b, nonK, k + elem}], + i32_val(elem)); + } + if (elemTy.isInteger(8)) { + assert(vectorSize == 4); + vec = bitcast(vec, i32_ty); + } + return vec; +} + +Value generateDotOp(ConversionPatternRewriter &rewriter, Location loc, + DotIntrinsic op, Value a, Value b, Value c) { + SmallVector args{a, b, c}; + args.append(op.additionalArgs.begin(), op.additionalArgs.end()); + SmallVector argTypes; + for (auto arg : args) + argTypes.push_back(arg.getType()); + auto funcType = LLVM::LLVMFunctionType::get(op.outElemTy, argTypes); + auto d = call_intrinsic(op.outElemTy, op.intrinsicName, args); + return d.getResult(0); +} + +} // namespace + +namespace mlir::triton::AMD { + +LogicalResult convertAMDFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor, + const LLVMTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + auto A = op.getA(); + auto B = op.getB(); + auto C = op.getC(); + auto D = op.getResult(); + + auto aTensorTy = cast(A.getType()); + auto dTensorTy = cast(D.getType()); + auto dElemTy = dTensorTy.getElementType(); + + SmallVector aShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy))); + auto dShapePerCTA = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy))); + + BlockedEncodingAttr dLayout = + cast(dTensorTy.getEncoding()); + auto order = expandMatrixOrderWithBatch(dLayout.getOrder()); + auto cc = unpackLLElements(loc, adaptor.getC(), rewriter); + + Value llA = adaptor.getA(); + Value llB = adaptor.getB(); + + auto sizePerThread = + expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout))); + auto shapePerCTATile = + expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout))); + + int K = aShapePerCTA[2]; + + unsigned retSize[3]; + for (int i = 0; i < 3; ++i) { + unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i]; + numRep = std::max(static_cast(1), numRep); + retSize[i] = numRep * sizePerThread[i]; + } + + auto has = + getValueTableFromStructFMA(llA, retSize[0], retSize[1], K, rewriter, loc); + auto hbs = + getValueTableFromStructFMA(llB, retSize[0], retSize[2], K, rewriter, loc); + + SmallVector ret = cc; + auto selectedOp = chooseIntrinsic(rewriter, loc, op); + + for (unsigned b = 0; b < retSize[0]; ++b) + for (unsigned m = 0; m < retSize[1]; ++m) + for (unsigned n = 0; n < retSize[2]; ++n) { + unsigned idx[] = {b, m, n}; + unsigned linearIdx = 0; + for (auto dim : llvm::reverse(order)) { + linearIdx = linearIdx * retSize[dim] + idx[dim]; + } + for (unsigned k = 0; k < K; k += selectedOp.vectorSize) { + auto aOp = + packOperand(rewriter, loc, has, b, m, k, selectedOp.vectorSize); + auto bOp = + packOperand(rewriter, loc, hbs, b, n, k, selectedOp.vectorSize); + ret[linearIdx] = generateDotOp(rewriter, loc, selectedOp, aOp, bOp, + ret[linearIdx]); + } + } + + auto res = packLLElements(loc, typeConverter, ret, rewriter, dTensorTy); + rewriter.replaceOp(op, res); + + return success(); +} + +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index bf976a8138dc..7143fcb6602e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -521,58 +521,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod) { ? AElType : BElType; } else { - // FMA case. - Type AElType = dotOp.getA().getType().getElementType(); - Type DElType = D.getType().getElementType(); - - // Convert int operands to FP32 to apply FMA case - // Do it here instead of introducing new pattern because the pass is more - // about MMA dots. - // TODO: Introduce new pass for FMA dots legalization. - if (AElType.isIntOrIndex()) { - assert(dotOp.getB().getType().getElementType().isIntOrIndex() && - dotOp.getC().getType().getElementType().isIntOrIndex() && - DElType.isIntOrIndex()); - auto convertTensorIToFP = [&](Value v) -> Value { - RankedTensorType vTy = cast(v.getType()); - Type dstType = vTy.cloneWith(std::nullopt, builder.getF32Type()); - Type srcElType = vTy.getElementType(); - return !srcElType.isUnsignedInteger() - ? builder - .create(dotOp.getLoc(), dstType, v) - .getResult() - : builder - .create(dotOp.getLoc(), dstType, v) - .getResult(); - }; - auto convertTensorFPToI = [&](Type dstElType, Value v) -> Value { - RankedTensorType vTy = cast(v.getType()); - Type dstType = vTy.cloneWith(std::nullopt, dstElType); - return !dstElType.isUnsignedInteger() - ? builder - .create(dotOp.getLoc(), dstType, v) - .getResult() - : builder - .create(dotOp.getLoc(), dstType, v) - .getResult(); - }; - - auto newAOperand = convertTensorIToFP(dotOp.getA()); - auto newBOperand = convertTensorIToFP(dotOp.getB()); - auto newCOperand = convertTensorIToFP(dotOp.getC()); - auto newDot = builder.create( - dotOp.getLoc(), newCOperand.getType(), newAOperand, newBOperand, - newCOperand, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); - auto newD = convertTensorFPToI(DElType, newDot.getResult()); - D.replaceAllUsesWith(newD); - dotOp.erase(); - return; - } - - if (AElType == DElType) - return; - promoteType = DElType; + // FMA case is processed in AccelerateBlocked + return; } Location loc = dotOp.getLoc(); Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType); @@ -669,6 +619,136 @@ class BlockedToWMMA : public RewritePattern { return success(); } }; + +class AccelerateBlocked : public mlir::RewritePattern { + StringRef arch; + +public: + AccelerateBlocked(mlir::MLIRContext *context, StringRef arch) + : mlir::RewritePattern(tt::DotOp::getOperationName(), 1, context), + arch(arch) {} + + bool isFloat(Type t) const { return t.isIntOrFloat() && !t.isIntOrIndex(); } + + Value castToElTy(mlir::PatternRewriter &rewriter, Value v, Type elTy) const { + Location loc = v.getLoc(); + auto srcTy = cast(v.getType()); + auto dstTy = srcTy.cloneWith(std::nullopt, elTy); + if (srcTy == dstTy) + return v; + auto srcElTy = srcTy.getElementType(); + auto dstElTy = dstTy.getElementType(); + if (isFloat(srcElTy) && isFloat(dstElTy)) + return rewriter.create(loc, dstTy, v); + if (!isFloat(srcElTy) && isFloat(dstElTy)) + return rewriter.create(loc, dstTy, v); + if (isFloat(srcElTy) && !isFloat(dstElTy)) + return rewriter.create(loc, dstTy, v); + assert(false && "int -> int cast is unexpected in FMA legalization"); + return Value(); + } + + bool legalFMAForm(tt::DotOp dotOp) const { + auto expectedElTy = dotOp.getA().getType().getElementType(); + for (auto operand : dotOp.getOperands()) { + auto opTy = cast(operand.getType()); + auto elTy = opTy.getElementType(); + if (elTy != expectedElTy) + return false; + if (!elTy.isF16() && !elTy.isF32()) + return false; + } + return true; + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + auto a = dotOp.getA(); + auto b = dotOp.getB(); + auto c = dotOp.getC(); + auto d = dotOp.getD(); + + if (!llvm::isa(d.getType().getEncoding())) + return failure(); + + Type aElTy = a.getType().getElementType(); + Type bElTy = b.getType().getElementType(); + Type cElTy = c.getType().getElementType(); + Type dElTy = d.getType().getElementType(); + + int rank = a.getType().getShape().size(); + int k = a.getType().getShape()[rank - 1]; + + bool dotAvailable = arch == "gfx908" || arch == "gfx90a" || + arch.starts_with("gfx94") || + arch.starts_with("gfx11") || arch.starts_with("gfx103"); + + // Try Fp16 x Fp16 -> Fp32 dot + if (dotAvailable && aElTy.isF16() && bElTy.isF16() && cElTy.isF32() && + dElTy.isF32()) { + if (k % 2 == 0) { + // nothing to do for this dot + return failure(); + } + // if k % 2 != 0: can not use DOT instruction, continue with FMA + } + // Try I8 x I8 -> I32 dot + if (dotAvailable && aElTy.isInteger(8) && bElTy.isInteger(8) && + cElTy.isInteger(32) && dElTy.isInteger(32)) { + if (k % 4 == 0) { + // nothing to do for this dot + return failure(); + } + // if k % 4 != 0: can not use DOT instruction, continue with FMA + } + + // check that dot is not legalized already + if (legalFMAForm(dotOp)) { + return failure(); + } + + // Legalize dot for FMA case + + // find common type, larger or equal of all operand types + SmallVector opElTy{aElTy, bElTy, cElTy, dElTy}; + unsigned maxBitsize = 8; + for (auto elTy : opElTy) + maxBitsize = std::max(maxBitsize, elTy.getIntOrFloatBitWidth()); + assert(maxBitsize <= 32); + Type commonTy = + maxBitsize <= 16 ? rewriter.getF16Type() : rewriter.getF32Type(); + + // check that type is compatible with all operands + // fallback to fp32 if not + if (commonTy.isF16()) { + for (auto elTy : opElTy) { + if (elTy.isInteger() && elTy.getIntOrFloatBitWidth() > 8) { + commonTy = rewriter.getF32Type(); + break; + } + if (elTy.isBF16()) { + commonTy = rewriter.getF32Type(); + break; + } + } + } + + auto newA = castToElTy(rewriter, a, commonTy); + auto newB = castToElTy(rewriter, b, commonTy); + auto newC = castToElTy(rewriter, c, commonTy); + + auto newDot = rewriter.create( + dotOp.getLoc(), newC.getType(), newA, newB, newC, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc()); + auto newD = castToElTy(rewriter, newDot.getResult(), dElTy); + d.replaceAllUsesWith(newD); + dotOp.erase(); + return success(); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -705,6 +785,7 @@ class TritonAMDGPUAccelerateMatmulPass default: break; } + patterns.add(context, archGenerationName); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); }