diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index 0d1a3ff748..cb93b26d40 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -1,4 +1,5 @@ import argparse +import itertools import math import random from typing import Callable, Generator, List, Optional, Tuple @@ -14,11 +15,13 @@ register_metric, ) -random.seed(16) -torch.manual_seed(16) +seed = 16 +random.seed(seed) +torch.manual_seed(seed) GIGABYTES_PER_BYTE = 1e-6 RANDOM_CHOICE_MARGIN = 0.3 +ABSOLUTE_TOLERANCE = 1e-3 def parse_op_args(args: List[str]): @@ -26,7 +29,7 @@ def parse_op_args(args: List[str]): parser.add_argument( "--seqlen", type=int, - default=100, + default=500, help="Maximum sequence length on ragged dimension (integer)", ) parser.add_argument( @@ -40,6 +43,8 @@ def parse_op_args(args: List[str]): class Operator(BenchmarkOperator): + DEFAULT_METRICS = ["latency", "accuracy"] + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): super().__init__(mode=mode, device=device, extra_args=extra_args) self.sizes = range(4, 10, 2) @@ -58,6 +63,17 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor): dtype=self.dtype, ) + @register_benchmark() + def torch_jagged_sum_pad(self, x: torch.Tensor): + return lambda: torch.sum( + torch.ops.aten._jagged_to_padded_dense_forward( + x.values(), + [x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`. + max_lengths=[self.seqlen], # max length of ragged dimension + ), + dim=1, + ) # sum along ragged dimension (dim == 1) + def get_x_val(self, example_inputs): return len(example_inputs[0]) @@ -90,50 +106,52 @@ def get_input_iter(self) -> Generator: """ B_vals, M_vals = self.get_x_vals() - - for B in B_vals: - for M in M_vals: - tensors = [] - - # greater sparsity --> shorter sequence lengths on ragged dimension - seqlen_avg = math.floor( - self.seqlen * (1 - self.sparsity) - ) # average sequence length across all tensors in nested tensor - seqlen_margin = math.floor( - self.seqlen * RANDOM_CHOICE_MARGIN - ) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity - - for _ in range(B): - seqlen_randint = random.randint( - max(seqlen_avg - seqlen_margin, 1), - min(seqlen_avg + seqlen_margin, self.seqlen), - ) - tensor_2d = torch.randn( - (seqlen_randint, M), device=self.device, dtype=self.dtype - ) - tensors.append(tensor_2d) - - nt = torch.nested.nested_tensor( - tensors, - layout=torch.jagged, - device=self.device, - dtype=self.dtype, + B_M_vals = itertools.product(B_vals, M_vals) + + for B, M in B_M_vals: + tensors = [] + + # greater sparsity --> shorter sequence lengths on ragged dimension + seqlen_avg = math.floor( + self.seqlen * (1 - self.sparsity) + ) # average sequence length across all tensors in nested tensor + seqlen_margin = math.floor( + self.seqlen * RANDOM_CHOICE_MARGIN + ) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity + + for _ in range(B): + seqlen_randint = random.randint( + max( + seqlen_avg - seqlen_margin, 1 + ), # seqlen_randint must be at least 1 + min( + seqlen_avg + seqlen_margin, self.seqlen + ), # seqlen_randint must not exceed self.seqlen ) + tensor_2d = torch.randn( + (seqlen_randint, M), device=self.device, dtype=self.dtype + ) + tensors.append(tensor_2d) - yield (nt,) + nt = torch.nested.nested_tensor( + tensors, + layout=torch.jagged, + device=self.device, + dtype=self.dtype, + ) - @register_metric() - def B_M(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): - return tuple([(ex.size(0), ex.size(2)) for ex in example_inputs])[ - 0 - ] # return (B, M) for each example input + yield (nt,) + + def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: + output = fn() + baseline_output = baseline_fn() + return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE) @register_metric() def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): - gbps = ( - lambda ms: example_inputs[0].element_size() + return ( + example_inputs[0].element_size() * example_inputs[0].numel() - / ms + / metrics.latency * GIGABYTES_PER_BYTE ) - return list(map(gbps, metrics.latency if metrics.latency else [0]))