From f4cbf782eccc759c13a013beacc8f79534dbd642 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Fri, 14 Jun 2024 15:57:56 -0700 Subject: [PATCH] Extend support to varying block sizes on both dimensions for 2D matrices (#2302) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2302 Extend support for reducing across individual dimensions on 2-dimensional matrices by allowing for varying block sizes on both the `M` (first) and `N` (second) dimensions. The existing kernel performed a simplified reduction, assuming that the entire reduction dimension fit within one thread block. The new kernel implementation removes the need for this assumption, allowing both the reduction and the non-reduction dimensions to fit in multiple thread blocks. This implementation also enables autotuning on block sizes for both the `M` and `N` dimensions. For 1D results, add a `sum_then_buffer` configuration which decides which kernel configuration to run. `Sum_then_buffer` sums individual blocks of input and adds these sums into a buffer. `Buffer_then_sum` adds blocks of raw input into a buffer, then reduces the buffer. Reviewed By: davidberard98 Differential Revision: D58313958 fbshipit-source-id: 639ea6b7d7b92f478c0f5669a1cdc0dcb68004e3 --- torchbenchmark/operators/sum/kernels.py | 154 +++++++++++++++++------ torchbenchmark/operators/sum/operator.py | 47 ++++--- 2 files changed, 150 insertions(+), 51 deletions(-) diff --git a/torchbenchmark/operators/sum/kernels.py b/torchbenchmark/operators/sum/kernels.py index 2a39f2eb45..3572d65211 100644 --- a/torchbenchmark/operators/sum/kernels.py +++ b/torchbenchmark/operators/sum/kernels.py @@ -44,7 +44,10 @@ def triton_sum_kernel_scalar_result( @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_NON_REDUCE_DIM": b}, + { + "BLOCK_SIZE_NON_REDUCE_DIM": b, + "BLOCK_SIZE_REDUCE_DIM": b, + }, num_warps=w, ) for b, w in itertools.product( @@ -54,7 +57,7 @@ def triton_sum_kernel_scalar_result( key=["M", "N"], ) @triton.jit -def triton_sum_kernel_1D_result( +def triton_sum_kernel_1D_result_sum_then_buffer( input_ptr, # pointer to input matrix output_ptr, # pointer to output matrix # matrix dimensions (input) @@ -66,49 +69,128 @@ def triton_sum_kernel_1D_result( # reduction dimension dim: tl.constexpr, # dimension along which to sum ): + """ + Sum blocks of input using Triton and store in buffer + """ + pid = tl.program_id(axis=0) # i-th block of input - block_start_m, block_start_n = 0, 0 - offsets_m, offsets_n = None, None - if dim == 0: - block_start_n = pid * BLOCK_SIZE_REDUCE_DIM - # offsets have shape equal to input shape - offsets_m = block_start_m + tl.arange( + reduce_dim_len = M if dim == 0 else N + non_reduce_dim_len = N if dim == 0 else M + + buffer = tl.zeros( + (1, BLOCK_SIZE_NON_REDUCE_DIM), dtype=tl.float32 + ) # create buffer as a row tensor + + block_start_non_reduce_dim = pid * BLOCK_SIZE_NON_REDUCE_DIM + offsets_non_reduce_dim = block_start_non_reduce_dim + tl.arange( + 0, BLOCK_SIZE_NON_REDUCE_DIM + ) + mask_non_reduce_dim = offsets_non_reduce_dim < non_reduce_dim_len + + for block_start_reduce_dim in range(0, reduce_dim_len, BLOCK_SIZE_REDUCE_DIM): + offsets_reduce_dim = block_start_reduce_dim + tl.arange( 0, BLOCK_SIZE_REDUCE_DIM - ) # create 1D vector for offsets on M-th dimension - offsets_n = block_start_n + tl.arange( - 0, BLOCK_SIZE_NON_REDUCE_DIM - ) # create 1D vector for offsets on N-th dimension - elif dim == 1: - block_start_m = pid * BLOCK_SIZE_REDUCE_DIM - # offsets have shape equal to input shape - offsets_m = block_start_m + tl.arange( - 0, BLOCK_SIZE_NON_REDUCE_DIM - ) # create 1D vector for offsets on M-th dimension - offsets_n = block_start_n + tl.arange( + ) + mask_reduce_dim = offsets_reduce_dim < reduce_dim_len + + idxs, mask = None, None + if dim == 0: + idxs = ( + offsets_reduce_dim[:, None] * non_reduce_dim_len + ) + offsets_non_reduce_dim + mask = mask_reduce_dim[:, None] & mask_non_reduce_dim + elif dim == 1: + idxs = ( + offsets_non_reduce_dim[:, None] * reduce_dim_len + ) + offsets_reduce_dim + mask = mask_non_reduce_dim[:, None] & mask_reduce_dim + + input = tl.load(input_ptr + idxs, mask=mask, other=mask) + + buffer += tl.sum(input, axis=dim) + + buffer_view = buffer.reshape( + (BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True + ) # reshape buffer to 1D, as tl.sum may return a 2D tensor + + tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim) + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_NON_REDUCE_DIM": b, + "BLOCK_SIZE_REDUCE_DIM": b, + }, + num_warps=w, + ) + for b, w in itertools.product( + [2, 4, 8, 16], [2, 4, 8] # block sizes # number of warps + ) + ], + key=["M", "N"], +) +@triton.jit +def triton_sum_kernel_1D_result_buffer_then_sum( + input_ptr, # pointer to input matrix + output_ptr, # pointer to output matrix + # matrix dimensions (input) + M, # number of rows + N, # number of columns + # block sizes (input) + BLOCK_SIZE_NON_REDUCE_DIM: tl.constexpr, # number of elements in non-reduction dimension per block + BLOCK_SIZE_REDUCE_DIM: tl.constexpr, # number of elements in reduction dimension per block + # reduction dimension + dim: tl.constexpr, # dimension along which to sum +): + """ + Add blocks of input to a buffer and sum the buffer using Triton + """ + + pid = tl.program_id(axis=0) # i-th block of input + + reduce_dim_len = M if dim == 0 else N + non_reduce_dim_len = N if dim == 0 else M + + buffer = tl.zeros( + (BLOCK_SIZE_REDUCE_DIM, BLOCK_SIZE_NON_REDUCE_DIM), dtype=tl.float32 + ) # create buffer as a 2D tensor + + block_start_non_reduce_dim = pid * BLOCK_SIZE_NON_REDUCE_DIM + offsets_non_reduce_dim = block_start_non_reduce_dim + tl.arange( + 0, BLOCK_SIZE_NON_REDUCE_DIM + ) + mask_non_reduce_dim = offsets_non_reduce_dim < non_reduce_dim_len + + for block_start_reduce_dim in range(0, reduce_dim_len, BLOCK_SIZE_REDUCE_DIM): + offsets_reduce_dim = block_start_reduce_dim + tl.arange( 0, BLOCK_SIZE_REDUCE_DIM - ) # create 1D vector for offsets on N-th dimension + ) + mask_reduce_dim = offsets_reduce_dim < reduce_dim_len - # mask has shape equal to input shape - mask_m = offsets_m < M - mask_n = offsets_n < N + idxs, mask = None, None + if dim == 0: + idxs = ( + offsets_reduce_dim[:, None] * non_reduce_dim_len + ) + offsets_non_reduce_dim + mask = mask_reduce_dim[:, None] & mask_non_reduce_dim + elif dim == 1: + idxs = ( + offsets_non_reduce_dim[:, None] * reduce_dim_len + ) + offsets_reduce_dim + mask = mask_non_reduce_dim[:, None] & mask_reduce_dim - # create 2D matrices of pointers and masks, using above M and N vectors - idxs = (offsets_m[:, None] * N) + offsets_n - mask = mask_m[:, None] & mask_n + buffer += tl.load(input_ptr + idxs, mask=mask, other=mask) - # loaded pointers have shape equal to input shape - input = tl.load( - input_ptr + idxs, mask=mask, other=mask - ) # other=mask zeros out masked values from input + buffer_sum = tl.sum(buffer, axis=dim) - output = tl.sum(input, axis=dim) + buffer_view = buffer_sum.reshape( + (BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True + ) # reshape buffer to 1D, as tl.sum may return a 2D tensor - # stored pointers have shape equal to output shape - if dim == 0: # store output along N-th dimension - tl.store(output_ptr + offsets_n, output, mask=mask_n) - elif dim == 1: # store output along M-th dimension - tl.store(output_ptr + offsets_m, output, mask=mask_m) + tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim) @triton.autotune( diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index f6f480adfe..fd18be5873 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -13,7 +13,8 @@ ) from .kernels import ( - triton_sum_kernel_1D_result, + triton_sum_kernel_1D_result_buffer_then_sum, + triton_sum_kernel_1D_result_sum_then_buffer, triton_sum_kernel_2D_result_dim_1, triton_sum_kernel_scalar_result, ) @@ -28,6 +29,12 @@ def parse_op_args(args: List[str]): default=None, help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1", ) + parser.add_argument( + "--sum-then-buffer", + type=int, # 1: sum then buffer, 0: buffer then sum + default=1, + help="[Optional] For 1D results, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum", + ) return parser.parse_args(args) @@ -41,6 +48,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non self.reduce_dim = ( args.reduce_dim if args.reduce_dim else None ) # for 2D case, guaranteed to be a list with 1 integer + self.sum_then_buffer = args.sum_then_buffer self.sizes = range(1, 11) @register_benchmark() @@ -61,13 +69,12 @@ def triton_sum(self, x: torch.Tensor): ) # race condition in cases where BLOCK_SIZE < n_elements^2 elif x.dim() == 2 and num_output_dims == 1: M, N = x.shape - BLOCK_SIZE_M, BLOCK_SIZE_N = triton.next_power_of_2( - M - ), triton.next_power_of_2(N) grid = lambda meta: ( max( triton.cdiv(M, meta["BLOCK_SIZE_REDUCE_DIM"]), triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]), + triton.cdiv(M, meta["BLOCK_SIZE_NON_REDUCE_DIM"]), + triton.cdiv(N, meta["BLOCK_SIZE_REDUCE_DIM"]), ), ) elif x.dim() == 3 and num_output_dims == 2 and self.reduce_dim[0] == 1: @@ -94,23 +101,29 @@ def _inner(): elif kernel_input.dim() == 2 and num_output_dims == 1: if self.reduce_dim[0] == 0: kernel_output = torch.empty(N, device=self.device) - BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_M elif self.reduce_dim[0] == 1: kernel_output = torch.empty(M, device=self.device) - BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_N else: raise Exception( f"Existing sum Triton kernels do not support reducing input with shape {kernel_input.size} along dimension(s) {self.reduce_dim}" ) - triton_sum_kernel_1D_result[grid]( - kernel_input, - kernel_output, - M=M, - N=N, - BLOCK_SIZE_REDUCE_DIM=BLOCK_SIZE_REDUCE_DIM, - dim=self.reduce_dim[0], - ) + if self.sum_then_buffer: + triton_sum_kernel_1D_result_sum_then_buffer[grid]( + kernel_input, + kernel_output, + M=M, + N=N, + dim=self.reduce_dim[0], + ) + else: + triton_sum_kernel_1D_result_buffer_then_sum[grid]( + kernel_input, + kernel_output, + M=M, + N=N, + dim=self.reduce_dim[0], + ) elif ( kernel_input.dim() == 3 and num_output_dims == 2 @@ -201,6 +214,10 @@ def best_config( if example_inputs[0].dim() == 3 and self.reduce_dim and self.reduce_dim[0] == 1: return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1) elif self.reduce_dim and len(self.reduce_dim) < example_inputs[0].dim(): - return dump_autotuner_best_config(triton_sum_kernel_1D_result) + if self.sum_then_buffer: + return dump_autotuner_best_config( + triton_sum_kernel_1D_result_sum_then_buffer + ) + return dump_autotuner_best_config(triton_sum_kernel_1D_result_buffer_then_sum) else: return ""