diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 07da2198ad..6d763b6ca3 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -210,7 +210,10 @@ def select_metric(m): # Append x_val_only metrics for x_only_metric in x_only_metrics: x_only_metric_dict = asdict(y_val[y_val_keys[0]]) - row.append(x_only_metric_dict[x_only_metric]) + if "extra_metrics" in x_only_metric_dict and x_only_metric in x_only_metric_dict["extra_metrics"]: + row.append(x_only_metric_dict["extra_metrics"][x_only_metric]) + else: + row.append(x_only_metric_dict[x_only_metric]) for k, _label in y_val_keys: metrics_dict = asdict(y_val[k]) if metrics_dict["error_msg"]: