Skip to content

Commit

Permalink
Add simple fused Triton kernel for jagged_sum operator
Browse files Browse the repository at this point in the history
Summary:
Add Triton kernel implementation to `jagged_sum` operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension.  It loads in blocks of the `values` tensor along its last dimension `M`, reduces each block of variable length along its first dimension `*`, and stores each of `B` reductions in an output tensor of shape `(B, M)`.

This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad.

Reviewed By: davidberard98

Differential Revision: D58549297
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 20, 2024
1 parent caa76d8 commit e16c063
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 3 deletions.
155 changes: 155 additions & 0 deletions torchbenchmark/operators/jagged_sum/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import itertools

import triton
import triton.language as tl


BLOCK_SIZES = [2**n for n in range(2, 11, 3)]
NUM_WARPS = [2, 4, 8]
NUM_STAGES = [2, 4, 8]


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_RAGGED": b_r,
"BLOCK_SIZE_M": b_m,
},
num_warps=w,
num_stages=s,
)
for b_r, b_m, w, s in itertools.product(
BLOCK_SIZES, # block sizes on non-reduction dimension
BLOCK_SIZES, # block sizes on reduction dimension
NUM_WARPS, # number of warps
NUM_STAGES, # number of stages
)
],
key=["M"],
)
@triton.jit
def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
input_ptr_values, # pointer to input values (2D tensor)
input_ptr_offsets, # pointer to input offsets (1D tensor)
output_ptr, # pointer to output tensor (2D tensor)
# matrix dimensions (input)
M, # number of elements in M-th dimension, with logical dimensions (B, *, M)
MAX_SEQLEN, # max length of ragged dimension
# block sizes (input)
BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M)
BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M)
):
pid = tl.program_id(axis=0) # i-th tensor in nested tensor
pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M)

buffer = tl.zeros(
(1, BLOCK_SIZE_M), dtype=tl.float32
) # create buffer as a row tensor

block_start_m = pid_m * BLOCK_SIZE_M
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

idxs = (offsets_ragged[:, None] * M) + offsets_m
mask = mask_ragged[:, None] & mask_m

input = tl.load(input_ptr_values + idxs, mask=mask, other=0)

buffer += tl.sum(input, axis=0)

buffer_view = buffer.reshape(
(BLOCK_SIZE_M,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

output_offsets = offsets_m + (
pid_ragged * M
) # output is offset by both ragged dimension and M-th dimension
output_mask = output_offsets < (M * (pid_ragged + 1))

tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask)


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_RAGGED": b_r,
"BLOCK_SIZE_M": b_m,
},
num_warps=w,
num_stages=s,
)
for b_r, b_m, w, s in itertools.product(
BLOCK_SIZES, # block sizes on non-reduction dimension
BLOCK_SIZES, # block sizes on reduction dimension
NUM_WARPS, # number of warps
NUM_STAGES, # number of stages
)
],
key=["M"],
)
@triton.jit
def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
input_ptr_values, # pointer to input values (2D tensor)
input_ptr_offsets, # pointer to input offsets (1D tensor)
output_ptr, # pointer to output tensor (2D tensor)
# matrix dimensions (input)
M, # number of elements in M-th dimension, with logical dimensions (B, *, M)
MAX_SEQLEN, # max length of ragged dimension
# block sizes (input)
BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M)
BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M)
):
pid = tl.program_id(axis=0) # i-th tensor in nested tensor
pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M)

buffer = tl.zeros(
(BLOCK_SIZE_RAGGED, BLOCK_SIZE_M), dtype=tl.float32
) # create buffer as a row tensor

block_start_m = pid_m * BLOCK_SIZE_M
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

idxs = (offsets_ragged[:, None] * M) + offsets_m
mask = mask_ragged[:, None] & mask_m

buffer += tl.load(input_ptr_values + idxs, mask=mask, other=0)

buffer_sum = tl.sum(buffer, axis=0)

buffer_view = buffer_sum.reshape(
(BLOCK_SIZE_M,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

output_offsets = offsets_m + (
pid_ragged * M
) # output is offset by both ragged dimension and M-th dimension
output_mask = output_offsets < (M * (pid_ragged + 1))

tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask)
66 changes: 63 additions & 3 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
register_metric,
)

from .kernels import (
triton_jagged_sum_kernel_simple_fused_buffer_then_sum,
triton_jagged_sum_kernel_simple_fused_sum_then_buffer,
)

seed = 16
random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -38,21 +43,57 @@ def parse_op_args(args: List[str]):
default=0.5,
help="Average sparsity for nested tensor (float, (0.0-1.0))",
)
parser.add_argument(
"--sum-then-buffer",
type=int, # 1: sum then buffer, 0: buffer then sum
default=1,
help="[Optional] For Triton kernels, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum",
)
return parser.parse_args(args)


def execute_kernel_simple_fused(x, max_seqlen, sum_then_buffer):
B, M = x.shape[0], x.shape[2]
grid = lambda meta: ((len(x.offsets()) - 1) * triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
kernel_output = torch.zeros((B, M), device=x.device)

if sum_then_buffer:
triton_jagged_sum_kernel_simple_fused_sum_then_buffer[grid](
x.values(),
x.offsets(),
kernel_output,
M=M,
MAX_SEQLEN=max_seqlen,
)
else:
triton_jagged_sum_kernel_simple_fused_buffer_then_sum[grid](
x.values(),
x.offsets(),
kernel_output,
M=M,
MAX_SEQLEN=max_seqlen,
)

return kernel_output


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
use_cuda_graphs = False # enables GPU/CPU sync (for methods like NestedTensor unbind)
use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(4, 10, 2)
self.sizes = list(range(2, 8, 2)) + list(
range(8, 12)
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
self.seqlen = args.seqlen
self.sparsity = args.sparsity
self.sum_then_buffer = args.sum_then_buffer

@register_benchmark(baseline=True)
def torch_jagged_sum_no_pad(self, x: torch.Tensor):
Expand All @@ -75,6 +116,13 @@ def torch_jagged_sum_pad(self, x: torch.Tensor):
dim=1,
) # sum along ragged dimension (dim == 1)

@register_benchmark()
def triton_jagged_sum_no_pad(self, x: torch.Tensor):
def _inner():
return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer)

return _inner

def get_x_val(self, example_inputs):
return len(example_inputs[0])

Expand Down Expand Up @@ -156,7 +204,7 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
/ metrics.latency
* GIGABYTES_PER_BYTE
)

@register_metric(x_only=True)
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
Expand All @@ -166,3 +214,15 @@ def input_shape(
"*",
example_inputs[0].shape[2],
) # return (B, '*', M) for each example input

@register_metric(skip_baseline=True)
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if self.sum_then_buffer:
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_sum_then_buffer
)
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_buffer_then_sum
)

0 comments on commit e16c063

Please sign in to comment.