From 4bcd309904f3d5e58ca01d616f7b1061fa9c2893 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Thu, 14 Nov 2024 19:07:19 +0000 Subject: [PATCH 1/7] introduce AMD inThreadTranspose for K-major dot operand --- bin/RegisterTritonDialects.h | 1 + .../TritonGPU/IR/LinearLayoutConversions.h | 8 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 99 ++++++++- test/TritonGPU/amd/in-thread-transpose.mlir | 54 +++++ .../include/TritonAMDGPUTransforms/Passes.h | 2 + .../include/TritonAMDGPUTransforms/Passes.td | 15 ++ .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 + .../inThreadTranspose.cpp | 207 ++++++++++++++++++ third_party/amd/python/triton_amd.cc | 2 + .../TritonGPU/LinearLayoutConversionsTest.cpp | 33 +++ 10 files changed, 420 insertions(+), 2 deletions(-) create mode 100644 test/TritonGPU/amd/in-thread-transpose.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index e873965e479a..270fd95b198d 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -67,6 +67,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUConvertToBufferOps(); mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + mlir::registerTritonAMDGPUInThreadTranspose(); // TODO: register Triton & TritonGPU passes registry diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 7c81b2496cdf..492bf3697123 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -37,12 +37,18 @@ namespace mlir::triton::gpu { // to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e. // shared layouts with hasLeadingOffset == true) but is otherwise unused. // +// inThreadTranspose is a flag indicating if transpose should be performed while +// the data resides in thread-local registers. This is set to true on AMD +// platform when non-KContig matrix is about to be written into LDS (shared +// memory) but is otherwise unused. More details are provided in the +// transpose2D() function in LinearLayoutConversions.cpp. // Returns std::nullopt if the given layout can't be converted to an LL. // TODO(jlebar): Remove the std::optional once all layouts are supported. // std::optional toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth = std::nullopt); + std::optional elemBitWidth = std::nullopt, + bool inThreadTranspose = false); // Given a linear layout where the input dimensions contain a "block" dimension, // this method sets the "block" dimension to 0 and removes the corresponding diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 32152190b6e6..fc93365e2ae8 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -43,6 +43,60 @@ SmallVector permuteDimNames(const SmallVector &names, return ret; } +// For sizePerThread = [4, 8], the regular linear layout will express it as +// the following +// - register=1 -> (1, 0) +// register=2 -> (2, 0) +// register=4 -> (4, 0) +// register=8 -> (0, 1) +// register=16 -> (0, 2) +// where out dims are: [dim1 (size 8), dim0 (size 4)] +// If we take the binary form, it will be an identity matrix. If we traverse +// from the dim of 4, it will be like the following +// - register=1 -> (0, 1) +// register=2 -> (0, 2) +// register=4 -> (1, 0) +// register=8 -> (2, 0) +// register=16 -> (4, 0) +// where out dims are: [dim1 (size 8), dim0 (size 4)] +// Inside the function we only change the register layout generation, so +// register layout is created by newly introduced transpose2D and the rest still +// comes from identityStandardND. +// Note that simply reversing the for-loop identityStandardND will not work +// because it will change the most minor dimension from dim1 to dim0, and still +// keep it as an identity matrix. +LinearLayout transpose2D(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { + assert(shape.size() == order.size()); + assert((order.size() == 2) && "only support dim of 2 now"); + + MLIRContext *ctx = inDimName.getContext(); + StringAttr kRegister = S("register"); + + std::vector> bases; + // traverse 2nd dimension (K-dim in GEMM case) + int dim = order[1]; + for (int basis = 1; basis < shape[dim]; basis <<= 1) { + bases.push_back({0, basis}); + } + // traverse 1st dimension (N-dim in GEMM non-KContig B-tensor) + // this is the consecutive dimension loaded from global memory + dim = order[0]; + for (int basis = 1; basis < shape[dim]; basis <<= 1) { + bases.push_back({basis, 0}); + } + + auto dimMinor = "dim" + std::to_string(order[0]); + auto dimMajor = "dim" + std::to_string(order[1]); + StringAttr kDimMinor = S(dimMinor); + StringAttr kDimMajor = S(dimMajor); + auto ret = LinearLayout( + {{kRegister, bases}}, + {{kDimMinor, shape[order[0]]}, {kDimMajor, shape[order[1]]}}, false); + + return ret; +} + // Make a LinearLayout that maps a block-id to an N-dimensional index. // // The tensor is split up into CTAsPerCGA pieces, which are distributed among @@ -239,6 +293,45 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape); } +// This function convert blockedEncodingAttr to linear layout in a special way. +// It accompanies the AMDGPUInThreadTranspose pass to transpose non-KContig +// tensor into KContig prior to writing into LDS (shared memory). This +// conversion treats the sizePerThread as a 2D matrix and has different access +// pattern. +// +// For example, consider the following blocked layout generated by +// AMDGPUInThreadTranspose: #blocked1 = #triton_gpu.blocked<{sizePerThread = +// [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>. +// Here since sizePerThread is 2D, there could be two ways to traverse it: along +// the dim of 8 or the dim of 4. The regular toLinearLayout() would go through +// it from the leading order, i.e. dim of 8, but since we want to transpose it +// in-thread, we'd want to iterate of the 2nd order, i.e. dim of 4, so that we +// can pack the element of 4 into a single vector, and AMD backend LLVM compiler +// will pack elements into consecutive VGPR to write data contiguous in K +// dimension into LDS. In this way we guarantee vectorized ds_read, and ds_write +// can be vectorized to 64bit or 32bit depending on the block size and number of +// warps. +// +// The functions is named ThreadRake because we have thread raking through +// multiple row at the same time, as opposed each warp raking through a cluster +// of rows, or the Triton way, which iterates through every warp avaiable, +// and then tile it over the entire block. +LinearLayout blockedToLinearLayoutThreadRake(ArrayRef shape, + BlockedEncodingAttr blocked) { + MLIRContext *ctx = blocked.getContext(); + int rank = shape.size(); + auto outDimNames = standardOutDimNames(ctx, rank); + const auto &order = blocked.getOrder(); + auto sizePerThread = blocked.getSizePerThread(); + + auto ctaLayout = + transpose2D(S("register"), sizePerThread, order) * + identityStandardND(S("lane"), blocked.getThreadsPerWarp(), order) * + identityStandardND(S("warp"), blocked.getWarpsPerCTA(), order); + + return combineCtaCgaWithShape(ctaLayout, blocked.getCTALayout(), shape); +} + } // anonymous namespace std::optional @@ -755,9 +848,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional toLinearLayout(ArrayRef shape, Attribute layout, - std::optional elemBitWidth /*= std::nullopt*/) { + std::optional elemBitWidth /*= std::nullopt*/, + bool inThreadTranspose /*= false*/) { // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { + auto blocked = dyn_cast(distributed); + if (blocked && inThreadTranspose) + return blockedToLinearLayoutThreadRake(shape, blocked); return distributed.toLinearLayout(shape); } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { diff --git a/test/TritonGPU/amd/in-thread-transpose.mlir b/test/TritonGPU/amd/in-thread-transpose.mlir new file mode 100644 index 000000000000..747bfd6477f6 --- /dev/null +++ b/test/TritonGPU/amd/in-thread-transpose.mlir @@ -0,0 +1,54 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-in-thread-transpose | FileCheck %s + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x256x!tt.ptr, [[threadrake_layout]]> +// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<64x256x!tt.ptr, [[threadrake_layout]]> + tt.func public @threadRake_transpose_b(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + tt.return + } +} + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<32x128x!tt.ptr, [[threadrake_layout]]> +// CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<32x128x!tt.ptr, [[threadrake_layout]]> + tt.func public @threadRake_transpose_b_no_change(%arg0: tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x128x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %1 = tt.load %arg1 : tensor<32x128x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + tt.return + } +} + + +// ----- +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { + +// CHECK-NOT: {{.*}} = triton_gpu.convert_layout {{.*blocked.*}} -> {{.*blocked.*}} + tt.func public @threadRake_no_transpose(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> + %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + tt.return + } +} diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 630a1e903562..894e50322f1a 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -27,6 +27,8 @@ std::unique_ptr createTritonAMDGPUCanonicalizePointersPass(); std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass(); +std::unique_ptr createTritonAMDGPUInThreadTransposePass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "TritonAMDGPUTransforms/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 6bee6da5fb45..867e2208b78f 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -124,4 +124,19 @@ def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "ml let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"]; } +def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::ModuleOp"> { + let summary = "Transpose K-outer dot operand while data is loaded into register right before writing to LDS"; + + let description = [{ + Transpose non-KContig dot operand (not consecutive on K dimension) right before writing data into LDS. This feature + happens right after data has been loaded from global memory to thread-local registers and will promote + (does not guarantee) vectorized LDS write while let SharedEncodingAttr guarantee vectorized LDS read, by + adding few VALU instructions to perform in-thread transpose. + }]; + + let constructor = "mlir::createTritonAMDGPUInThreadTransposePass()"; + + let dependentDialects = []; +} + #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index aef5886b11d8..88cb0f68b1ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_triton_library(TritonAMDGPUTransforms ReorderInstructions.cpp StreamPipeline.cpp MfmaGroup.cpp + inThreadTranspose.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp new file mode 100644 index 000000000000..a15d81f4caa0 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp @@ -0,0 +1,207 @@ +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritonamdgpu-in-thread-transpose" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = dyn_cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +// This function is mostly copied over from coalesce.cpp since it uses almost +// the same functionality. +void convertLayout(Attribute encoding, Operation *op) { + OpBuilder builder(op); + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType) { + Type newType = getNewType(tensorType, encoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + newTypes.push_back(getNewType(t, encoding)); + } + + // Construct new op with the new encoding + Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(), + newArgs, newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); +} + +SmallVector getLoadInsts(Operation *op) { + SmallVector ret; + auto v = op->getOperand(0); + auto prevOp = v.getDefiningOp(); + if (isa(prevOp)) { + // Deal with the case that convert_layout intakes from scf.if, etc. + LDBG("Dealing with scf blocks"); + auto idx = cast(v).getResultNumber(); + llvm::SmallVector yieldOps; + prevOp->walk([&](Operation *op) { + if (auto yieldOp = dyn_cast(op)) { + yieldOps.push_back(yieldOp); + } + }); + + for (auto yieldOp : yieldOps) { + auto maybeLoadOp = yieldOp.getOperand(idx).getDefiningOp(); + if (isa(maybeLoadOp)) + ret.push_back(maybeLoadOp); + } + } else if (isa(prevOp)) { + // regular case + LDBG("Regular cases"); + ret.push_back(prevOp); + } else { + // can't find any loadOp + LDBG("we assume load->convert_layout->dot chain but we cannot find it."); + } + return ret; +} + +bool needCvtToThreadRaked(Value operand) { + auto opTensorTy = cast(operand.getType()); + auto opEnc = opTensorTy.getEncoding(); + auto opDotOpEnc = dyn_cast(opEnc); + // dotOperand has to have dotOp and MFMA encoding + if (!opDotOpEnc) + return false; + if (!isa(opDotOpEnc.getParent())) { + LDBG("Operand's parent encoding is not MFMA"); + return false; + } + auto cvtOp = operand.getDefiningOp(); + // make sure the previous op is convert_layout + if (!cvtOp || !isa(cvtOp)) + return false; + auto cvtOperand = cvtOp->getOperand(0); + auto cvtOperandEnc = + cast(cvtOperand.getType()).getEncoding(); + auto blockedEnc = dyn_cast(cvtOperandEnc); + // make sure it is converted from blocked layout + if (!blockedEnc) + return false; + // check whether it's contiguous on K dimension + int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; + auto order = blockedEnc.getOrder(); + if (order[0] != kDimNum) { + return true; + } + + return false; +} + +ttg::BlockedEncodingAttr getThreadRakedBlockedEnc(Value operand, + ModuleOp &mod) { + // get the K dim according to dotOp operand's index + auto tensorTy = cast(operand.getType()); + auto shape = tensorTy.getShape(); + auto opEnc = tensorTy.getEncoding(); + auto opDotOpEnc = dyn_cast(opEnc); + int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; + // get the current blocked encoding + auto cvtOperand = operand.getDefiningOp()->getOperand(0); + auto cvtOperandEnc = + cast(cvtOperand.getType()).getEncoding(); + auto blockedEnc = dyn_cast(cvtOperandEnc); + // compute the sizePerThread for the new encoding + auto sizePerThread = blockedEnc.getSizePerThread(); + auto elemsPerIter = product(sizePerThread); + auto elemsTotal = blockedEnc.getTotalElemsPerThread(shape, tensorTy); + // we need to know how many iteration each thread will load + LDBG("elemsPerIter = " << elemsPerIter << "; elemsTotal = " << elemsTotal); + auto numMaxIters = elemsTotal / elemsPerIter; + auto bitwidth = tensorTy.getElementType().getIntOrFloatBitWidth(); + // Current the widest is set to ds_write_b64 + auto newKOuterDim = std::min(numMaxIters, 64 / bitwidth); + LDBG("Choose the minimum of numIters: " << numMaxIters << " and numDtype: " + << 64 / bitwidth); + SmallVector newSizePerThread(sizePerThread); + newSizePerThread[kDimNum] = newKOuterDim; + + // return the new blocked encoding + auto order = blockedEnc.getOrder(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + return ttg::BlockedEncodingAttr::get(mod.getContext(), shape, + newSizePerThread, order, numWarps, + threadsPerWarp, numCTAs); +} + +class TritonAMDGPUInThreadTransposePass + : public TritonAMDGPUInThreadTransposeBase< + TritonAMDGPUInThreadTransposePass> { + +public: + TritonAMDGPUInThreadTransposePass() = default; + + void runOnOperation() override { + ModuleOp m = getOperation(); + + m.walk([&](Operation *op) { + auto dotOp = dyn_cast(op); + if (!dotOp) + return; + + LDBG("DotOp under inspection: " << dotOp); + auto mod = op->getParentOfType(); + + // helper function + auto cvtNonKContigDotOperand = [&](Value op) { + if (needCvtToThreadRaked(op)) { + auto loadOps = getLoadInsts(op.getDefiningOp()); + // when we cannot find the associated loadOp + if (!loadOps.size()) + return; + auto newBlockedEnc = getThreadRakedBlockedEnc(op, mod); + LDBG("newBlockedEnc = " << newBlockedEnc); + for (auto loadOp : loadOps) + convertLayout(newBlockedEnc, (Operation *)loadOp); + } + }; + + cvtNonKContigDotOperand(dotOp.getA()); + cvtNonKContigDotOperand(dotOp.getB()); + }); + } +}; + +std::unique_ptr mlir::createTritonAMDGPUInThreadTransposePass() { + return std::make_unique(); +} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 8132773fc2a1..e26941ee006d 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -74,6 +74,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_2("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass, int, int); + ADD_PASS_WRAPPER_0("add_in_thread_tranpose", + mlir::createTritonAMDGPUInThreadTransposePass); } void addControlConstant(llvm::Module *module, const char *name, diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index af6242b59662..59a3b1959ea6 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -307,6 +307,39 @@ TEST_F(LinearLayoutConversionsTest, Blocked4D) { {S("dim0"), S("dim1"), S("dim2"), S("dim3")})); } +TEST_F(LinearLayoutConversionsTest, inThreadTranspose_4x8) { + auto ll = toLinearLayout( + {64, 256}, + blocked({4, 8}, {2, 32}, {8, 1}, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + std::nullopt, true); + EXPECT_EQ(ll, + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), + {{0, 8}, {0, 16}, {0, 32}, {0, 64}, {0, 128}, {4, 0}}}, + {S("warp"), {{8, 0}, {16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(LinearLayoutConversionsTest, inThreadTranspose_1x8) { + auto ll = toLinearLayout( + {32, 128}, + blocked({1, 8}, {4, 16}, {8, 1}, {1, 1}, {1, 1}, {1, 0}, {1, 0}), + std::nullopt, true); + EXPECT_EQ(ll, LinearLayout( + { + {S("register"), {{0, 1}, {0, 2}, {0, 4}}}, + {S("lane"), + {{0, 8}, {0, 16}, {0, 32}, {0, 64}, {1, 0}, {2, 0}}}, + {S("warp"), {{4, 0}, {8, 0}, {16, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MMAv2_16x16) { EXPECT_EQ(toLinearLayout({16, 16}, mma(2, 0, {16, 8}, {1, 1}, {1, 1}, {1, 1}, {0, 1})), From cb22067cf3ab44b984a25cfa27acdef4be14ca13 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Fri, 15 Nov 2024 04:24:07 +0000 Subject: [PATCH 2/7] added sharedEncodingAttr and lowerToLLVM for inThreadTranspose --- .../Conversion/TritonGPUToLLVM/Utility.h | 3 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 37 ++++---- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 10 ++- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 84 ++++++++++++++----- test/TritonGPU/amd/amd-instruction-sched.mlir | 4 +- third_party/amd/backend/compiler.py | 2 + .../LoadStoreOpToLLVM.cpp | 3 +- 7 files changed, 99 insertions(+), 44 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index d9c3acbf712d..b9a0cf65bd11 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1145,7 +1145,7 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + const TargetInfoBase &target, bool crossGrain, std::function perVectorCallback); inline DenseMap getSwizzledSharedPtrs( @@ -1321,6 +1321,7 @@ void storeDistributedToShared( triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + bool crossGrain = false, std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index b900c3d2e3b7..0d45f5fad42d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -282,29 +282,32 @@ compared to 1*64 when the hasLeadingOffset is false. if (needTrans) kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); - if (isKDimInner) { - const int numBanks = 32; - const int bankBitWidth = 32; - const int SIMDWidth = 16; + const int numBanks = 32; + const int bankBitWidth = 32; + const int SIMDWidth = 16; - // number of inner dimension rows per one pattern repeat - int innerDimLength = shape[order[0]]; - int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + // number of inner dimension rows per one pattern repeat + unsigned innerDim = isKDimInner ? order[0] : order[1]; + int innerDimLength = shape[innerDim]; + int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; - int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); - // vecSize is set to kWidth of the dotop layout - int vecSize = dotOpEnc.getKWidth(); - int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); + // vecSize is set to kWidth of the dotop layout + int vecSize = dotOpEnc.getKWidth(); + int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); - // TODO (zhanglx): figure out better parameters for mfma4 - if (mfmaEnc.getMDim() == 4) - maxPhase = 4; + // TODO (zhanglx): figure out better parameters for mfma4 + if (mfmaEnc.getMDim() == 4) + maxPhase = 4; + if (isKDimInner) { return get(context, vecSize, perPhase, maxPhase, order, CTALayout); } else { - // Do not swizzle in case k dimension is not innermost. - // In this case accesses will go in different banks even without swizzling. - return get(context, 1, 1, 1, order, CTALayout); + // swap order because blocked layout has non-KContig but in LDS it will be KContig + SmallVector newOrder(order); + std::swap(newOrder[0], newOrder[1]); + // TODO: set inThreadTranspose to true since we want to use special swizzling + return $_get(context, vecSize, perPhase, maxPhase, newOrder, CTALayout, false); } } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 3488a686134e..979a85b6bda3 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -23,6 +23,14 @@ void lowerDistributedToShared( auto srcTy = cast(src.getType()); auto dstTy = cast(dst.getType()); auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); + bool crossGrain = false; + // only set crossGrain if it is blocked->shared. This is not a problem for + // NV path because for non-KContig tensor their blocked and shared layout + // still have the same order. + if (auto blocked = dyn_cast(srcTy.getEncoding())) { + auto inOrd = blocked.getOrder(); + crossGrain = inOrd[0] != outOrd[0]; + } assert(srcTy.getShape().size() <= 2 || (srcTy.getShape().size() == 3 && outOrd[2] == 0) && "Unexpected rank of ConvertLayout(blocked->shared)"); @@ -32,7 +40,7 @@ void lowerDistributedToShared( auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo, llvmOpCount); + loc, rewriter, targetInfo, crossGrain, llvmOpCount); } struct GlobalScratchAllocOpConversion diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index c681cd344ce8..e05fd4b81ec8 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -162,7 +162,7 @@ bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + const TargetInfoBase &target, bool crossGrain, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); @@ -174,8 +174,12 @@ bool emitTransferBetweenRegistersAndShared( StringAttr kLane = str_attr("lane"); StringAttr kWarp = str_attr("warp"); - std::optional regLayout = - triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); + std::optional regLayout = LinearLayout::empty(); + auto regEncoding = registerTy.getEncoding(); + // setting elemBitWidth to std::nullopt is fine because that is only used for + // shared layout + regLayout = + triton::gpu::toLinearLayout(shape, regEncoding, std::nullopt, crossGrain); std::optional sharedLayout = triton::gpu::toLinearLayout( shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth()); if (!regLayout.has_value() || !sharedLayout.has_value()) { @@ -280,7 +284,7 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), - smemObj.getStrides(), loc, rewriter, target, + smemObj.getStrides(), loc, rewriter, target, /*crossGrain = */ false, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * @@ -301,26 +305,62 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy, ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, Location loc, RewriterBase &rewriter, - const TargetInfoBase &target, + const TargetInfoBase &target, bool crossGrain, std::pair *const llvmOpCount) { - bool success = emitTransferBetweenRegistersAndShared( + bool success; + std::function perVectorCallback; + if (!crossGrain) { + // callback for every situation except the non-KContig dotOperand + // blocked->shared on AMD platform + perVectorCallback = [&](VectorType vecTy, Value vecAddr) { + ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); + srcVals = srcVals.drop_front(vecTy.getNumElements()); + + Value vec = undef(vecTy); + for (int i = 0; i < vals.size(); i++) { + vec = insert_element(vec, vals[i], i32_val(i)); + } + store(vec, vecAddr) + .setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + }; + } else { + // This section is only for inThreadTranspose for AMD path, where we want to + // transpose during the blocked->shared tranfer. + // For example, the thread-local register holds a [4, 8] section of matrix, + // where it is contiguous on the dim of 8. We want the perVectorCallback to + // access the column of 4 elements, 8 times, instead of row of 8 elements, + // 4 times like the callback above. For the specific example, the variables + // accessed or derived below will be the following: + // sizePerThread: [4, 8] + // order: [1, 0] + // numElemsPerIter: 4 x 8 = 32 + // colIndex: initialized as 0, increment to 8 every time callback is called + // innerVecSize: 8, since it is the vector size of inner dimension + auto blockedEncoding = dyn_cast(srcTy.getEncoding()); + auto sizePerThread = blockedEncoding.getSizePerThread(); + auto order = blockedEncoding.getOrder(); + unsigned int numElemsPerIter = product(sizePerThread); + unsigned int colIndex = 0; + unsigned int innerVecSize = sizePerThread[order[0]]; + perVectorCallback = [&](VectorType vecTy, Value vecAddr) { + Value vec = undef(vecTy); + auto startPos = colIndex / innerVecSize * + numElemsPerIter + // start pos of different iter + colIndex % innerVecSize; // start pos of single iter + for (int i = 0; i < vecTy.getNumElements(); i++) { + auto idx = startPos + i * innerVecSize; // iterate within a vector + vec = insert_element(vec, srcVals[idx], i32_val(i)); + } + colIndex++; + store(vec, vecAddr) + .setAlignment(vecTy.getNumElements() * + elemLlvmTy.getIntOrFloatBitWidth() / 8); + }; + } + success = emitTransferBetweenRegistersAndShared( srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { - ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); - srcVals = srcVals.drop_front(vecTy.getNumElements()); - - Value vec = undef(vecTy); - for (int i = 0; i < vals.size(); i++) { - vec = insert_element(vec, vals[i], i32_val(i)); - } - store(vec, vecAddr) - .setAlignment(vecTy.getNumElements() * - elemLlvmTy.getIntOrFloatBitWidth() / 8); - if (llvmOpCount) { - ++(llvmOpCount->first); - llvmOpCount->second = vecTy; - } - }); + dstStrides, loc, rewriter, target, crossGrain, perVectorCallback); if (!success) llvm::report_fatal_error("Failed to emit transfer from register to shared"); diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index 8cc3ae64f44c..c550f8ea31b4 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -50,7 +50,7 @@ module { // INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false // INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false // INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> - // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<4xf16>> // INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none> // INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none> // INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> @@ -61,7 +61,7 @@ module { // INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false // INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> - // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>> + // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d86..f9566e21fe28 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -225,6 +225,8 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_optimize_thread_locality(pm) amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack) passes.ttgpuir.add_remove_layout_conversions(pm) + amd.passes.ttgpuir.add_in_thread_tranpose(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d2cef405ebdf..63347b1da25f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -936,7 +936,8 @@ struct AsyncCopyGlobalToLocalOpConversion SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, - rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + rewriter, targetInfo, /*crossGrain = */ false, + [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); }); From 565f6e934f6e9826e51843e40ac64a376562feb9 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Thu, 21 Nov 2024 00:24:35 +0000 Subject: [PATCH 3/7] added fix from Ravil for instruction hint --- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 8 ++++++++ test/TritonGPU/amd/amd-instruction-sched.mlir | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e05fd4b81ec8..52c303527425 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -323,6 +323,10 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }; } else { // This section is only for inThreadTranspose for AMD path, where we want to @@ -356,6 +360,10 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy, store(vec, vecAddr) .setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); + if (llvmOpCount) { + ++(llvmOpCount->first); + llvmOpCount->second = vecTy; + } }; } success = emitTransferBetweenRegistersAndShared( diff --git a/test/TritonGPU/amd/amd-instruction-sched.mlir b/test/TritonGPU/amd/amd-instruction-sched.mlir index c550f8ea31b4..efc6964da48d 100644 --- a/test/TritonGPU/amd/amd-instruction-sched.mlir +++ b/test/TritonGPU/amd/amd-instruction-sched.mlir @@ -63,7 +63,7 @@ module { // INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<8, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>> - // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>> + // INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<16, vector<1xf16>> // INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>> // INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>> From ba56f56de2a761cc53ff6591b097389ac2f18332 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Sun, 24 Nov 2024 06:13:55 +0000 Subject: [PATCH 4/7] fixed test_dot --- .../amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp index a15d81f4caa0..405fef881411 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp @@ -133,6 +133,7 @@ ttg::BlockedEncodingAttr getThreadRakedBlockedEnc(Value operand, auto shape = tensorTy.getShape(); auto opEnc = tensorTy.getEncoding(); auto opDotOpEnc = dyn_cast(opEnc); + auto kWidth = opDotOpEnc.getKWidth(); int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; // get the current blocked encoding auto cvtOperand = operand.getDefiningOp()->getOperand(0); @@ -149,6 +150,8 @@ ttg::BlockedEncodingAttr getThreadRakedBlockedEnc(Value operand, auto bitwidth = tensorTy.getElementType().getIntOrFloatBitWidth(); // Current the widest is set to ds_write_b64 auto newKOuterDim = std::min(numMaxIters, 64 / bitwidth); + // the new vectorization needs to be bound by kWidth as well + newKOuterDim = std::min(newKOuterDim, kWidth); LDBG("Choose the minimum of numIters: " << numMaxIters << " and numDtype: " << 64 / bitwidth); SmallVector newSizePerThread(sizePerThread); From 45af9fc3ccc06249f29fa1b2fce1bc1f19ac7f63 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Mon, 25 Nov 2024 19:16:50 +0000 Subject: [PATCH 5/7] enforce inThreadTranspose to 2D to fix test_dot3d --- lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp | 4 +++- .../amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 979a85b6bda3..1fd93f9bb83e 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -28,8 +28,10 @@ void lowerDistributedToShared( // NV path because for non-KContig tensor their blocked and shared layout // still have the same order. if (auto blocked = dyn_cast(srcTy.getEncoding())) { + auto rank = blocked.getOrder().size(); auto inOrd = blocked.getOrder(); - crossGrain = inOrd[0] != outOrd[0]; + // it has to be 2D and blocked's and shared's order mismatch + crossGrain = (rank == 2) && (inOrd[0] != outOrd[0]); } assert(srcTy.getShape().size() <= 2 || (srcTy.getShape().size() == 3 && outOrd[2] == 0) && diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp index 405fef881411..d0d93e6e7afe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp @@ -95,6 +95,7 @@ SmallVector getLoadInsts(Operation *op) { } bool needCvtToThreadRaked(Value operand) { + auto opTensorTy = cast(operand.getType()); auto opEnc = opTensorTy.getEncoding(); auto opDotOpEnc = dyn_cast(opEnc); @@ -116,6 +117,11 @@ bool needCvtToThreadRaked(Value operand) { // make sure it is converted from blocked layout if (!blockedEnc) return false; + auto rank = blockedEnc.getOrder().size(); + if (rank != 2) { + LDBG("inThreadRake only supports 2D case right now"); + return false; + } // check whether it's contiguous on K dimension int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; auto order = blockedEnc.getOrder(); From d77a7913ff76f34ab9117fd5124d43e9de78e5d8 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Wed, 27 Nov 2024 22:03:16 +0000 Subject: [PATCH 6/7] triton_gpu to ttg --- test/TritonGPU/amd/in-thread-transpose.mlir | 46 ++++++++++----------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/test/TritonGPU/amd/in-thread-transpose.mlir b/test/TritonGPU/amd/in-thread-transpose.mlir index 747bfd6477f6..9a3074ba0611 100644 --- a/test/TritonGPU/amd/in-thread-transpose.mlir +++ b/test/TritonGPU/amd/in-thread-transpose.mlir @@ -1,54 +1,54 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-in-thread-transpose | FileCheck %s -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { -// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> -// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x256x!tt.ptr, [[threadrake_layout]]> +// CHECK: [[threadrake_layout:#.*]] = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<64x256x!tt.ptr, [[threadrake_layout]]> // CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<64x256x!tt.ptr, [[threadrake_layout]]> - tt.func public @threadRake_transpose_b(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { + tt.func public @threadRake_transpose_b(%arg0: tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + %2 = ttg.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> tt.return } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { -// CHECK: [[threadrake_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> -// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<32x128x!tt.ptr, [[threadrake_layout]]> +// CHECK: [[threadrake_layout:#.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}> +// CHECK: [[load_ptr:%.*]] = ttg.convert_layout {{.*}} -> tensor<32x128x!tt.ptr, [[threadrake_layout]]> // CHECK: {{.*}} = tt.load [[load_ptr]] : tensor<32x128x!tt.ptr, [[threadrake_layout]]> - tt.func public @threadRake_transpose_b_no_change(%arg0: tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x128x!tt.ptr, #blocked1>) { + tt.func public @threadRake_transpose_b_no_change(%arg0: tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x128x!tt.ptr, #blocked1>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> %1 = tt.load %arg1 : tensor<32x128x!tt.ptr, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %2 = ttg.convert_layout %1 : tensor<32x128xf16, #blocked1> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> tt.return } } // ----- -#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}> -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 2], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { -// CHECK-NOT: {{.*}} = triton_gpu.convert_layout {{.*blocked.*}} -> {{.*blocked.*}} - tt.func public @threadRake_no_transpose(%arg0: tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { +// CHECK-NOT: {{.*}} = ttg.convert_layout {{.*blocked.*}} -> {{.*blocked.*}} + tt.func public @threadRake_no_transpose(%arg0: tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<64x256x!tt.ptr, #blocked1>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> %1 = tt.load %arg1 : tensor<64x256x!tt.ptr, #blocked1> - %2 = triton_gpu.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> + %2 = ttg.convert_layout %1 : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %6 = tt.dot %arg0, %2, %cst_0 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> tt.return } } From 3ca69ea57585e9cfa93dac01c5aec2bf84912c66 Mon Sep 17 00:00:00 2001 From: Jingning Tang Date: Wed, 27 Nov 2024 22:04:23 +0000 Subject: [PATCH 7/7] fix --- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index fc93365e2ae8..6e8775e8b75e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -300,7 +300,7 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef shape, // pattern. // // For example, consider the following blocked layout generated by -// AMDGPUInThreadTranspose: #blocked1 = #triton_gpu.blocked<{sizePerThread = +// AMDGPUInThreadTranspose: #blocked1 = #ttg.blocked<{sizePerThread = // [4, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>. // Here since sizePerThread is 2D, there could be two ways to traverse it: along // the dim of 8 or the dim of 4. The regular toLinearLayout() would go through