Skip to content

Commit

Permalink
Refactors code (shapePerCTATile, isa<>)
Browse files Browse the repository at this point in the history
  • Loading branch information
hmalgewatta committed Nov 19, 2024
1 parent 7b20b6b commit 36c425f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 25 deletions.
4 changes: 2 additions & 2 deletions test/Conversion/amd/invalid_extractslice_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// Invalid size
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTA [256, 16]}}
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}}
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1>
tt.return
}
Expand Down Expand Up @@ -33,7 +33,7 @@ tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibili
// Invalid offset, not multiple of shapePerTile
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTA [256, 16]}}
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}}
%1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
tt.return
}
Expand Down
27 changes: 17 additions & 10 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,16 @@ LogicalResult ExtractSliceOp::verify() {
return emitError("result layout must match source layout");

auto srcShape = srcTy.getShape();
auto shapePerCTA = mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] = std::min(static_cast<unsigned>(srcShape[0]), shapePerCTA[0]);
shapePerCTA[1] = std::min(static_cast<unsigned>(srcShape[1]), shapePerCTA[1]);
auto shapePerCTATile =
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);

// ExtractSlice only supports slicing where offsets and sizes are multiples of
// shapePerCTA. This condition ensures that slice has the same layout as the
// original tensor.
// shapePerCTATile. This condition ensures that slice has the same layout as
// the original tensor.

auto offsets = getStaticOffsets();
if (offsets.size() != 2) {
Expand Down Expand Up @@ -110,14 +113,18 @@ LogicalResult ExtractSliceOp::verify() {
sizes.push_back(resultDimSize);
}

if (sizes[0] % shapePerCTA[0] != 0 || sizes[1] % shapePerCTA[1] != 0) {
return emitError("incorrect static size");
if (sizes[0] % shapePerCTATile[0] != 0 ||
sizes[1] % shapePerCTATile[1] != 0) {
return emitError() << "sizes [" << sizes
<< "] must be a multiple of shapePerCTATile ["
<< shapePerCTATile << "]";
}

if (offsets[0] % shapePerCTA[0] != 0 || offsets[1] % shapePerCTA[1] != 0) {
if (offsets[0] % shapePerCTATile[0] != 0 ||
offsets[1] % shapePerCTATile[1] != 0) {
return emitError() << "offset [" << offsets
<< "] must be a multiple of shapePerCTA [" << shapePerCTA
<< "]";
<< "] must be a multiple of shapePerCTATile ["
<< shapePerCTATile << "]";
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ struct ExtractSliceOpConversion
auto order = triton::gpu::getOrder(srcLayout);

// Calculate valid total number of workers in each dimension
auto shapePerCTA = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTA[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTA[0]);
shapePerCTA[1] =
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTA[1]);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);

// Rank == 2 checked in the verifier
SmallVector<int64_t, 2> sizes;
Expand All @@ -85,12 +85,12 @@ struct ExtractSliceOpConversion
auto offsets = op.getStaticOffsets();

// Calculate offsets and sizes in terms of CTA units.
std::array<int64_t,2> CTAOffsets{offsets[0] / shapePerCTA[0],
offsets[1] / shapePerCTA[1]};
std::array<int64_t,2> CTASizes{sizes[0] / shapePerCTA[0],
sizes[1] / shapePerCTA[1]};
std::array<int64_t,2> CTAPerShape{srcShape[0] / shapePerCTA[0],
srcShape[1] / shapePerCTA[1]};
std::array<int64_t, 2> CTAOffsets{offsets[0] / shapePerCTATile[0],
offsets[1] / shapePerCTATile[1]};
std::array<int64_t, 2> CTASizes{sizes[0] / shapePerCTATile[0],
sizes[1] / shapePerCTATile[1]};
std::array<int64_t, 2> CTAPerShape{srcShape[0] / shapePerCTATile[0],
srcShape[1] / shapePerCTATile[1]};

// The diagram above illustrates the graphical representation of the
// skipElems, tensorStride, and lastIdx variables.
Expand Down Expand Up @@ -124,8 +124,8 @@ struct ExtractSliceOpConversion
matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType();
if (isa<BlockedEncodingAttr>(op.getSource().getType().getEncoding()) ||
isa<AMDMfmaEncodingAttr>(op.getSource().getType().getEncoding())) {
if (isa<BlockedEncodingAttr, AMDMfmaEncodingAttr>(
op.getSource().getType().getEncoding())) {
return processLayout(op, adaptor, rewriter);
} else {
assert(false && "Unsupported layout in viewSlice.");
Expand Down

0 comments on commit 36c425f

Please sign in to comment.