Skip to content

Commit

Permalink
Add asymmetric shapes to test sum Triton kernels
Browse files Browse the repository at this point in the history
Summary:
Add asymmetric shapes to `get_input_iter()` in order to test accuracy and performance of `sum` Triton kernel implementations against PyTorch.

This diff generates tensors with dimensions of different sizes. For example, a 2D asymmetric tensor would have shape `(n, n + 3)`; a 3D asymmetric tensor would have shape `(n, n + 3, n + 6)`.

Reviewed By: jbschlosser

Differential Revision: D58509022
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 17, 2024
1 parent f5d3de1 commit b64dc24
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import itertools
from typing import Callable, Generator, List, Optional, Tuple

import torch
Expand All @@ -19,6 +20,9 @@
triton_sum_kernel_scalar_result,
)

GIGABYTES_PER_BYTE = 1e-6
ABSOLUTE_TOLERANCE = 1e-3


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -132,7 +136,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.input_dim = args.input_dim
self.reduce_dim = args.reduce_dim
self.sum_then_buffer = args.sum_then_buffer
self.sizes = range(1, 11)
self.sizes = range(1, 11, 2)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
Expand Down Expand Up @@ -191,11 +195,10 @@ def get_input_iter(self) -> Generator:
self.input_dim <= 3
), f"Existing sum Triton kernels do not support input dimension {self.input_dim}"

for size in self.get_x_vals():
sizes = itertools.product(self.get_x_vals(), repeat=self.input_dim)
for size in sizes:
input_tensor = torch.randn(
tuple(
[size for _ in range(self.input_dim)]
), # tuple with self.input_dim dimensions
size, # tuple with self.input_dim dimensions
device=self.device,
dtype=self.dtype,
)
Expand All @@ -204,15 +207,15 @@ def get_input_iter(self) -> Generator:
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-3)
return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE)

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
return (
example_inputs[0].element_size()
* example_inputs[0].numel()
/ metrics.latency
* 1e-6
* GIGABYTES_PER_BYTE
)

@register_metric(skip_baseline=True)
Expand All @@ -231,3 +234,9 @@ def best_config(
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
else:
return ""

@register_metric(x_only=True)
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return example_inputs[0].shape # return (B, M) for example input

0 comments on commit b64dc24

Please sign in to comment.