From bbc16c8b92e1b0dbceb766d08cd3871775f3f059 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 18 Jun 2024 09:18:21 -0700 Subject: [PATCH] Add input_shape metric to jagged_sum operator Summary: Add new metric to `jagged_sum` that denotes the 0th and 2nd input dimensions, `B` and `M`, in the form `(B, '*', M)`, where the nested tensor has logical dimensions `(B, *, M)`. Display this metric once per `x` value using the `x_only = True` argument to `register_metric()`. This diff will make TritonBench's benchmark table more readable by denoting the nested tensor dimensions used per benchmark row. Reviewed By: jbschlosser Differential Revision: D58535619 --- torchbenchmark/operators/jagged_sum/operator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchbenchmark/operators/jagged_sum/operator.py b/torchbenchmark/operators/jagged_sum/operator.py index b65e4c3aa..fb650e028 100644 --- a/torchbenchmark/operators/jagged_sum/operator.py +++ b/torchbenchmark/operators/jagged_sum/operator.py @@ -156,3 +156,13 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): / metrics.latency * GIGABYTES_PER_BYTE ) + + @register_metric(x_only=True) + def input_shape( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): + return ( + example_inputs[0].shape[0], + "*", + example_inputs[0].shape[2], + ) # return (B, '*', M) for each example input