From e16c06340b4c418fded96e35b9a33fef68affec6 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Thu, 20 Jun 2024 14:20:30 -0700 Subject: [PATCH] Add simple fused Triton kernel for jagged_sum operator Summary: Add Triton kernel implementation to `jagged_sum` operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. It loads in blocks of the `values` tensor along its last dimension `M`, reduces each block of variable length along its first dimension `*`, and stores each of `B` reductions in an output tensor of shape `(B, M)`. This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad. Reviewed By: davidberard98 Differential Revision: D58549297 --- .../operators/jagged_sum/kernels.py | 155 ++++++++++++++++++ .../operators/jagged_sum/operator.py | 66 +++++++- 2 files changed, 218 insertions(+), 3 deletions(-) create mode 100644 torchbenchmark/operators/jagged_sum/kernels.py diff --git a/torchbenchmark/operators/jagged_sum/kernels.py b/torchbenchmark/operators/jagged_sum/kernels.py new file mode 100644 index 000000000..ea71634aa --- /dev/null +++ b/torchbenchmark/operators/jagged_sum/kernels.py @@ -0,0 +1,155 @@ +import itertools + +import triton +import triton.language as tl + + +BLOCK_SIZES = [2**n for n in range(2, 11, 3)] +NUM_WARPS = [2, 4, 8] +NUM_STAGES = [2, 4, 8] + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_RAGGED": b_r, + "BLOCK_SIZE_M": b_m, + }, + num_warps=w, + num_stages=s, + ) + for b_r, b_m, w, s in itertools.product( + BLOCK_SIZES, # block sizes on non-reduction dimension + BLOCK_SIZES, # block sizes on reduction dimension + NUM_WARPS, # number of warps + NUM_STAGES, # number of stages + ) + ], + key=["M"], +) +@triton.jit +def triton_jagged_sum_kernel_simple_fused_sum_then_buffer( + input_ptr_values, # pointer to input values (2D tensor) + input_ptr_offsets, # pointer to input offsets (1D tensor) + output_ptr, # pointer to output tensor (2D tensor) + # matrix dimensions (input) + M, # number of elements in M-th dimension, with logical dimensions (B, *, M) + MAX_SEQLEN, # max length of ragged dimension + # block sizes (input) + BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M) + BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M) +): + pid = tl.program_id(axis=0) # i-th tensor in nested tensor + pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M) + + buffer = tl.zeros( + (1, BLOCK_SIZE_M), dtype=tl.float32 + ) # create buffer as a row tensor + + block_start_m = pid_m * BLOCK_SIZE_M + offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < M + + ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( + input_ptr_offsets + (pid_ragged + 1) + ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] + + for block_pos in range( + 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED + ): # loop over ragged dimension, ranging until maximum seqlen + block_start_ragged = ragged_start + block_pos # offset block position by start of current program + offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) + mask_ragged = offsets_ragged < ragged_end + + idxs = (offsets_ragged[:, None] * M) + offsets_m + mask = mask_ragged[:, None] & mask_m + + input = tl.load(input_ptr_values + idxs, mask=mask, other=0) + + buffer += tl.sum(input, axis=0) + + buffer_view = buffer.reshape( + (BLOCK_SIZE_M,), + ) # reshape buffer to 1D, as tl.sum may return a 2D tensor + + output_offsets = offsets_m + ( + pid_ragged * M + ) # output is offset by both ragged dimension and M-th dimension + output_mask = output_offsets < (M * (pid_ragged + 1)) + + tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_RAGGED": b_r, + "BLOCK_SIZE_M": b_m, + }, + num_warps=w, + num_stages=s, + ) + for b_r, b_m, w, s in itertools.product( + BLOCK_SIZES, # block sizes on non-reduction dimension + BLOCK_SIZES, # block sizes on reduction dimension + NUM_WARPS, # number of warps + NUM_STAGES, # number of stages + ) + ], + key=["M"], +) +@triton.jit +def triton_jagged_sum_kernel_simple_fused_buffer_then_sum( + input_ptr_values, # pointer to input values (2D tensor) + input_ptr_offsets, # pointer to input offsets (1D tensor) + output_ptr, # pointer to output tensor (2D tensor) + # matrix dimensions (input) + M, # number of elements in M-th dimension, with logical dimensions (B, *, M) + MAX_SEQLEN, # max length of ragged dimension + # block sizes (input) + BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M) + BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M) +): + pid = tl.program_id(axis=0) # i-th tensor in nested tensor + pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M) + + buffer = tl.zeros( + (BLOCK_SIZE_RAGGED, BLOCK_SIZE_M), dtype=tl.float32 + ) # create buffer as a row tensor + + block_start_m = pid_m * BLOCK_SIZE_M + offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M) + mask_m = offsets_m < M + + ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load( + input_ptr_offsets + (pid_ragged + 1) + ) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1] + + for block_pos in range( + 0, MAX_SEQLEN, BLOCK_SIZE_RAGGED + ): # loop over ragged dimension, ranging until maximum seqlen + block_start_ragged = ragged_start + block_pos # offset block position by start of current program + offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED) + mask_ragged = offsets_ragged < ragged_end + + idxs = (offsets_ragged[:, None] * M) + offsets_m + mask = mask_ragged[:, None] & mask_m + + buffer += tl.load(input_ptr_values + idxs, mask=mask, other=0) + + buffer_sum = tl.sum(buffer, axis=0) + + buffer_view = buffer_sum.reshape( + (BLOCK_SIZE_M,), + ) # reshape buffer to 1D, as tl.sum may return a 2D tensor + + output_offsets = offsets_m + ( + pid_ragged * M + ) # output is offset by both ragged dimension and M-th dimension + output_mask = output_offsets < (M * (pid_ragged + 1)) + + tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask) diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index fb650e028..775982698 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -15,6 +15,11 @@ register_metric, ) +from .kernels import ( + triton_jagged_sum_kernel_simple_fused_buffer_then_sum, + triton_jagged_sum_kernel_simple_fused_sum_then_buffer, +) + seed = 16 random.seed(seed) torch.manual_seed(seed) @@ -38,21 +43,57 @@ def parse_op_args(args: List[str]): default=0.5, help="Average sparsity for nested tensor (float, (0.0-1.0))", ) + parser.add_argument( + "--sum-then-buffer", + type=int, # 1: sum then buffer, 0: buffer then sum + default=1, + help="[Optional] For Triton kernels, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum", + ) return parser.parse_args(args) +def execute_kernel_simple_fused(x, max_seqlen, sum_then_buffer): + B, M = x.shape[0], x.shape[2] + grid = lambda meta: ((len(x.offsets()) - 1) * triton.cdiv(M, meta["BLOCK_SIZE_M"]),) + kernel_output = torch.zeros((B, M), device=x.device) + + if sum_then_buffer: + triton_jagged_sum_kernel_simple_fused_sum_then_buffer[grid]( + x.values(), + x.offsets(), + kernel_output, + M=M, + MAX_SEQLEN=max_seqlen, + ) + else: + triton_jagged_sum_kernel_simple_fused_buffer_then_sum[grid]( + x.values(), + x.offsets(), + kernel_output, + M=M, + MAX_SEQLEN=max_seqlen, + ) + + return kernel_output + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["latency", "accuracy"] - use_cuda_graphs = False # enables GPU/CPU sync (for methods like NestedTensor unbind) + use_cuda_graphs = ( + False # enables GPU/CPU sync (for methods like NestedTensor unbind) + ) def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): super().__init__(mode=mode, device=device, extra_args=extra_args) - self.sizes = range(4, 10, 2) + self.sizes = list(range(2, 8, 2)) + list( + range(8, 12) + ) # bias towards larger sizes, which are more representative of real-world shapes args = parse_op_args(self.extra_args) self.seqlen = args.seqlen self.sparsity = args.sparsity + self.sum_then_buffer = args.sum_then_buffer @register_benchmark(baseline=True) def torch_jagged_sum_no_pad(self, x: torch.Tensor): @@ -75,6 +116,13 @@ def torch_jagged_sum_pad(self, x: torch.Tensor): dim=1, ) # sum along ragged dimension (dim == 1) + @register_benchmark() + def triton_jagged_sum_no_pad(self, x: torch.Tensor): + def _inner(): + return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer) + + return _inner + def get_x_val(self, example_inputs): return len(example_inputs[0]) @@ -156,7 +204,7 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): / metrics.latency * GIGABYTES_PER_BYTE ) - + @register_metric(x_only=True) def input_shape( self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics @@ -166,3 +214,15 @@ def input_shape( "*", example_inputs[0].shape[2], ) # return (B, '*', M) for each example input + + @register_metric(skip_baseline=True) + def best_config( + self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics + ) -> str: + if self.sum_then_buffer: + return dump_autotuner_best_config( + triton_jagged_sum_kernel_simple_fused_sum_then_buffer + ) + return dump_autotuner_best_config( + triton_jagged_sum_kernel_simple_fused_buffer_then_sum + )