diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 2e3c5e350..67a2dbd0e 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -43,20 +43,6 @@ REGISTERED_METRICS: Dict[str, List[str]] = {} REGISTERED_X_VALS: Dict[str, str] = {} BASELINE_BENCHMARKS: Dict[str, str] = {} -BUILTIN_METRICS = [ - "latency", - "tflops", - "speedup", - "accuracy", - "compile_time", - "ncu_trace", - "ncu_rep", - "kineto_trace", - "cpu_peak_mem", - "gpu_peak_mem", - "hw_roofline", - "best_config", -] BASELINE_SKIP_METRICS = set(["speedup", "accuracy"]) X_ONLY_METRICS = set(["hw_roofline"]) PRECISION_DTYPE_MAPPING = { @@ -202,6 +188,7 @@ class BenchmarkOperatorMetrics: # extra metrics extra_metrics: Optional[Dict[str, float]] = None +BUILTIN_METRICS = set(map(lambda x: x.name, fields(BenchmarkOperatorMetrics))) - {"extra_metrics"} @dataclass class BenchmarkOperatorResult: