Skip to content

Commit

Permalink
[LLVMGPUVectorDistribute] Add support for inter-subgroup multi_reduct…
Browse files Browse the repository at this point in the history
…ion (iree-org#19596)

This commit adds support for distribute multi_reductions where the
reduction dimension(s) is/are distributed across subgroups.

We perform the existing reduction distribution, however, we are left
with partial reductions accross subgroups.

Thereafter, we insert tranfer_write / transfer_read to shared memory to
achieve a layout change where
we re-distribute reduction subgroup tiles into element tile. Finally, we
do another multi_reduction to complete the reduction.

closes: iree-org#19578

---------

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak authored Jan 14, 2025
1 parent 21b0101 commit 01c9f14
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) {
/// by doing a butterfly shuffle.
/// 3. Accumulator Reduce: Each thread reduces it's intermediate reduced
/// results with the accumulator it holds.
/// Currently, reduction across warps is not supported, but it would just add
/// another step, Warp Reduce, where threads do an atomic addition on a buffer.
/// 4. Subgroup reduce : each subgroup will store the partial reductions
/// to shared memory and will be reloaded into a layout where partial
/// reductions will be placed inside threads.
struct DistributeMultiReduction final
: OpDistributionPattern<vector::MultiDimReductionOp> {
using OpDistributionPattern::OpDistributionPattern;
Expand Down Expand Up @@ -460,7 +461,6 @@ struct DistributeMultiReduction final
}

Location loc = multiReduceOp.getLoc();

SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
int64_t rank = srcVector.getType().getRank();

Expand Down Expand Up @@ -492,47 +492,65 @@ struct DistributeMultiReduction final
assert(locallyReduced && "result should have been a vector");

// Flatten the locally reduced value.
VectorValue threadReduced = locallyReduced;
VectorType shaped = locallyReduced.getType();
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, elemTy);
VectorValue flat =
rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
bool hasThreadReductions =
llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
return srcLayout.getThreadTile()[rDim] > 1;
});
if (hasThreadReductions) {
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, elemTy);
VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
locallyReduced);

// Do inter-thread/warp reduce.
FailureOr<VectorValue> threadReducedFlat = doThreadReduction(
rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
if (failed(threadReducedFlat)) {
return failure();
}

// Do inter-thread/warp reduce.
FailureOr<VectorValue> threadReduced = doThreadReduction(
rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
if (failed(threadReduced)) {
return failure();
// Do reduction against accumulator, which needs to be done after thread
// reduction.
threadReduced = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReducedFlat.value());
}

// Do reduction against accumulator, which needs to be done after thread
// reduction.
VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReduced.value());

if (!accVector) {
// Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
// because the following implementation requires the operand to be a
// vector.
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
}

Value accReduction = vector::makeArithReduction(
rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
}

if (resVector) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, ArrayRef{int64_t(0)});
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
bool hasSubgroupReductions =
llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
return srcLayout.getSubgroupTile()[rDim] > 1;
});
// We can exit here if its just a subgroup reduction.
if (!hasSubgroupReductions) {
Value accReduction = vector::makeArithReduction(
rewriter, loc, multiReduceOp.getKind(), threadReduced, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
}
if (resVector) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, ArrayRef{int64_t(0)});
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
}
return success();
}

// do inter-subgroup reductions
Value subgroupReduced = doSubgroupReduction(
rewriter, loc, srcVector, srcLayout, multiReduceOp.getReductionDims(),
threadReduced, multiReduceOp.getKind(), acc, signature[resVector]);
rewriter.replaceOp(multiReduceOp, subgroupReduced);
return success();
}

Expand Down Expand Up @@ -569,10 +587,185 @@ struct DistributeMultiReduction final

res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
}

return res;
}

// The reductions across subgroups are performed
// as follows:
// 1) Re-cover the subgroup-local result as the same rank as the
// input vector
// 2) Write the subgroup-local reduced vector to shared memory
// 3) Read the subgroup-local reduced vector where partially reduced
// subgroup tile is read as the element tile.
// 4) Perform a second reduction to complete the reduction.
Value doSubgroupReduction(PatternRewriter &rewriter, Location loc,
VectorValue srcVector, NestedLayoutAttr srcLayout,
ArrayRef<int64_t> reductionDims,
VectorValue threadReduced,
vector::CombiningKind kind, Value acc,
VectorLayoutInterface resLayout) const {
// Subgroup-local / thread-local vector.multi_reduce operations
// will remove the reduction dimensions by definition.
// e.g.:
// p1 x p2 x p3 x r2 x r1 --> p1 x p2 x p3
// However, the reduction is not complete until inter-subgroup results
// are combined. Therefore, we need to maintain the rank to get them back to
// the SIMD domain to re-layout the vector.
// Thus, we re-insert the reduction dimensions in
// their original positions as :
// p1 x p2 x p3 -> p1 x p2 x p3 x 1 x 1
int64_t rank = srcLayout.getRank();
SmallVector<int64_t> partialReducedDistributedShape =
srcLayout.getDistributedShape();
for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
int64_t tileGroupOffset = tileGroupIdx * rank;
for (int64_t rDim : reductionDims) {
partialReducedDistributedShape[tileGroupOffset + rDim] = 1;
}
}
VectorType partialReducedDistributedType = VectorType::get(
partialReducedDistributedShape, srcVector.getType().getElementType());
Value isoRankThreadReduced = rewriter.create<vector::ShapeCastOp>(
loc, partialReducedDistributedType, threadReduced);

SmallVector<int64_t> preDistrShape =
srcLayout.getUndistributedPackedShape();
SmallVector<int64_t> partialReductionShape =
llvm::to_vector(srcVector.getType().getShape());
for (int64_t rDim : reductionDims) {
// The first #rank elements will form the subgroup tile
// Here we replace the input shape with subgroup tile
// because every other tile is reduced except the subgroup
// tile.
partialReductionShape[rDim] = preDistrShape[rDim];
}
auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
rewriter.getContext(), gpu::AddressSpace::Workgroup));
MemRefType allocType = MemRefType::get(
partialReductionShape, srcVector.getType().getElementType(),
AffineMap(), workgroupMemoryAddressSpace);
auto alloc = rewriter.create<memref::AllocOp>(loc, allocType);
VectorType unDistributedType = VectorType::get(
partialReductionShape, srcVector.getType().getElementType());
Value undistrWrite = rewriter.create<IREE::VectorExt::ToSIMDOp>(
loc, unDistributedType, isoRankThreadReduced);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(unDistributedType.getRank(), c0);
SmallVector<bool> inBounds(unDistributedType.getRank(), true);
// Insert gpu.barrier to make sure previuos iteration
// of batch loop has fully read the subgroup partial
// reductions.
rewriter.create<gpu::BarrierOp>(loc);
auto write = rewriter.create<vector::TransferWriteOp>(
loc, undistrWrite, alloc, indices, inBounds);
// Set layouts signature for write.
// We need to set the layout on the srcVector/first operand.
auto unitAttr = UnitAttr::get(rewriter.getContext());
{
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(srcLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(srcLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(srcLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(srcLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(srcLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(srcLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(srcLayout.getThreadStrides());
// Replace the reduced tiles with unit dimension.
for (int64_t rDim : reductionDims) {
batchTileLens[rDim] = 1;
outerTileLens[rDim] = 1;
threadTileLens[rDim] = 1;
elementTileLens[rDim] = 1;
threadStrides[rDim] = 0;
}
auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
threadTileLens, elementTileLens, subgroupStrides, threadStrides);
auto writeAttrs =
SmallVector<Attribute>(write->getNumOperands(), unitAttr);
writeAttrs[0] = interSubGroupLayout;
ArrayAttr writeOperandsAttr =
ArrayAttr::get(rewriter.getContext(), writeAttrs);
ArrayAttr writeResultsAttr = ArrayAttr::get(rewriter.getContext(), {});
setSignatureForRedistribution(rewriter, write.getOperation(),
writeOperandsAttr, writeResultsAttr);
}
// Insert gpu.barrier
rewriter.create<gpu::BarrierOp>(write.getLoc());
auto read = rewriter.create<vector::TransferReadOp>(loc, unDistributedType,
alloc, indices);
// Create new layout where subgroup dims are squashed to
// element tile
IREE::VectorExt::NestedLayoutAttr intraSubGroupLayout;
{
// We intentionally make the subgroup tile to be 1
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(srcLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(srcLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(srcLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(srcLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(srcLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(srcLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(srcLayout.getThreadStrides());
for (int64_t rDim : reductionDims) {
subgroupTileLens[rDim] = 1;
batchTileLens[rDim] = 1;
outerTileLens[rDim] = 1;
threadTileLens[rDim] = 1;
// the partial reductions that was across subgroups will
// will be loaded as element tile. We can revisit if this
// need to be something else such as thread tile.
elementTileLens[rDim] = srcLayout.getSubgroupTile()[rDim];
subgroupStrides[rDim] = 0;
threadStrides[rDim] = 0;
}
intraSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
threadTileLens, elementTileLens, subgroupStrides, threadStrides);
auto readAttrs = SmallVector<Attribute>(read->getNumOperands(), unitAttr);
ArrayAttr readOperandsAttr =
ArrayAttr::get(rewriter.getContext(), readAttrs);
ArrayAttr readResultsAttr =
ArrayAttr::get(rewriter.getContext(), {intraSubGroupLayout});
setSignatureForRedistribution(rewriter, read.getOperation(),
readOperandsAttr, readResultsAttr);
}

// A newly created reduction to complete the reduction
// that reduces the data that was otherwise was on
// different subgroups.
auto secondReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, kind, read, acc, reductionDims);
{
auto reduceAttrs =
SmallVector<Attribute>(secondReduction->getNumOperands(), unitAttr);
reduceAttrs[0] = intraSubGroupLayout;
ArrayAttr reduceResultsAttr =
ArrayAttr::get(rewriter.getContext(), {unitAttr});
if (auto dstLayout = dyn_cast_or_null<NestedLayoutAttr>(resLayout)) {
reduceAttrs[1] = dstLayout;
reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
}
ArrayAttr reduceOperandsAttr =
ArrayAttr::get(rewriter.getContext(), reduceAttrs);
setSignatureForRedistribution(rewriter, secondReduction.getOperation(),
reduceOperandsAttr, reduceResultsAttr);
}
return secondReduction.getResult();
}

int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ struct VectorDistributionListener : public RewriterBase::Listener {
void notifyOperationModified(Operation *op) override {
if (op->hasAttr(kVectorLayoutRedistributeAttrName) &&
op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName)) {
op->removeAttr(kVectorLayoutRedistributeAttrName);
toBeDistributed.push_back(op);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"gpu_lower_to_ukernels.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
"gpu_nested_layout_vector_distribution.mlir",
"gpu_nested_layout_vector_distribution_multi_reduce.mlir",
"gpu_nested_layout_vector_distribution_step.mlir",
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ iree_lit_test_suite(
"gpu_lower_to_ukernels.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
"gpu_nested_layout_vector_distribution_multi_reduce.mlir"
"gpu_nested_layout_vector_distribution_step.mlir"
"gpu_pack_to_instrinsics.mlir"
"gpu_pad_operands.mlir"
Expand Down
Loading

0 comments on commit 01c9f14

Please sign in to comment.