diff --git a/torchbenchmark/operators/sum/operator.py b/torchbenchmark/operators/sum/operator.py index 262061d68..f313c222c 100644 --- a/torchbenchmark/operators/sum/operator.py +++ b/torchbenchmark/operators/sum/operator.py @@ -1,4 +1,5 @@ import argparse +import itertools from typing import Callable, Generator, List, Optional, Tuple import torch @@ -19,6 +20,9 @@ triton_sum_kernel_scalar_result, ) +GIGABYTES_PER_BYTE = 1e-6 +ABSOLUTE_TOLERANCE = 1e-3 + def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() @@ -132,7 +136,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non self.input_dim = args.input_dim self.reduce_dim = args.reduce_dim self.sum_then_buffer = args.sum_then_buffer - self.sizes = range(1, 11) + self.sizes = range(1, 11, 2) @register_benchmark() def triton_sum(self, x: torch.Tensor): @@ -191,11 +195,10 @@ def get_input_iter(self) -> Generator: self.input_dim <= 3 ), f"Existing sum Triton kernels do not support input dimension {self.input_dim}" - for size in self.get_x_vals(): + sizes = itertools.product(self.get_x_vals(), repeat=self.input_dim) + for size in sizes: input_tensor = torch.randn( - tuple( - [size for _ in range(self.input_dim)] - ), # tuple with self.input_dim dimensions + size, # tuple with self.input_dim dimensions device=self.device, dtype=self.dtype, ) @@ -204,7 +207,7 @@ def get_input_iter(self) -> Generator: def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: output = fn() baseline_output = baseline_fn() - return torch.allclose(output, baseline_output, atol=1e-3) + return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE) @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): @@ -212,7 +215,7 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): example_inputs[0].element_size() * example_inputs[0].numel() / metrics.latency - * 1e-6 + * GIGABYTES_PER_BYTE ) @register_metric(skip_baseline=True) @@ -231,3 +234,9 @@ def best_config( return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1) else: return "" + + @register_metric(x_only=True) + def input_shape( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): + return example_inputs[0].shape # return (B, M) for example input