From f450c5985dfa1447e8935e4633de51cc2e5e443c Mon Sep 17 00:00:00 2001 From: Chenyu Zhao Date: Fri, 6 Dec 2024 01:59:59 -0800 Subject: [PATCH] GAUC fix (#2619) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2619 fix gauc multiple returns Reviewed By: yunjiangster, arsatis Differential Revision: D66853683 fbshipit-source-id: cf5f47a22086b2d225dc38b6dd3d90b25e53cc93 --- torchrec/metrics/gauc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrec/metrics/gauc.py b/torchrec/metrics/gauc.py index c509475b9..829f875c0 100644 --- a/torchrec/metrics/gauc.py +++ b/torchrec/metrics/gauc.py @@ -209,8 +209,8 @@ def _compute(self) -> List[MetricComputationReport]: name=MetricName.GAUC_NUM_SAMPLES, metric_prefix=MetricPrefix.LIFETIME, value=compute_window_auc( - self.get_window_state("auc_sum"), - self.get_window_state("num_samples"), + cast(torch.Tensor, self.auc_sum), + cast(torch.Tensor, self.num_samples), )["num_samples"], ), MetricComputationReport(