Skip to content

Commit

Permalink
Add jagged_sum operator for padded nested tensors to TritonBench (#2305)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2305

Add a `jagged_sum` reduction operator for padded nested tensors, based on the PyTorch `sum` operator, to TritonBench. This diff uses the PyTorch function [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), hosted at this [GitHub pull request](pytorch/pytorch#125968), to pad each 2-dimensional tensor in a nested tensor of shape `(B, *, M)`, then reduce across the `N`-th dimension (`dim == 1`) to a `(B, M)` output tensor.

Measure accuracy of padded implementation against unpadded baseline implementation via `accuracy` TritonBench metric.

Reviewed By: davidberard98

Differential Revision: D58423489
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 17, 2024
1 parent 48223b8 commit a8fcd5a
Showing 1 changed file with 59 additions and 41 deletions.
100 changes: 59 additions & 41 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import itertools
import math
import random
from typing import Callable, Generator, List, Optional, Tuple
Expand All @@ -14,19 +15,21 @@
register_metric,
)

random.seed(16)
torch.manual_seed(16)
seed = 16
random.seed(seed)
torch.manual_seed(seed)

GIGABYTES_PER_BYTE = 1e-6
RANDOM_CHOICE_MARGIN = 0.3
ABSOLUTE_TOLERANCE = 1e-3


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--seqlen",
type=int,
default=100,
default=500,
help="Maximum sequence length on ragged dimension (integer)",
)
parser.add_argument(
Expand All @@ -40,6 +43,8 @@ def parse_op_args(args: List[str]):

class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(4, 10, 2)
Expand All @@ -58,6 +63,17 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
dtype=self.dtype,
)

@register_benchmark()
def torch_jagged_sum_pad(self, x: torch.Tensor):
return lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[self.seqlen], # max length of ragged dimension
),
dim=1,
) # sum along ragged dimension (dim == 1)

def get_x_val(self, example_inputs):
return len(example_inputs[0])

Expand Down Expand Up @@ -90,50 +106,52 @@ def get_input_iter(self) -> Generator:
"""

B_vals, M_vals = self.get_x_vals()

for B in B_vals:
for M in M_vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
seqlen_randint = random.randint(
max(seqlen_avg - seqlen_margin, 1),
min(seqlen_avg + seqlen_margin, self.seqlen),
)
tensor_2d = torch.randn(
(seqlen_randint, M), device=self.device, dtype=self.dtype
)
tensors.append(tensor_2d)

nt = torch.nested.nested_tensor(
tensors,
layout=torch.jagged,
device=self.device,
dtype=self.dtype,
B_M_vals = itertools.product(B_vals, M_vals)

for B, M in B_M_vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
seqlen_randint = random.randint(
max(
seqlen_avg - seqlen_margin, 1
), # seqlen_randint must be at least 1
min(
seqlen_avg + seqlen_margin, self.seqlen
), # seqlen_randint must not exceed self.seqlen
)
tensor_2d = torch.randn(
(seqlen_randint, M), device=self.device, dtype=self.dtype
)
tensors.append(tensor_2d)

yield (nt,)
nt = torch.nested.nested_tensor(
tensors,
layout=torch.jagged,
device=self.device,
dtype=self.dtype,
)

@register_metric()
def B_M(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
return tuple([(ex.size(0), ex.size(2)) for ex in example_inputs])[
0
] # return (B, M) for each example input
yield (nt,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE)

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: example_inputs[0].element_size()
return (
example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
/ metrics.latency
* GIGABYTES_PER_BYTE
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))

0 comments on commit a8fcd5a

Please sign in to comment.