diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index ed2a7b5484..911a898495 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -195,6 +195,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { // CHECK: llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects, no_unwind, will_return} + // CHECK: llvm.func spir_funccc @_Z22get_sub_group_local_idv() -> i32 // CHECK-LABEL: llvm.func spir_kernelcc @broadcast( // CHECK-SAME: [[VAL_0:%.*]]: f32) -> vector<16xf32> @@ -209,6 +210,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return %2 : tensor<16x16xf32> } + // CHECK-LABEL: llvm.func spir_kernelcc @broadcast_range() -> vector<16xi32> + tt.func public @broadcast_range() -> tensor<16x16xi32> { + // CHECK: [[LAST_CONST:%.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: [[RANGE:%.*]] = llvm.insertelement [[LAST_CONST]], {{%.*}}[[[LAST_CONST]] : i32] : vector<16xi32> + // CHECK: [[LANE_ID:%.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_idv() + // CHECK: [[EXTRACT:%.*]] = llvm.extractelement [[RANGE]][[[LANE_ID]] : i32] : vector<16xi32> + // CHECK: [[EMPTY:%.*]] = llvm.mlir.poison : vector<1xi32> + // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: [[VEC:%.*]] = llvm.insertelement [[EXTRACT]], [[EMPTY]][[[ZERO]] : i32] : vector<1xi32> + // CHECK: llvm.shufflevector [[VEC]], [[EMPTY]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<1xi32> + %0 = tt.make_range {start = 0 : i32, end = 16 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #warp}>> -> tensor<1x16xi32, #warp> + %2 = tt.broadcast %1 : tensor<1x16xi32, #warp> -> tensor<16x16xi32> + tt.return %2 : tensor<16x16xi32> + } + // CHECK-LABEL: llvm.func spir_kernelcc @addptr( // CHECK-SAME: [[VAL_0:%.*]]: !llvm.ptr<1>) -> !llvm.ptr<1> attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [128 : i32, 1 : i32, 1 : i32]} tt.func public @addptr(%arg0: !tt.ptr) -> !tt.ptr { diff --git a/test/TritonIntelGPU/match-target-size.mlir b/test/TritonIntelGPU/match-target-size.mlir index 1fe8fa90ff..9c2a5f5e9f 100644 --- a/test/TritonIntelGPU/match-target-size.mlir +++ b/test/TritonIntelGPU/match-target-size.mlir @@ -491,7 +491,7 @@ tt.func public @attn_fwd(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.pt // ----- #warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { tt.func public @_attn_fwd(%arg0: i32, %arg1: !tt.ptr) { // COM: This op primes the map of known layouts %cst = arith.constant dense<1> : tensor<16x64xi32, #warp> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index f23b098393..f4cabc7d02 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -500,6 +500,12 @@ class SplatOpConversion : public ConvertTritonGPUOpToLLVMPattern { insert_element(vecTy, poison, adaptor.getSrc(), rewriter.create(loc, i32_ty, 0)); Type convertedTy = typeConverter->convertType(resultType); + if (!isa(convertedTy)) { + // On the advance path, the type converter reduces 1-element vectors to + // their element type, making this splat a no-op. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } int64_t num = cast(convertedTy).getNumElements(); SmallVector indices(num, 0); Value result = rewriter.create( @@ -573,8 +579,36 @@ class BroadcastOpConversion LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp(op, adaptor.getSrc()); - return success(); + constexpr unsigned subgroupSize = 16; + + auto srcShape = op.getSrc().getType().getShape(); + auto dstShape = op.getType().getShape(); + assert(srcShape.size() == 2 && dstShape.size() == 2 && + "Expected 2D broadcast"); + assert(dstShape[1] == subgroupSize && "Unexpected result shape"); + + if (srcShape[0] == dstShape[0]) { + // Example: 16x1 --> 16x16 broadcast. Each thread in the subgroup will get + // the same value, so we use the source operand directly. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } + + if (srcShape[1] == dstShape[1]) { + // Example: 1x16 --> 8x16 broadcast. We have extract the element + // corresponding to the thread's lane ID and splat it to the desired + // result size. + Location loc = op.getLoc(); + Value laneId = rewriter.create(loc, i32_ty); + Value extract = rewriter.create( + loc, adaptor.getSrc(), laneId); + Value splat = + rewriter.create(loc, op.getType(), extract); + rewriter.replaceOp(op, splat); + return success(); + } + + return failure(); } }; @@ -650,6 +684,32 @@ class AddPtrOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; +class MakeRangeOpConversion + : public ConvertTritonGPUOpToLLVMPattern { +public: + using ConvertTritonGPUOpToLLVMPattern< + MakeRangeOp>::ConvertTritonGPUOpToLLVMPattern; + LogicalResult + matchAndRewrite(MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Note: On the default path, the lowering of `tt.make_range` takes the + // tensor layout into account. To that end, there is a dedicated lowering + // pattern in `MakeRangeOpToLLVM.cpp`. However, with the assumed dense + // layout in the advanced path, we can just emit a sequence of integers. + + Location loc = op->getLoc(); + Value vec = rewriter.create( + loc, getTypeConverter()->convertType(op.getType())); + for (int i = op.getStart(); i < op.getEnd(); ++i) { + auto valI = LLVM::createConstantI32(loc, rewriter, i); + vec = rewriter.create(loc, vec, valI, valI); + } + + rewriter.replaceOp(op, vec); + return success(); + } +}; + } // namespace void mlir::triton::intel::populateTritonOpsToLLVMPatterns( @@ -670,4 +730,5 @@ void mlir::triton::intel::populateTritonOpsToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index 66e64060fb..61b075b863 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -1080,7 +1080,9 @@ void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) { } void MatchTargetSizePass::transformMakeRangeOp(tt::MakeRangeOp op) { - constexpr unsigned subgroupSize = 16; + auto mod = op->getParentOfType(); + int subgroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + unsigned start = op.getStart(); unsigned end = op.getEnd(); assert(start == 0 && end % subgroupSize == 0 && "Unsupported range");