From 6fbeafece2f528c42860e373c27f5ad44ca86b87 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 9 Nov 2023 19:32:44 -0800 Subject: [PATCH] move tensor concatenation to compute step from update step (#1498) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1498 We don't need to concatenate the tensor on every update step, since it is an expensive operation (creates a new tensor and allocates new memory every call, as tensors are contiguous) we can call `tensor.concat` on the compute step instead. Which happens every `compute_interval_step` batches. This optimization should boost performance of models using AUC with no regression in metric quality. We've also added an extra unit test consisting of multiple `update` calls before `compute,` ensuring tensor concatenation is done correctly in the `compute` and `update` calls Differential Revision: D51176437 fbshipit-source-id: 891c8b1de5f11c4aed68ab2de73cb7a1df335204 --- torchrec/metrics/auc.py | 117 +++++++++++++++-------------- torchrec/metrics/tests/test_auc.py | 17 +++++ torchrec/metrics/tests/test_gpu.py | 6 +- 3 files changed, 79 insertions(+), 61 deletions(-) diff --git a/torchrec/metrics/auc.py b/torchrec/metrics/auc.py index 688f04583..368545a51 100644 --- a/torchrec/metrics/auc.py +++ b/torchrec/metrics/auc.py @@ -50,64 +50,77 @@ def _compute_auc_helper( def compute_auc( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, - apply_bin: bool = False, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], ) -> torch.Tensor: """ Computes AUC (Area Under the Curve) for binary classification. Args: n_tasks (int): number of tasks. - predictions (torch.Tensor): tensor of size (n_tasks, n_examples). - labels (torch.Tensor): tensor of size (n_tasks, n_examples). - weights (torch.Tensor): tensor of size (n_tasks, n_examples). + predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). + labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). + weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples). """ + # concatenate tensors along dim = -1 + predictions_cat = torch.cat(predictions, dim=-1) + labels_cat = torch.cat(labels, dim=-1) + weights_cat = torch.cat(weights, dim=-1) + aucs = [] - for predictions_i, labels_i, weights_i in zip(predictions, labels, weights): - auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin) + for predictions_i, labels_i, weights_i in zip( + predictions_cat, labels_cat, weights_cat + ): + auc = _compute_auc_helper(predictions_i, labels_i, weights_i) aucs.append(auc.view(1)) return torch.cat(aucs) def compute_auc_per_group( n_tasks: int, - predictions: torch.Tensor, - labels: torch.Tensor, - weights: torch.Tensor, - grouping_keys: torch.Tensor, + predictions: List[torch.Tensor], + labels: List[torch.Tensor], + weights: List[torch.Tensor], + grouping_keys: List[torch.Tensor], ) -> torch.Tensor: """ Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels. Args: n_tasks (int): number of tasks - predictions (torch.Tensor): tensor of size (n_tasks, n_examples) - labels (torch.Tensor): tensor of size (n_tasks, n_examples) - weights (torch.Tensor): tensor of size (n_tasks, n_examples) - grouping_keys (torch.Tensor): tensor of size (n_examples,) + predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) + labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) + weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples) + grouping_keys (List[torch.Tensor]): List of tensors of size (n_examples,) Returns: torch.Tensor: tensor of size (n_tasks,), average of AUCs per group. """ + predictions_cat = torch.cat(predictions, dim=-1) + labels_cat = torch.cat(labels, dim=-1) + weights_cat = torch.cat(weights, dim=-1) + grouping_keys_cat = torch.cat(grouping_keys, dim=-1) + aucs = [] - if grouping_keys.numel() != 0 and grouping_keys[0] == -1: + if grouping_keys_cat.numel() != 0 and grouping_keys_cat[0] == -1: # we added padding as the first elements during init to avoid floating point exception in sync() # removing the paddings to avoid numerical errors. - grouping_keys = grouping_keys[1:] - predictions = predictions[:, 1:] - labels = labels[:, 1:] - weights = weights[:, 1:] + grouping_keys_cat = grouping_keys_cat[1:] + predictions_cat = predictions_cat[:, 1:] + labels_cat = labels_cat[:, 1:] + weights_cat = weights_cat[:, 1:] # get unique group indices - group_indices = torch.unique(grouping_keys) + group_indices = torch.unique(grouping_keys_cat) - for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights): + for (predictions_i, labels_i, weights_i) in zip( + predictions_cat, labels_cat, weights_cat + ): # Loop over each group auc_groups_sum = torch.tensor([0], dtype=torch.float32) for group_idx in group_indices: # get predictions, labels, and weights for this group - group_mask = grouping_keys == group_idx + group_mask = grouping_keys_cat == group_idx grouped_predictions = predictions_i[group_mask] grouped_labels = labels_i[group_mask] grouped_weights = weights_i[group_mask] @@ -241,25 +254,12 @@ def update( predictions = predictions.float() labels = labels.float() weights = weights.float() - num_samples = getattr(self, PREDICTIONS)[0].size(-1) - batch_size = predictions.size(-1) - start_index = max(num_samples + batch_size - self._window_size, 0) + # Using `self.predictions =` will cause Pyre errors. - getattr(self, PREDICTIONS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:], - predictions, - ], - dim=-1, - ) - getattr(self, LABELS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels], - dim=-1, - ) - getattr(self, WEIGHTS)[0] = torch.cat( - [cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights], - dim=-1, - ) + getattr(self, PREDICTIONS).append(predictions) + getattr(self, LABELS).append(labels) + getattr(self, WEIGHTS).append(weights) + if self._grouped_auc: if REQUIRED_INPUTS not in kwargs or ( (grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None @@ -267,13 +267,8 @@ def update( raise RecMetricException( f"Input '{GROUPING_KEYS}' are required for AUCMetricComputation grouped update" ) - getattr(self, GROUPING_KEYS)[0] = torch.cat( - [ - cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:], - grouping_keys.squeeze(), - ], - dim=0, - ) + + getattr(self, GROUPING_KEYS).append(grouping_keys.squeeze()) def _compute(self) -> List[MetricComputationReport]: reports = [ @@ -282,10 +277,12 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), - self._apply_bin, + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, PREDICTIONS)), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, LABELS)), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, WEIGHTS)), ), ) ] @@ -296,10 +293,14 @@ def _compute(self) -> List[MetricComputationReport]: metric_prefix=MetricPrefix.WINDOW, value=compute_auc_per_group( self._n_tasks, - cast(torch.Tensor, getattr(self, PREDICTIONS)[0]), - cast(torch.Tensor, getattr(self, LABELS)[0]), - cast(torch.Tensor, getattr(self, WEIGHTS)[0]), - cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, PREDICTIONS)), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, LABELS)), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, WEIGHTS)), + # pyre-ignore[6] + cast(torch.Tensor, getattr(self, GROUPING_KEYS)), ), ) ) diff --git a/torchrec/metrics/tests/test_auc.py b/torchrec/metrics/tests/test_auc.py index e81922cd4..8de096dc3 100644 --- a/torchrec/metrics/tests/test_auc.py +++ b/torchrec/metrics/tests/test_auc.py @@ -177,6 +177,23 @@ def test_calc_auc_balanced(self) -> None: actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"] torch.allclose(expected_auc, actual_auc) + def test_calc_multiple_updates(self) -> None: + expected_auc = torch.tensor([0.4464], dtype=torch.float) + # first batch + self.labels["DefaultTask"] = torch.tensor([1, 0, 0]) + self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8]) + self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5]) + + self.auc.update(**self.batches) + # second batch + self.labels["DefaultTask"] = torch.tensor([1, 1]) + self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9]) + self.weights["DefaultTask"] = torch.tensor([0.8, 0.75]) + + self.auc.update(**self.batches) + multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"] + torch.allclose(expected_auc, multiple_batch) + def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: return [ diff --git a/torchrec/metrics/tests/test_gpu.py b/torchrec/metrics/tests/test_gpu.py index 9ca7f9ad4..184ea5c12 100644 --- a/torchrec/metrics/tests/test_gpu.py +++ b/torchrec/metrics/tests/test_gpu.py @@ -48,9 +48,9 @@ def test_auc_reset(self) -> None: labels={"DefaultTask": model_output["label"]}, weights={"DefaultTask": model_output["weight"]}, ) - self.assertEqual(len(auc._metrics_computations[0].predictions), 1) - self.assertEqual(len(auc._metrics_computations[0].labels), 1) - self.assertEqual(len(auc._metrics_computations[0].weights), 1) + self.assertEqual(len(auc._metrics_computations[0].predictions), 2) + self.assertEqual(len(auc._metrics_computations[0].labels), 2) + self.assertEqual(len(auc._metrics_computations[0].weights), 2) self.assertEqual(auc._metrics_computations[0].predictions[0].device, device) self.assertEqual(auc._metrics_computations[0].labels[0].device, device) self.assertEqual(auc._metrics_computations[0].weights[0].device, device)