Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTIG-TO-LLVM] Support row-vector broadcasts and make_range #2046

Merged
merged 11 commits into from
Sep 12, 2024
17 changes: 17 additions & 0 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {

// CHECK: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
// CHECK: llvm.func spir_funccc @_Z22get_sub_group_local_idv() -> i32

// CHECK-LABEL: llvm.func spir_kernelcc @broadcast(
// CHECK-SAME: [[VAL_0:%.*]]: f32) -> vector<16xf32>
Expand All @@ -209,6 +210,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.return %2 : tensor<16x16xf32>
}

// CHECK-LABEL: llvm.func spir_kernelcc @broadcast_range() -> vector<16xi32>
tt.func public @broadcast_range() -> tensor<16x16xi32> {
// CHECK: [[LAST_CONST:%.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: [[RANGE:%.*]] = llvm.insertelement [[LAST_CONST]], {{%.*}}[[[LAST_CONST]] : i32] : vector<16xi32>
// CHECK: [[LANE_ID:%.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_idv()
// CHECK: [[EXTRACT:%.*]] = llvm.extractelement [[RANGE]][[[LANE_ID]] : i32] : vector<16xi32>
// CHECK: [[EMPTY:%.*]] = llvm.mlir.poison : vector<1xi32>
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.*]] = llvm.insertelement [[EXTRACT]], [[EMPTY]][[[ZERO]] : i32] : vector<1xi32>
// CHECK: llvm.shufflevector [[VEC]], [[EMPTY]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<1xi32>
%0 = tt.make_range {start = 0 : i32, end = 16 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x16xi32, #warp>
%2 = tt.broadcast %1 : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
tt.return %2 : tensor<16x16xi32>
}

// CHECK-LABEL: llvm.func spir_kernelcc @addptr(
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) -> !llvm.ptr<1> attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]}
tt.func public @addptr(%arg0: !tt.ptr<f16>) -> !tt.ptr<f16> {
Expand Down
2 changes: 1 addition & 1 deletion test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.pt
// -----

#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @_attn_fwd(%arg0: i32, %arg1: !tt.ptr<i32>) {
// COM: This op primes the map of known layouts
%cst = arith.constant dense<1> : tensor<16x64xi32, #warp>
Expand Down
65 changes: 63 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,12 @@ class SplatOpConversion : public ConvertTritonGPUOpToLLVMPattern<SplatOp> {
insert_element(vecTy, poison, adaptor.getSrc(),
rewriter.create<LLVM::ConstantOp>(loc, i32_ty, 0));
Type convertedTy = typeConverter->convertType(resultType);
if (!isa<VectorType>(convertedTy)) {
// On the advance path, the type converter reduces 1-element vectors to
// their element type, making this splat a no-op.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
int64_t num = cast<VectorType>(convertedTy).getNumElements();
SmallVector<int32_t> indices(num, 0);
Value result = rewriter.create<LLVM::ShuffleVectorOp>(
Expand Down Expand Up @@ -573,8 +579,36 @@ class BroadcastOpConversion
LogicalResult
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getSrc());
return success();
constexpr unsigned subgroupSize = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to get subgroup size from module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm querying the triton_intel_gpu.min_sg_size attribute now, is that the correct one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would query triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod), as the minimum one may not be the selected one, although it always is on the advanced path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the advanced path, threads-per-warp is always 1 IIRC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at this stage, I can see that multiple patterns query getThreadsPerWarp in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right of course. Fixed.


auto srcShape = op.getSrc().getType().getShape();
auto dstShape = op.getType().getShape();
assert(srcShape.size() == 2 && dstShape.size() == 2 &&
"Expected 2D broadcast");
assert(dstShape[1] == subgroupSize && "Unexpected result shape");

if (srcShape[0] == dstShape[0]) {
// Example: 16x1 --> 16x16 broadcast. Each thread in the subgroup will get
// the same value, so we use the source operand directly.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}

if (srcShape[1] == dstShape[1]) {
// Example: 1x16 --> 8x16 broadcast. We have extract the element
// corresponding to the thread's lane ID and splat it to the desired
// result size.
Location loc = op.getLoc();
Value laneId = rewriter.create<TritonGEN::SubgroupLocalIdOp>(loc, i32_ty);
Value extract = rewriter.create<LLVM::ExtractElementOp>(
loc, adaptor.getSrc(), laneId);
Value splat =
rewriter.create<mlir::triton::SplatOp>(loc, op.getType(), extract);
rewriter.replaceOp(op, splat);
return success();
}

return failure();
}
};

Expand Down Expand Up @@ -650,6 +684,32 @@ class AddPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern<AddPtrOp> {
}
};

class MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<MakeRangeOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Note: On the default path, the lowering of `tt.make_range` takes the
// tensor layout into account. To that end, there is a dedicated lowering
// pattern in `MakeRangeOpToLLVM.cpp`. However, with the assumed dense
// layout in the advanced path, we can just emit a sequence of integers.

Location loc = op->getLoc();
Value vec = rewriter.create<LLVM::UndefOp>(
loc, getTypeConverter()->convertType(op.getType()));
for (int i = op.getStart(); i < op.getEnd(); ++i) {
auto valI = LLVM::createConstantI32(loc, rewriter, i);
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec, valI, valI);
}

rewriter.replaceOp(op, vec);
return success();
}
};

} // namespace

void mlir::triton::intel::populateTritonOpsToLLVMPatterns(
Expand All @@ -670,4 +730,5 @@ void mlir::triton::intel::populateTritonOpsToLLVMPatterns(
patterns.add<ReduceOpConversion>(typeConverter, benefit);
patterns.add<SubGroupTransposeOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,9 @@ void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) {
}

void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) {
constexpr unsigned subgroupSize = 16;
auto mod = op->getParentOfType<mlir::ModuleOp>();
int subgroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

unsigned start = op.getStart();
unsigned end = op.getEnd();
assert(start == 0 && end % subgroupSize == 0 && "Unsupported range");
Expand Down