Skip to content

Commit

Permalink
Add unit tests on CPU for TritonBench features (#2323)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2323

Add unit tests that run on the CPU to verify the behavior of the following:
- `x_only = True` for metric registration in [`register_metric()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=337)
- custom `label` argument for benchmark registration in [`register_benchmark()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=316)

Reviewed By: xuzhao9

Differential Revision: D58558868
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 20, 2024
1 parent caa76d8 commit 3e9ceec
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/test_op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
44 changes: 44 additions & 0 deletions torchbenchmark/operators/test_op/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Generator, List, Optional

import torch

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
register_benchmark,
register_metric,
)


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["test_metric"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)

@register_benchmark(label="new_op_label")
def test_op(self, x: torch.Tensor):
return lambda: x

def get_x_val(self, example_inputs):
return example_inputs[0].shape

def get_x_vals(self) -> List[int]:
return [2**n for n in [1, 2, 3]]

def get_input_iter(self) -> Generator:
for x in self.get_x_vals():
yield (torch.Tensor(torch.randn(x, device=self.device, dtype=self.dtype)),)

@register_metric(x_only=True)
def test_metric(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return [ex.shape[0] + 2 for ex in example_inputs]

@register_metric()
def test_metric_per_benchmark(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return [ex.shape[0] + 3 for ex in example_inputs]
1 change: 1 addition & 0 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None)
self._only = _split_params_by_comma(self.tb_args.only)
self._input_id = self.tb_args.input_id
self._num_inputs = self.tb_args.num_inputs
self.device = device

# Run the post initialization
def __post__init__(self):
Expand Down

0 comments on commit 3e9ceec

Please sign in to comment.