Skip to content

Commit

Permalink
Introduce support for buffer operations
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 12, 2024
1 parent 378170c commit fbeb3a6
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 255 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace mlir::triton {
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
// clang-format off
"AMDGCN_ENABLE_DUMP",
"AMDGCN_USE_BUFFER_OPS",
"DISABLE_FAST_REDUCTION",
"DISABLE_LLVM_OPT",
"DISABLE_MMA_V3",
Expand Down
27 changes: 1 addition & 26 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,6 @@ def _is_triton_scalar(o: Any) -> bool:
def _is_list_like(o: Any) -> bool:
return isinstance(o, (list, tuple))


# def _convert_elem_to_ir_value(builder, elem, require_i64):
# if isinstance(elem, int):
# elem = tl.constexpr(elem)
# if isinstance(elem, constexpr):
# if require_i64:
# assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
# f"got a value {elem.value} which is out of the range"
# return builder.get_int64(elem.value)
# else:
# assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
# f"got a value {elem.value} which is out of the range"
# return builder.get_int32(elem.value)
# elif isinstance(elem, tensor):
# return elem.handle
# assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"


def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
Expand Down Expand Up @@ -452,13 +434,6 @@ def visit_FunctionDef(self, node):
self.set_value(arg_name, arg_value)

self.builder.set_insertion_point_to_start(entry)
# if len(arg_values) > 4:
# if arg_values[3].dtype==language.int32:
# c = self.builder.create_icmpSGE(arg_values[3].handle, self.builder.get_int32(0))
# self.builder.create_assume(c)

# elif arg_values[3].dtype==language.int64:
# c = self.builder.create_icmpSGE(arg_values[3].handle, self.builder.get_int64(0))
# visit function body
self.visit_compound_statement(node.body)
# finalize function
Expand Down Expand Up @@ -628,7 +603,7 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block):
then_defs[name] = liveins[name]
# variables that are both in then and else but not in liveins
# TODO: could probably be cleaned up
for name in sorted(then_defs.keys() & else_defs.keys()):
for name in then_defs.keys() & else_defs.keys():
if name in names:
continue
then_ty = then_defs[name].type
Expand Down
11 changes: 6 additions & 5 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ def __post_init__(self):
self.non_negative = set()

def to_dict(self):
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1), 'within_2gb' : list(self.within_2gb), 'non_negative' : list(self.non_negative)}
return {
'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1), 'within_2gb':
list(self.within_2gb), 'non_negative': list(self.non_negative)
}

@staticmethod
def from_dict(data):
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
equal_to_1=set(data.get('equal_to_1', [])),
within_2gb=set(data.get('within_2gb', [])),
non_negative=set(data.get('non_negative', []))
)
equal_to_1=set(data.get('equal_to_1', [])), within_2gb=set(data.get('within_2gb', [])),
non_negative=set(data.get('non_negative', [])))

def hash(self):
key = str([sorted(x) for x in self.__dict__.values()])
Expand Down
8 changes: 2 additions & 6 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,13 +502,9 @@ def is_divisible_by_16(x):
within_2gb = {
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, torch.Tensor) and sys.getsizeof(arg.untyped_storage()) < 2**31
}
non_negative = {
param.num
for param, arg in zip(self.params, args)
if isinstance(arg, int) and arg >= 0
if isinstance(arg, torch.Tensor) and sys.getsizeof(arg.untyped_storage()) < 2**31 #=MAX_INT32
}
non_negative = {param.num for param, arg in zip(self.params, args) if isinstance(arg, int) and arg >= 0}
# folded equal_to_1 and None
# TODO: method to collect all folded args
return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1), tuple(within_2gb), tuple(non_negative))
Expand Down
61 changes: 30 additions & 31 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,21 @@ def get_cuda_autotune_config():

def get_hip_autotune_config():
return [
# triton.Config(
# {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
# num_warps=4, num_stages=0),
# triton.Config(
# {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
# num_warps=8, num_stages=0),
# triton.Config(
# {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
# num_warps=8, num_stages=0),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=0),
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=0),
# triton.Config(
# {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
# num_warps=4, num_stages=0),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=0),
]


Expand Down Expand Up @@ -269,7 +269,6 @@ def matmul_kernel(
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
tl.assume(group_size_m > 0)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

Expand Down Expand Up @@ -422,21 +421,21 @@ def matmul(a, b, activation=""):
))


# @triton.testing.perf_report(configs)
# def benchmark(M, N, K, provider, fp8_inputs):
# a = torch.randn((M, K), device='cuda', dtype=torch.float16)
# b = torch.randn((K, N), device='cuda', dtype=torch.float16)
# if TORCH_HAS_FP8 and fp8_inputs:
# a = a.to(torch.float8_e5m2)
# b = b.T
# b = b.to(torch.float8_e5m2)
# quantiles = [0.5, 0.2, 0.8]
# if provider == ref_lib.lower():
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
# perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
# return perf(ms), perf(max_ms), perf(min_ms)


# benchmark.run(show_plots=True, print_data=True)
@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if TORCH_HAS_FP8 and fp8_inputs:
a = a.to(torch.float8_e5m2)
b = b.T
b = b.to(torch.float8_e5m2)
quantiles = [0.5, 0.2, 0.8]
if provider == ref_lib.lower():
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=True)
39 changes: 39 additions & 0 deletions test/Conversion/amd/buffer_load_store.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: AMDGCN_USE_BUFFER_OPS=1 triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s

#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: buffer_load_store_vec8
tt.func @buffer_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK-COUNT-5: llvm.select
// CHECK: %[[mask0:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask0]]
// CHECK: %[[mask1:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask1]]
// CHECK: %[[mask2:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask2]]
// CHECK: %[[mask3:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask3]]
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
// CHECK: %[[mask4:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask4]]
// CHECK: %[[mask5:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask5]]
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
}
}
Loading

0 comments on commit fbeb3a6

Please sign in to comment.