From fbeb3a671de0a8fb052b978e67f2787375885900 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 12 Sep 2024 14:21:16 +0100 Subject: [PATCH] Introduce support for buffer operations --- include/triton/Tools/Sys/GetEnv.hpp | 1 + python/triton/compiler/code_generator.py | 27 +- python/triton/compiler/compiler.py | 11 +- python/triton/runtime/jit.py | 8 +- python/tutorials/03-matrix-multiplication.py | 61 +++-- test/Conversion/amd/buffer_load_store.mlir | 39 +++ .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 178 ++++++++----- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 3 +- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 7 +- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 237 +++++++++--------- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 10 +- .../CanonicalizePointers.cpp | 3 +- 12 files changed, 330 insertions(+), 255 deletions(-) create mode 100644 test/Conversion/amd/buffer_load_store.mlir diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 43e7df13585c..e5132b6d36e5 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -13,6 +13,7 @@ namespace mlir::triton { inline const std::set CACHE_INVALIDATING_ENV_VARS = { // clang-format off "AMDGCN_ENABLE_DUMP", + "AMDGCN_USE_BUFFER_OPS", "DISABLE_FAST_REDUCTION", "DISABLE_LLVM_OPT", "DISABLE_MMA_V3", diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index f5f0d580f129..3e63893f44d6 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -62,24 +62,6 @@ def _is_triton_scalar(o: Any) -> bool: def _is_list_like(o: Any) -> bool: return isinstance(o, (list, tuple)) - -# def _convert_elem_to_ir_value(builder, elem, require_i64): -# if isinstance(elem, int): -# elem = tl.constexpr(elem) -# if isinstance(elem, constexpr): -# if require_i64: -# assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ -# f"got a value {elem.value} which is out of the range" -# return builder.get_int64(elem.value) -# else: -# assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ -# f"got a value {elem.value} which is out of the range" -# return builder.get_int32(elem.value) -# elif isinstance(elem, tensor): -# return elem.handle -# assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" - - def _check_fn_args(node, fn, args): if fn.noinline: for idx, arg in enumerate(args): @@ -452,13 +434,6 @@ def visit_FunctionDef(self, node): self.set_value(arg_name, arg_value) self.builder.set_insertion_point_to_start(entry) - # if len(arg_values) > 4: - # if arg_values[3].dtype==language.int32: - # c = self.builder.create_icmpSGE(arg_values[3].handle, self.builder.get_int32(0)) - # self.builder.create_assume(c) - - # elif arg_values[3].dtype==language.int64: - # c = self.builder.create_icmpSGE(arg_values[3].handle, self.builder.get_int64(0)) # visit function body self.visit_compound_statement(node.body) # finalize function @@ -628,7 +603,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): then_defs[name] = liveins[name] # variables that are both in then and else but not in liveins # TODO: could probably be cleaned up - for name in sorted(then_defs.keys() & else_defs.keys()): + for name in then_defs.keys() & else_defs.keys(): if name in names: continue then_ty = then_defs[name].type diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index af33ff8705e0..3e575f56a5b3 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -35,15 +35,16 @@ def __post_init__(self): self.non_negative = set() def to_dict(self): - return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1), 'within_2gb' : list(self.within_2gb), 'non_negative' : list(self.non_negative)} + return { + 'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1), 'within_2gb': + list(self.within_2gb), 'non_negative': list(self.non_negative) + } @staticmethod def from_dict(data): return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), - equal_to_1=set(data.get('equal_to_1', [])), - within_2gb=set(data.get('within_2gb', [])), - non_negative=set(data.get('non_negative', [])) - ) + equal_to_1=set(data.get('equal_to_1', [])), within_2gb=set(data.get('within_2gb', [])), + non_negative=set(data.get('non_negative', []))) def hash(self): key = str([sorted(x) for x in self.__dict__.values()]) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 8da9521a7674..ccf34bd2e8a7 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -502,13 +502,9 @@ def is_divisible_by_16(x): within_2gb = { param.num for param, arg in zip(self.params, args) - if isinstance(arg, torch.Tensor) and sys.getsizeof(arg.untyped_storage()) < 2**31 - } - non_negative = { - param.num - for param, arg in zip(self.params, args) - if isinstance(arg, int) and arg >= 0 + if isinstance(arg, torch.Tensor) and sys.getsizeof(arg.untyped_storage()) < 2**31 #=MAX_INT32 } + non_negative = {param.num for param, arg in zip(self.params, args) if isinstance(arg, int) and arg >= 0} # folded equal_to_1 and None # TODO: method to collect all folded args return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(within_2gb), tuple(non_negative)) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 23339966f50a..91f751207b8e 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -204,21 +204,21 @@ def get_cuda_autotune_config(): def get_hip_autotune_config(): return [ - # triton.Config( - # {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - # num_warps=4, num_stages=0), - # triton.Config( - # {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, - # num_warps=8, num_stages=0), - # triton.Config( - # {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - # num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), - # triton.Config( - # {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, - # num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=0), ] @@ -269,7 +269,6 @@ def matmul_kernel( group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - tl.assume(group_size_m > 0) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m @@ -422,21 +421,21 @@ def matmul(a, b, activation=""): )) -# @triton.testing.perf_report(configs) -# def benchmark(M, N, K, provider, fp8_inputs): -# a = torch.randn((M, K), device='cuda', dtype=torch.float16) -# b = torch.randn((K, N), device='cuda', dtype=torch.float16) -# if TORCH_HAS_FP8 and fp8_inputs: -# a = a.to(torch.float8_e5m2) -# b = b.T -# b = b.to(torch.float8_e5m2) -# quantiles = [0.5, 0.2, 0.8] -# if provider == ref_lib.lower(): -# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) -# if provider == 'triton': -# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) -# perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) -# return perf(ms), perf(max_ms), perf(min_ms) - - -# benchmark.run(show_plots=True, print_data=True) +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider, fp8_inputs): + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + if TORCH_HAS_FP8 and fp8_inputs: + a = a.to(torch.float8_e5m2) + b = b.T + b = b.to(torch.float8_e5m2) + quantiles = [0.5, 0.2, 0.8] + if provider == ref_lib.lower(): + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +benchmark.run(show_plots=True, print_data=True) diff --git a/test/Conversion/amd/buffer_load_store.mlir b/test/Conversion/amd/buffer_load_store.mlir new file mode 100644 index 000000000000..c99983c046d6 --- /dev/null +++ b/test/Conversion/amd/buffer_load_store.mlir @@ -0,0 +1,39 @@ +// RUN: AMDGCN_USE_BUFFER_OPS=1 triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: buffer_load_store_vec8 + tt.func @buffer_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + %3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0> + %4 = arith.addi %3, %2 : tensor<256xi32, #blocked0> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %6 = tt.addptr %5, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + // Load 8 elements from A with two vectorized load instruction + // CHECK-COUNT-5: llvm.select + // CHECK: %[[mask0:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask0]] + // CHECK: %[[mask1:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask1]] + // CHECK: %[[mask2:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask2]] + // CHECK: %[[mask3:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask3]] + %9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> + %10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr, #blocked0> + %11 = arith.addf %9, %10 : tensor<256xf32, #blocked0> + // CHECK: %[[mask4:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask4]] + // CHECK: %[[mask5:.*]] = llvm.select + // CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask5]] + %12 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked0> + %13 = tt.addptr %12, %4 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> + tt.store %13, %11 : tensor<256x!tt.ptr, #blocked0> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 15f3a56f9eb4..260be29ae6b2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,19 +1,23 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" #include "Utility.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::triton::gpu; @@ -102,8 +106,10 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo, - ModuleAxisInfoAnalysis &axisAnalysisPass, const DenseSet& assumptions) - : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass), assumptions(assumptions) {} + ModuleAxisInfoAnalysis &axisAnalysisPass, + const DenseSet &assumptions) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass), + assumptions(assumptions) {} unsigned getContiguity(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); @@ -122,39 +128,57 @@ struct LoadStoreConversionBase { return std::min(128 / pointeeBitWidth, contiguity); } - bool verifyNonNegativeExpr(Value expr) const { - // Look if the expression has been assumed positive - for (Value assume: assumptions){ - if (auto cmpOp = dyn_cast(assume.getDefiningOp())){ + // Look through the available assumption to verify if the expression has been + // assumed positive + bool verifyNonNegativeByAssumption(Value expr) const { + for (Value assume : assumptions) { + if (auto cmpOp = dyn_cast(assume.getDefiningOp())) { bool isGreaterThan = - (cmpOp.getPredicate() == arith::CmpIPredicate::sge || - cmpOp.getPredicate() == arith::CmpIPredicate::sgt); + (cmpOp.getPredicate() == arith::CmpIPredicate::sge || + cmpOp.getPredicate() == arith::CmpIPredicate::sgt); APInt cst; - if (isGreaterThan && (cmpOp.getLhs() == expr) && matchPattern(cmpOp.getRhs(), m_ConstantInt(&cst))){ + if (isGreaterThan && (cmpOp.getLhs() == expr) && + matchPattern(cmpOp.getRhs(), m_ConstantInt(&cst))) { return cst.isNonNegative(); } } } + return false; + } - // Otherwise checks if the expression comes from a function parameter and - // as a tt.non_negative attr - if (!expr.getDefiningOp()){ + // Look if the expression is a block argument with a "tt.non_negative" + // property + bool verifyNonNegativeByFunctionProperty(Value expr) const { + if (!expr.getDefiningOp()) { BlockArgument blockArg = dyn_cast(expr); if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); - if (auto fun = dyn_cast(op)){ - if (fun.getArgAttr(blockArg.getArgNumber(), "tt.non_negative")){ + if (auto fun = dyn_cast(op)) + if (fun.getArgAttr(blockArg.getArgNumber(), "tt.non_negative")) return true; - } - } } - return false; } + return false; + } + + bool verifyNonNegativeExpr(Value expr) const { + + // Base case 1: check if the expression is contained in any assumption + if (verifyNonNegativeByAssumption(expr)) + return true; + + // Base case 2: check if the expression is a BlockArgument and if there + // is a property that states its non-negativity + if (verifyNonNegativeByFunctionProperty(expr)) + return true; + + // Recurse if the operation is defined + Operation *op = expr.getDefiningOp(); + if (!op) + return false; - // Recurse bool nonNegative = - llvm::TypeSwitch( - expr.getDefiningOp()) + llvm::TypeSwitch(expr.getDefiningOp()) .Case([&](auto broadcastOp) { return verifyNonNegativeExpr(broadcastOp.getSrc()); }) @@ -164,53 +188,80 @@ struct LoadStoreConversionBase { .Case([&](auto splatOp) { return verifyNonNegativeExpr(splatOp.getSrc()); }) - .Case([&](triton::MakeRangeOp makeRangeOp){ - return (makeRangeOp.getStart() >= 0); - }) - .Case([&](auto constIntOp){ - return true; - }) - .Case([&](auto pidOp) { - return true; + .Case([&](auto makeRangeOp) { + return makeRangeOp.getStart() >= 0 && makeRangeOp.getEnd() >= 0; }) - .Case([&](Operation *binOp) { - bool nnLhs = verifyNonNegativeExpr(binOp->getOperand(0)); - bool nnRhs = verifyNonNegativeExpr(binOp->getOperand(1)); - return nnLhs&&nnRhs; + .Case( + [&](auto constIntOp) { return constIntOp.value() >= 0; }) + .Case([&](auto pidOp) { return true; }) + .Case([&](auto maxOp) { + // max(a,b) >= 0 iff a>=0 || b>=0 + bool nnLhs = verifyNonNegativeExpr(maxOp.getLhs()); + bool nnRhs = verifyNonNegativeExpr(maxOp.getRhs()); + return nnLhs || nnRhs; }) .Case([&](auto remsiOp) { + // a % b >= 0 iff a>=0 return (verifyNonNegativeExpr(remsiOp.getLhs())); }) + .Case( + // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when + // OP != sub + [&](Operation *binOp) { + bool nnLhs = verifyNonNegativeExpr(binOp->getOperand(0)); + bool nnRhs = verifyNonNegativeExpr(binOp->getOperand(1)); + return nnLhs && nnRhs; + }) .Default([&](Operation *op) { + // Base case 3: unknown operation return false; }); - return nonNegative; + return nonNegative; } + // Quick analysis on the Triton IR to decide if we can safely use + // buffer operations bool canUseBufferOps(Value ptr) const { - // 1. Check if the pointer is uniform: i.e., if it comes from a scalar pointer(splatted) - // and non-uniform offset addition + // 1. Check if the pointer is uniform: i.e., if it comes from a scalar + // pointer(splatted) and non-uniform offset addition DenseSet nonUniformUpdates; - SmallVector queue{ptr.getDefiningOp()}; - while (!queue.empty()){ - Operation* curOp = queue.pop_back_val(); + SmallVector queue{ptr.getDefiningOp()}; + while (!queue.empty()) { + Operation *curOp = queue.pop_back_val(); if (!curOp) continue; - if (auto addPtrOp = dyn_cast(curOp)){ - if (isa(addPtrOp.getPtr().getType())){ + if (auto addPtrOp = dyn_cast(curOp)) + if (isa(addPtrOp.getPtr().getType())) nonUniformUpdates.insert(addPtrOp); - } - } for (Value operand : curOp->getOperands()) queue.push_back(operand.getDefiningOp()); } + + // 2. Check the that pointer is not a block argument. We cannot + // be sure if the block argument has been already non-uniformly + // updated by the caller bool useBufferOps = (nonUniformUpdates.size() == 1); - if (useBufferOps){ - // 2. Check if the offset can be expressed ad 32-bits - Value offset = nonUniformUpdates.begin()->getOffset(); - useBufferOps = (cast(offset.getType()).getElementTypeBitWidth() == 32); - // 3. Check if the offset is non-negative - useBufferOps = useBufferOps&&verifyNonNegativeExpr(offset); + + if (useBufferOps) { + triton::AddPtrOp addPtrOp = (*nonUniformUpdates.begin()); + // 2. Check that the tensor pointer is not coming from a function + // argument. We have no way to know if that pointer has been + // already updated by the caller + Value basePtr = addPtrOp.getPtr(); + auto maybeBufferArg = dyn_cast(basePtr); + useBufferOps = + !maybeBufferArg || + !isa(maybeBufferArg.getOwner()->getParentOp()); + + // 3. Check if the offset can be expressed ad 32-bits + Value offset = addPtrOp.getOffset(); + useBufferOps = + useBufferOps && + (cast(offset.getType()).getElementTypeBitWidth() == + 32); + + // 4. Check if the offset is non-negative + useBufferOps = useBufferOps && verifyNonNegativeExpr(offset); } return useBufferOps; } @@ -232,8 +283,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, LoadOpConversion(LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, - const DenseSet &assumptions, - PatternBenefit benefit) + const DenseSet &assumptions, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass, assumptions) {} @@ -261,7 +311,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, typeConverter->convertType(getElementTypeOrSelf(valueTy)); unsigned vec = getVectorSize(ptr); unsigned numElems = getTotalElemsPerThread(ptr.getType()); - bool useBufferOps = canUseBufferOps(ptr); + bool useBufferOps = + tools::getBoolEnv("AMDGCN_USE_BUFFER_OPS") && canUseBufferOps(ptr); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); @@ -333,8 +384,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, falseVal = v; } - auto loadVal = - llLoad(rewriter, loc, targetInfo, ptr, vecTy, pred, falseVal, cacheMod, useBufferOps); + auto loadVal = llLoad(rewriter, loc, targetInfo, ptr, vecTy, pred, + falseVal, cacheMod, useBufferOps); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); @@ -358,8 +409,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, StoreOpConversion(LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, - const DenseSet &assumptions, - PatternBenefit benefit) + const DenseSet &assumptions, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass, assumptions) {} @@ -376,7 +426,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, auto loc = op->getLoc(); MLIRContext *ctx = rewriter.getContext(); - bool useBufferOps = canUseBufferOps(ptr); + bool useBufferOps = + tools::getBoolEnv("AMDGCN_USE_BUFFER_OPS") && canUseBufferOps(ptr); auto valueTy = value.getType(); Type valueElemTy = @@ -433,7 +484,8 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, rewriter, loc, this->getTypeConverter()->getIndexType(), s); storeVal = insert_element(vecTy, storeVal, otherElem, indexVal); } - llStore(rewriter, loc, targetInfo, ptr, storeVal, pred, cacheMod, useBufferOps); + llStore(rewriter, loc, targetInfo, ptr, storeVal, pred, cacheMod, + useBufferOps); } // end vec rewriter.eraseOp(op); return success(); @@ -606,7 +658,7 @@ struct AtomicRMWOpConversion AtomicRMWOpConversion(LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, - const DenseSet& assumptions, + const DenseSet &assumptions, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass, assumptions) {} @@ -774,10 +826,10 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, - const DenseSet& assumptions, + const DenseSet &assumptions, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, axisInfoAnalysis, assumptions, - benefit); + StoreOpConversion>(typeConverter, targetInfo, axisInfoAnalysis, + assumptions, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 72f258c9032b..004e3dfe963a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -101,7 +101,8 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, } Value falseVal = rewriter.create( loc, elemTy, rewriter.getZeroAttr(elemTy)); - return mlir::LLVM::AMD::llLoad(rewriter, loc, *this, ptr, elemTy, pred, falseVal); + return mlir::LLVM::AMD::llLoad(rewriter, loc, *this, ptr, elemTy, pred, + falseVal); } Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index dff8e3d50f5f..8b30d248dfb1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -108,12 +108,11 @@ struct ConvertTritonAMDGPUToLLVM // Collect all the assumed expressions DenseSet assumptions; - mod.walk([&](Operation *op){ + mod.walk([&](Operation *op) { if (op->getName().getStringRef() == "llvm.intr.assume") assumptions.insert(op->getOperand(0)); }); - // Lower functions { mlir::LowerToLLVMOptions option(context); @@ -188,8 +187,8 @@ struct ConvertTritonAMDGPUToLLVM axisInfoAnalysis, allocation, targetInfo, AMDBenefit); AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, - numWarps, axisInfoAnalysis, assumptions, - AMDBenefit); + numWarps, axisInfoAnalysis, + assumptions, AMDBenefit); populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, commonBenefit); populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index da85302795b1..24e6d739b0a0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -1,9 +1,9 @@ #include "Utility.h" #include "ConvertLayoutOpToLLVM/SharedToDotOperandHelper.h" #include "PatternTritonGPUOpToLLVM.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "TargetInfo.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" @@ -23,52 +23,57 @@ using mlir::triton::gpu::appendOrGetExternFuncOp; using mlir::triton::gpu::getFunctionType; - namespace { // Utility function to traverse a struct, get to the GEP contained // in the struct at position `pos` and extract its base pointer and offset -mlir::FailureOr> getBaseAndOffset(Value ptr, int64_t pos=0){ +mlir::FailureOr> getBaseAndOffset(Value ptr, + int64_t pos = 0) { // Typedef the return type here, to have more coincise code using ReturnType = mlir::FailureOr>; Operation *currentOp = ptr.getDefiningOp(); - auto res = llvm::TypeSwitch(currentOp) - .Case([&](auto gepOp) -> ReturnType { - SmallVector indices = llvm::to_vector(gepOp.getDynamicIndices()); - if (indices.size() ==1 ) - return std::make_pair(gepOp.getBase(), indices[0]); - return failure(); - }) - .Case([&](auto addrspaceCastOp)->ReturnType{ - return getBaseAndOffset(addrspaceCastOp.getArg(), pos); - }) - .Case([&](auto extractValOp)->ReturnType{ - ArrayRef position = extractValOp.getPosition(); - if (position.size() > 1) - return failure(); - return getBaseAndOffset(extractValOp.getContainer(), position[0]); - }) - .Case([&](auto insertValOp)->ReturnType{ - ArrayRef position = insertValOp.getPosition(); - if (position.size() > 1) - return failure(); - if (position[0] == pos) - return getBaseAndOffset(insertValOp.getValue(), 0); - return getBaseAndOffset(insertValOp.getContainer(), pos); - }) - .Default([&](Operation *op)->ReturnType{ - return failure(); - }); + auto res = + llvm::TypeSwitch(currentOp) + .Case([&](auto gepOp) -> ReturnType { + SmallVector indices = + llvm::to_vector(gepOp.getDynamicIndices()); + if (indices.size() == 1) + return std::make_pair(gepOp.getBase(), indices[0]); + return failure(); + }) + .Case([&](auto addrspaceCastOp) -> ReturnType { + return getBaseAndOffset(addrspaceCastOp.getArg(), pos); + }) + .Case([&](auto extractValOp) -> ReturnType { + ArrayRef position = extractValOp.getPosition(); + if (position.size() > 1) + return failure(); + return getBaseAndOffset(extractValOp.getContainer(), position[0]); + }) + .Case([&](auto insertValOp) -> ReturnType { + ArrayRef position = insertValOp.getPosition(); + if (position.size() > 1) + return failure(); + if (position[0] == pos) + return getBaseAndOffset(insertValOp.getValue(), 0); + return getBaseAndOffset(insertValOp.getContainer(), pos); + }) + .Default([&](Operation *op) -> ReturnType { return failure(); }); return res; } -struct BufferEmitter{ - BufferEmitter(RewriterBase &rw, Location loc, AMD::TargetInfo ti):rewriter(rw), loc(loc), targetInfo(ti){} +// Utility class to take care of buffer operation emission. We may add more +// emitters into this as needed. +struct BufferEmitter { + + BufferEmitter(RewriterBase &rw, Location loc, AMD::TargetInfo ti) + : rewriter(rw), loc(loc), targetInfo(ti) {} - // Emit a predicated rocdl.raw.ptr.buffer.load - Value emitMaskedBufferLoad(Type type, Value basePtr, Value offset, - Value pred, Value falseVal, bool nt = false) { + // Emit a predicated rocdl.raw.ptr.buffer.load. `type` needs to be a + // `VectorType` + Value emitMaskedBufferLoad(Type type, Value basePtr, Value offset, Value pred, + Value falseVal, bool nt = false) { VectorType vecTy = cast(type); SmallVector args; fillBufferArgs(vecTy, basePtr, offset, pred, nt, args); @@ -79,65 +84,66 @@ struct BufferEmitter{ return data; } - // Emit a predicated rocdl.raw.ptr.buffer.store - void emitMaskedBufferStore(Value data, Value basePtr, Value offset, Value pred, bool nt=false) { - // We only support vector types. So the caller needs to ensure we have a vector type here + // Emit a predicated rocdl.raw.ptr.buffer.store. `type` needs to be a + // `VectorType` + void emitMaskedBufferStore(Value data, Value basePtr, Value offset, + Value pred, bool nt = false) { + // We only support vector types. So the caller needs to ensure we have a + // vector type here VectorType vecTy = cast(data.getType()); Type bufferType = getBufferOpType(vecTy); if (vecTy != bufferType) data = bitcast(data, bufferType); SmallVector args{data}; fillBufferArgs(vecTy, basePtr, offset, pred, nt, args); - rewriter.create( - loc, TypeRange{}, args, ArrayRef()); + rewriter.create(loc, TypeRange{}, args, + ArrayRef()); } private: - -// Given a type, the buffer type can be either the same type -// or a packed version. E.g., a vector of 8xfp16 can be -// as a vector of 4xi32 -Type getBufferOpType(VectorType vecTy){ - int64_t vecSize = vecTy.getNumElements(); - Type elementType = vecTy.getElementType(); - const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); - const size_t totalWidth = valueElemNBits * vecSize; - - // For bf16, always convert to i16 - Type bufferElementType = elementType; - if (vecTy.getElementType().isBF16()) - bufferElementType = rewriter.getI16Type(); - - // If we are dealing with a subword type (e.g., i8 or f16) but we - // still need multiple words, then pack the subwords into 32bit integers - // and update the vector length and the type - int64_t bufferVecSize = vecSize; - if (valueElemNBits < 32) { - if (totalWidth > 32) { - bufferElementType = rewriter.getI32Type(); - bufferVecSize = totalWidth / 32; - } else { - bufferElementType = rewriter.getIntegerType(totalWidth); - bufferVecSize = 1; + // Given a type, the buffer type can be either the same type + // or a packed version. E.g., a vector of 8xfp16 can be bitcasted to + // a vector of 4xi32. This usually makes the life of the backend easier + Type getBufferOpType(VectorType vecTy) { + int64_t vecSize = vecTy.getNumElements(); + Type elementType = vecTy.getElementType(); + const int valueElemNBits = + std::max(8u, elementType.getIntOrFloatBitWidth()); + const size_t totalWidth = valueElemNBits * vecSize; + + // For bf16, always convert to i16 + Type bufferElementType = elementType; + if (vecTy.getElementType().isBF16()) + bufferElementType = rewriter.getI16Type(); + + // If we are dealing with a subword type (e.g., i8 or f16) but we + // still need multiple words, then pack the subwords into 32bit integers + // and update the vector length and the type + int64_t bufferVecSize = vecSize; + if (valueElemNBits < 32) { + if (totalWidth > 32) { + bufferElementType = rewriter.getI32Type(); + bufferVecSize = totalWidth / 32; + } else { + bufferElementType = rewriter.getIntegerType(totalWidth); + bufferVecSize = 1; + } } - } - - // This is the buffer type that the buffer operation will use. It - // will be bitcast-able to the original type. So if the types - // ended up different, we simply have to emit a `bitcastOp` to convert - Type bufferType = bufferElementType; - if (bufferVecSize != vecSize) - bufferType = VectorType::get(bufferVecSize, bufferElementType); - return bufferType; -} + // This is the buffer type that the buffer operation will use. It + // will be bitcast-able to the original type. So if the types + // ended up different, we simply have to emit a `bitcastOp` to convert + Type bufferType = bufferElementType; + if (bufferVecSize != vecSize) + bufferType = VectorType::get(bufferVecSize, bufferElementType); + return bufferType; + } // Fill common buffer operation arguments. A large part of this function is // courtesy of: mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp - void fillBufferArgs(VectorType vecTy, Value basePtr, - Value offset, Value pred, bool nt, - SmallVector &args) { + void fillBufferArgs(VectorType vecTy, Value basePtr, Value offset, Value pred, + bool nt, SmallVector &args) { // 1. Create the resource descriptor // bits 0-11: dst sel, ignored by these intrinsics // bits 12-14: data format (ignored, must be nonzero, 7=float) @@ -169,7 +175,8 @@ Type getBufferOpType(VectorType vecTy){ // 2. Create the (masked) offset Type elementType = vecTy.getElementType(); - const int valueElemNBits = std::max(8u, elementType.getIntOrFloatBitWidth()); + const int valueElemNBits = + std::max(8u, elementType.getIntOrFloatBitWidth()); const int elementByteWidth = valueElemNBits / 8; // Please note: the index passed to GEP is not in bytes, but in number of // elements In order to pass the index to the buffer operation, we need to @@ -201,8 +208,6 @@ Type getBufferOpType(VectorType vecTy){ AMD::TargetInfo targetInfo; }; - - enum class ShflKind : uint32_t { bfly = 0, up = 1, @@ -379,36 +384,40 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, return rewriter.create(loc, i32_ty, blockId); } -Value llLoad(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targetInfo, Value ptr, Type elemTy, - Value pred, Value falseVal, triton::CacheModifier cm, bool useBufferOps) { +Value llLoad(RewriterBase &rewriter, Location loc, + triton::AMD::TargetInfo targetInfo, Value ptr, Type elemTy, + Value pred, Value falseVal, triton::CacheModifier cm, + bool useBufferOps) { - if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) { - // Use a predicated buffer load intrinsic if we can. This should be optimal, - // since we don't have to emit any branch, ever. bool noCacheModifiers = (cm == triton::CacheModifier::NONE); bool nt = (cm == triton::CacheModifier::CG); - if (useBufferOps && (noCacheModifiers || nt)){ - auto maybeBaseAndOffset = getBaseAndOffset(ptr); - if (!failed(maybeBaseAndOffset)) { - BufferEmitter bufferEmitter(rewriter, loc, targetInfo); - Value basePtr = maybeBaseAndOffset->first; - Value offset = maybeBaseAndOffset->second; - Type vecType = castToVectorType(elemTy); - falseVal = bitcast(falseVal, vecType); - bool nt = (cm == triton::CacheModifier::CG); - Value vecData = bufferEmitter.emitMaskedBufferLoad( vecType, basePtr, offset, pred, falseVal, nt); - // If it is not a vector, remember to bitcast back to a scalar - vecData = bitcast(vecData, elemTy); - return vecData; + // Use a predicated buffer load intrinsic if we can. This should be optimal, + // since we don't have to emit any branch, ever. + if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) { + if (useBufferOps && (noCacheModifiers || nt)) { + auto maybeBaseAndOffset = getBaseAndOffset(ptr); + if (!failed(maybeBaseAndOffset)) { + BufferEmitter bufferEmitter(rewriter, loc, targetInfo); + Value basePtr = maybeBaseAndOffset->first; + Value offset = maybeBaseAndOffset->second; + Type vecType = castToVectorType(elemTy); + falseVal = bitcast(falseVal, vecType); + bool nt = (cm == triton::CacheModifier::CG); + Value vecData = bufferEmitter.emitMaskedBufferLoad( + vecType, basePtr, offset, pred, falseVal, nt); + // If it is not a vector, remember to bitcast back to a scalar + vecData = bitcast(vecData, elemTy); + return vecData; + } } } // Alternatively, try to emit llvm.intr.masked.load if we can. In theory the // backend should be happier because we emit less branchy code to optimize. // The backend will lower it down however it wants at some point. - if (noCacheModifiers || nt){ - // `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need - // to bitcast to `vector<1xelemTy>` (and back) + if (noCacheModifiers || nt) { + // `llvm.intr.masked.load` only accepts vectors. If we see a scalar we + // need to bitcast to `vector<1xelemTy>` (and back) int64_t vecSize = getNumElements(elemTy); Type vecType = castToVectorType(elemTy); falseVal = bitcast(falseVal, vecType); @@ -421,7 +430,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targe return vecData; } - // Emit a branch from MLIR + // Default strategy: emit a branch in MLIR. Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); auto parent = ptr.getParentRegion()->getParentOfType(); auto getLoadNameRaw = [](triton::CacheModifier cm) { @@ -447,9 +456,11 @@ Value llLoad(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targe return loadVal; } -void llStore(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targetInfo, Value ptr, Value val, +void llStore(RewriterBase &rewriter, Location loc, + triton::AMD::TargetInfo targetInfo, Value ptr, Value val, Value pred, triton::CacheModifier cm, bool useBufferOps) { - // Try using the predicated buffer store first + // Use a predicated buffer store intrinsic if we can. This should be optimal, + // since we don't have to emit any branch, ever. if (useBufferOps && cm == triton::CacheModifier::NONE) { auto maybeBaseAndOffset = getBaseAndOffset(ptr); if (!failed(maybeBaseAndOffset)) { @@ -465,12 +476,12 @@ void llStore(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targe } } - // Alternatively, try to emit llvm.intr.masked.store if we can. In theory the - // backend should be happier because we emit less branchy code to optimize. - // The backend will lower it down however it wants at some point. + // Alternatively, try to emit llvm.intr.masked.store if we can. In theory + // the backend should be happier because we emit less branchy code to + // optimize. The backend will lower it down however it wants at some point. if (cm == triton::CacheModifier::NONE) { - // `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need - // to bitcast to `vector<1xelemTy>` + // `llvm.intr.masked.store` only accepts vectors. If we see a scalar we + // need to bitcast to `vector<1xelemTy>` Type elemTy = val.getType(); int64_t vecSize = getNumElements(elemTy); Type vecType = castToVectorType(elemTy); @@ -478,12 +489,10 @@ void llStore(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targe Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); auto op = rewriter.create(loc, val, ptr, maskVal, vecSize); - int64_t vec = cast(elemTy).getNumElements(); - Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vec); - rewriter.create(loc, val, ptr, maskVal, vec); return; } + // Default strategy: emit a branch in MLIR. auto ctx = ptr.getContext(); Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); auto parent = ptr.getParentRegion()->getParentOfType(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 7b3861ff3cbb..00272460b116 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -29,15 +29,17 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, // Loads from shared or global memory with predication. // `otherElems` is used to mask out the elements that are not loaded -Value llLoad(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targetInfo, Value ptr, Type elemTy, +Value llLoad(RewriterBase &rewriter, Location loc, + triton::AMD::TargetInfo targetInfo, Value ptr, Type elemTy, Value pred, Value falseVal, triton::CacheModifier cm = triton::CacheModifier::NONE, bool useBufferOp = false); // Stores to shared or global memory with predication. -void llStore(RewriterBase &rewriter, Location loc, triton::AMD::TargetInfo targetInfo, Value ptr, Value val, - Value pred, - triton::CacheModifier cm = triton::CacheModifier::NONE, bool useBufferOps=false); +void llStore(RewriterBase &rewriter, Location loc, + triton::AMD::TargetInfo targetInfo, Value ptr, Value val, + Value pred, triton::CacheModifier cm = triton::CacheModifier::NONE, + bool useBufferOps = false); } // namespace mlir::LLVM::AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 232589da3427..2efc22b3850d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -836,7 +836,8 @@ LogicalResult PointerCanonicalizer::rewriteFunction(triton::FuncOp funcOp) { if (!isa(arg.getType())) continue; - bool is32BitPtrRange = (funcOp.getArgAttr(idx, "tt.ptr_int32_range") != nullptr); + bool is32BitPtrRange = + (funcOp.getArgAttr(idx, "tt.ptr_int32_range") != nullptr); int64_t bitness = (is32BitPtrRange ? 32 : 64); rewriter.setInsertionPointToStart(®ion.front());