Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTIG-TO-LLVM] Support row-vector broadcasts and make_range #2046

Merged
merged 11 commits into from
Sep 12, 2024
17 changes: 17 additions & 0 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {

// CHECK: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {passthrough = ["nounwind", "willreturn", ["memory", "0"]]}
// CHECK: llvm.func spir_funccc @_Z22get_sub_group_local_idv() -> i32

// CHECK-LABEL: llvm.func spir_kernelcc @broadcast(
// CHECK-SAME: [[VAL_0:%.*]]: f32) -> vector<16xf32>
Expand All @@ -209,6 +210,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.return %2 : tensor<16x16xf32>
}

// CHECK-LABEL: llvm.func spir_kernelcc @broadcast_range() -> vector<16xi32>
tt.func public @broadcast_range() -> tensor<16x16xi32> {
// CHECK: [[LAST_CONST:%.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: [[RANGE:%.*]] = llvm.insertelement [[LAST_CONST]], {{%.*}}[[[LAST_CONST]] : i32] : vector<16xi32>
// CHECK: [[LANE_ID:%.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_idv()
// CHECK: [[EXTRACT:%.*]] = llvm.extractelement [[RANGE]][[[LANE_ID]] : i32] : vector<16xi32>
// CHECK: [[EMPTY:%.*]] = llvm.mlir.poison : vector<1xi32>
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.*]] = llvm.insertelement [[EXTRACT]], [[EMPTY]][[[ZERO]] : i32] : vector<1xi32>
// CHECK: llvm.shufflevector [[VEC]], [[EMPTY]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<1xi32>
%0 = tt.make_range {start = 0 : i32, end = 16 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>>
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x16xi32, #warp>
%2 = tt.broadcast %1 : tensor<1x16xi32, #warp> -> tensor<16x16xi32>
tt.return %2 : tensor<16x16xi32>
}

// CHECK-LABEL: llvm.func spir_kernelcc @addptr(
// CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) -> !llvm.ptr<1> attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]}
tt.func public @addptr(%arg0: !tt.ptr<f16>) -> !tt.ptr<f16> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ void populateArithOpsToLLVMPatterns(

void populateTritonOpsToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, PatternBenefit benefit);
RewritePatternSet &patterns, PatternBenefit benefit,
bool isAdvancedPathEnabled);

void populateBarrierOpToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ class TritonGPUToLLVMPipelineManager {
intel::populateBF16CastsLLVMPatterns(typeConverter, patterns, benefit);
intel::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
benefit);
intel::populateTritonOpsToLLVMPatterns(typeConverter, patterns, benefit);
intel::populateTritonOpsToLLVMPatterns(typeConverter, patterns, benefit,
jopperm marked this conversation as resolved.
Show resolved Hide resolved
isAdvancedPathEnabled);
} else {
intel::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo,
patterns, benefit);
Expand Down
71 changes: 68 additions & 3 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,12 @@ class SplatOpConversion : public ConvertTritonGPUOpToLLVMPattern<SplatOp> {
insert_element(vecTy, poison, adaptor.getSrc(),
rewriter.create<LLVM::ConstantOp>(loc, i32_ty, 0));
Type convertedTy = typeConverter->convertType(resultType);
if (!isa<VectorType>(convertedTy)) {
// On the advance path, the type converter reduces 1-element vectors to
// their element type, making this splat a no-op.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
int64_t num = cast<VectorType>(convertedTy).getNumElements();
SmallVector<int32_t> indices(num, 0);
Value result = rewriter.create<LLVM::ShuffleVectorOp>(
Expand Down Expand Up @@ -573,8 +579,39 @@ class BroadcastOpConversion
LogicalResult
matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getSrc());
return success();
constexpr unsigned subgroupSize = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to get subgroup size from module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm querying the triton_intel_gpu.min_sg_size attribute now, is that the correct one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would query triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod), as the minimum one may not be the selected one, although it always is on the advanced path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the advanced path, threads-per-warp is always 1 IIRC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at this stage, I can see that multiple patterns query getThreadsPerWarp in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right of course. Fixed.


auto srcShape = op.getSrc().getType().getShape();
auto dstShape = op.getType().getShape();
assert(srcShape.size() == 2 && dstShape.size() == 2 &&
"Expected 2D broadcast");
assert(dstShape[1] == subgroupSize && "Unexpected result shape");

if (srcShape[0] == dstShape[0]) {
// Example: 16x1 --> 16x16 broadcast. Each thread in the subgroup will get
// the same value, so we use the source operand directly.
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}

if (srcShape[1] == dstShape[1]) {
// Example: 1x16 --> 8x16 broadcast. We have extract the element
// corresponding to the thread's lane ID and splat it to the desired
// result size.
auto loc = op.getLoc();
Value laneId = rewriter.create<mlir::gpu::LaneIdOp>(
loc, /*upper_bound=*/IntegerAttr{});
Value cast = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), laneId);
Value extract =
rewriter.create<LLVM::ExtractElementOp>(loc, adaptor.getSrc(), cast);
Value splat =
rewriter.create<mlir::triton::SplatOp>(loc, op.getType(), extract);
rewriter.replaceOp(op, splat);
return success();
}

return failure();
}
};

Expand All @@ -598,11 +635,37 @@ class AddPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern<AddPtrOp> {
}
};

class MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<MakeRangeOp> {
public:
using ConvertTritonGPUOpToLLVMPattern<
MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// FIXME: The real lowering has to take the layout into account. Here, we're
// just emitting a sequence of ints. Use
// `third_party/intel/lib/TritonIntelGPUToLLVM/MakeRangeOpToLLVM.cpp`
// instead!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typically, we have a separate conversion for triton ops, that's why this file stands.
what's meaning here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rephrased the comment to a bit to explain the lowering to a sequence of ints is the correct lowering for the advanced path, assuming dense layouts there.

auto loc = op->getLoc();
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
Value vec = rewriter.create<LLVM::UndefOp>(
loc, getTypeConverter()->convertType(op.getType()));
for (int i = op.getStart(); i < op.getEnd(); ++i) {
auto valI = LLVM::createConstantI32(loc, rewriter, i);
vec = rewriter.create<LLVM::InsertElementOp>(loc, vec, valI, valI);
}

rewriter.replaceOp(op, vec);
return success();
}
};

} // namespace

void mlir::triton::intel::populateTritonOpsToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, PatternBenefit benefit) {
RewritePatternSet &patterns, PatternBenefit benefit,
bool isAdvancedPathEnabled) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AdvanceOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
Expand All @@ -617,4 +680,6 @@ void mlir::triton::intel::populateTritonOpsToLLVMPatterns(
patterns.add<MakeTensorPtrOpConversion>(typeConverter, benefit);
patterns.add<ReduceOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
if (isAdvancedPathEnabled)
jopperm marked this conversation as resolved.
Show resolved Hide resolved
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
}
9 changes: 7 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter(
addConversion([&](mlir::RankedTensorType type) -> mlir::Type {
unsigned num = type.getNumElements();
Type elmTy = type.getElementType();
if (!type.getEncoding() ||
isa<mlir::triton::gpu::DotOperandEncodingAttr>(type.getEncoding()))
if ((!type.getEncoding() ||
isa<mlir::triton::gpu::DotOperandEncodingAttr>(
type.getEncoding())) &&
// FIXME: Intended to exclude row vectors occuring in the attention
// mask computation; probably won't work in general.
!(type.getElementType().isInteger(32) && type.getShape()[0] == 1)) {
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
num /= 16;
}
if (num == 1)
return elmTy;
return mlir::VectorType::get(num, elmTy);
Expand Down