Skip to content

Commit

Permalink
[BACKEND] Support convert_layout with num_ctas > 1 Using Linear L…
Browse files Browse the repository at this point in the history
…ayout (triton-lang#4782)

Particularly, this PR implements layout conversion when a CGA contains
more than one CTA. In such cases, a Triton tensor is split into multiple
blocks, with each block being handled by a CTA.

```
block0 | block1
----------------
block2 | block3
```

If data transfer is required from block0 to block3, this PR cannot
handle it, and we use `isCrossCTAConversion` to check this condition.
  • Loading branch information
Jokeren authored and Luosuu committed Nov 13, 2024
1 parent f5f367d commit 405f6fa
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 21 deletions.
9 changes: 9 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
// dimension, determines if the layout moves data across block boundaries.
bool isCrossCTAConversion(const LinearLayout &layout);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
// output dimensions.
//
// Note that this behavior differs from calling
// `LinearLayout::sublayout(inDimNames, outDimNames)` when "block" is not in
// `inDimNames`. The latter does not modify the output sizes.
LinearLayout getLayoutWithinBlock(const LinearLayout &layout);

// In this function, we construct a linear layout representing the
// <shared memory offset, iteration, block> -> <tensor element index> mapping
// for entire `src` and `dst` tensors. We determine the shape of the
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ class LinearLayout {
//
// TODO(jlebar): Implement divideLeft.
// std::optional<LinearLayout> divideLeft(const LinearLayout &divisor);
std::optional<LinearLayout> divideRight(const LinearLayout &divisor);
std::optional<LinearLayout> divideRight(const LinearLayout &divisor) const;

// Gets a layout with only these in/out dimensions.
//
Expand Down
26 changes: 13 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// The following tasks must be completed before we can remove the layoutIsOK
// check:
// 1. Support for AMD's MFMA and WMMA
// 2. Handling NVIDIA's MMA layout when CTA per CGA > 1
std::function<bool(Attribute)> layoutIsOK = [&](Attribute layout) {
if (auto nvidiaMma = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
return false;
}
if (useLegacyMMAConversion) {
return false;
}
Expand Down Expand Up @@ -419,8 +415,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

SmallVector<Value> outVals = transferWithinBlockOrGroupImpl(
inVals, conversion, op, srcLayout, dstLayout, adaptor, rewriter);
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
SmallVector<Value> outVals =
transferWithinBlock(inVals, op, srcLayoutWithinBlock,
dstLayoutWithinBlock, adaptor, rewriter);

// Unmunge output values
for (const auto &it : llvm::enumerate(outVals)) {
Expand All @@ -437,11 +436,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return success();
}

SmallVector<Value> transferWithinBlockOrGroupImpl(
ArrayRef<Value> inVals, const LinearLayout &conversion,
ConvertLayoutOp op, const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value>
transferWithinBlock(ArrayRef<Value> inVals, ConvertLayoutOp op,
const LinearLayout &srcLayout,
const LinearLayout &dstLayout, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();

Expand All @@ -459,11 +458,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion

auto scratchConfig =
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
auto tensorShape = convertType<unsigned, int64_t>(op.getType().getShape());
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
op.getSrc().getType().getEncoding(), op.getType().getShape()));
// Input dims: [offset, iteration, block]
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
ctx, tensorShape, scratchConfig.repShape, scratchConfig.order);
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);

// Layout for the store from registers to shared memory.
//
Expand Down
31 changes: 25 additions & 6 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,17 @@ bool isCrossCTAConversion(const LinearLayout &layout) {
!layout.sublayoutIsIdentity({kBlock}, {kBlock});
}

LinearLayout getLayoutWithinBlock(const LinearLayout &layout) {
assert(!layout.getInDimNames().empty());
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();

StringAttr kBlock = S("block");
assert(layout.hasInDim(kBlock));
auto bases = layout.getBases();
bases[kBlock] = {};
return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames()));
}

LinearLayout chooseShemLayoutForRegToRegConversion(
MLIRContext *ctx, ArrayRef<unsigned> tensorShape,
ArrayRef<unsigned> repShape, ArrayRef<unsigned> order) {
Expand Down Expand Up @@ -925,11 +936,11 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
if (order[0] != 1)
return false;

auto tensorShape = tensorTy.getShape();
if (tensorShape.size() != 2)
auto tensorShapePerCTA = getShapePerCTA(mmaLayout, tensorTy.getShape());
if (tensorShapePerCTA.size() != 2)
return false;
auto numIterations = ceil<unsigned>(tensorShape[1], repShape[1]) *
ceil<unsigned>(tensorShape[0], repShape[0]);
auto numIterations = ceil<unsigned>(tensorShapePerCTA[1], repShape[1]) *
ceil<unsigned>(tensorShapePerCTA[0], repShape[0]);
if (numIterations > 1)
return false;
if (paddedRepShape[1] % 8 != 0)
Expand Down Expand Up @@ -1020,6 +1031,7 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
StringAttr kWarp = S("warp");
StringAttr kCol = S("dim1");
StringAttr kRow = S("dim0");
StringAttr kBlock = S("block");

std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
std::vector<std::vector<int>> basesLane = {
Expand All @@ -1039,9 +1051,16 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
auto ret =
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
llvm::SmallDenseMap<StringAttr, int64_t> namedTensorShape;
namedTensorShape[kRow] = tensorShapePerCTA[0];
namedTensorShape[kCol] = tensorShapePerCTA[1];
ret = ensureLayoutNotSmallerThan(ret, namedTensorShape);
ret = ensureLayoutNotLargerThan(ret, namedTensorShape);
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
.reshapeOuts(
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
.reshapeOuts({{S("offset"), ret.getTotalOutDimSize()},
{S("iteration"), 1}}) *
identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")});
}

} // anonymous namespace
Expand Down
2 changes: 1 addition & 1 deletion lib/Tools/LinearLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) {
}

std::optional<LinearLayout>
LinearLayout::divideRight(const LinearLayout &divisor) {
LinearLayout::divideRight(const LinearLayout &divisor) const {
assertCommonDimsSameOrder(getOutDimNames(), divisor.getOutDimNames());
assertCommonDimsSameOrder(getInDimNames(), divisor.getInDimNames());

Expand Down

0 comments on commit 405f6fa

Please sign in to comment.