Skip to content

Commit

Permalink
Back out "move tensor concatenation to compute step from update step" (
Browse files Browse the repository at this point in the history
…#1506)

Summary:
Pull Request resolved: #1506

OOM issues on MAST, changing AUC to optionally apply tensor cats in update or compute.

Original commit changeset: 891c8b1de5f1

Original Phabricator Diff: D51176437

Differential Revision: D51270369

fbshipit-source-id: 5f51cc5e037710956c2c6cbe5a1249b137b8f1f5
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Nov 13, 2023
1 parent e2cc13a commit c7770f9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 79 deletions.
117 changes: 58 additions & 59 deletions torchrec/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,77 +50,64 @@ def _compute_auc_helper(

def compute_auc(
n_tasks: int,
predictions: List[torch.Tensor],
labels: List[torch.Tensor],
weights: List[torch.Tensor],
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
apply_bin: bool = False,
) -> torch.Tensor:
"""
Computes AUC (Area Under the Curve) for binary classification.
Args:
n_tasks (int): number of tasks.
predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples).
labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples).
weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples).
predictions (torch.Tensor): tensor of size (n_tasks, n_examples).
labels (torch.Tensor): tensor of size (n_tasks, n_examples).
weights (torch.Tensor): tensor of size (n_tasks, n_examples).
"""
# concatenate tensors along dim = -1
predictions_cat = torch.cat(predictions, dim=-1)
labels_cat = torch.cat(labels, dim=-1)
weights_cat = torch.cat(weights, dim=-1)

aucs = []
for predictions_i, labels_i, weights_i in zip(
predictions_cat, labels_cat, weights_cat
):
auc = _compute_auc_helper(predictions_i, labels_i, weights_i)
for predictions_i, labels_i, weights_i in zip(predictions, labels, weights):
auc = _compute_auc_helper(predictions_i, labels_i, weights_i, apply_bin)
aucs.append(auc.view(1))
return torch.cat(aucs)


def compute_auc_per_group(
n_tasks: int,
predictions: List[torch.Tensor],
labels: List[torch.Tensor],
weights: List[torch.Tensor],
grouping_keys: List[torch.Tensor],
predictions: torch.Tensor,
labels: torch.Tensor,
weights: torch.Tensor,
grouping_keys: torch.Tensor,
) -> torch.Tensor:
"""
Computes AUC (Area Under the Curve) for binary classification for groups of predictions/labels.
Args:
n_tasks (int): number of tasks
predictions (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples)
labels (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples)
weights (List[torch.Tensor]): List of tensors of size (n_tasks, n_examples)
grouping_keys (List[torch.Tensor]): List of tensors of size (n_examples,)
predictions (torch.Tensor): tensor of size (n_tasks, n_examples)
labels (torch.Tensor): tensor of size (n_tasks, n_examples)
weights (torch.Tensor): tensor of size (n_tasks, n_examples)
grouping_keys (torch.Tensor): tensor of size (n_examples,)
Returns:
torch.Tensor: tensor of size (n_tasks,), average of AUCs per group.
"""
predictions_cat = torch.cat(predictions, dim=-1)
labels_cat = torch.cat(labels, dim=-1)
weights_cat = torch.cat(weights, dim=-1)
grouping_keys_cat = torch.cat(grouping_keys, dim=-1)

aucs = []
if grouping_keys_cat.numel() != 0 and grouping_keys_cat[0] == -1:
if grouping_keys.numel() != 0 and grouping_keys[0] == -1:
# we added padding as the first elements during init to avoid floating point exception in sync()
# removing the paddings to avoid numerical errors.
grouping_keys_cat = grouping_keys_cat[1:]
predictions_cat = predictions_cat[:, 1:]
labels_cat = labels_cat[:, 1:]
weights_cat = weights_cat[:, 1:]
grouping_keys = grouping_keys[1:]
predictions = predictions[:, 1:]
labels = labels[:, 1:]
weights = weights[:, 1:]

# get unique group indices
group_indices = torch.unique(grouping_keys_cat)
group_indices = torch.unique(grouping_keys)

for (predictions_i, labels_i, weights_i) in zip(
predictions_cat, labels_cat, weights_cat
):
for (predictions_i, labels_i, weights_i) in zip(predictions, labels, weights):
# Loop over each group
auc_groups_sum = torch.tensor([0], dtype=torch.float32)
for group_idx in group_indices:
# get predictions, labels, and weights for this group
group_mask = grouping_keys_cat == group_idx
group_mask = grouping_keys == group_idx
grouped_predictions = predictions_i[group_mask]
grouped_labels = labels_i[group_mask]
grouped_weights = weights_i[group_mask]
Expand Down Expand Up @@ -254,21 +241,39 @@ def update(
predictions = predictions.float()
labels = labels.float()
weights = weights.float()

num_samples = getattr(self, PREDICTIONS)[0].size(-1)
batch_size = predictions.size(-1)
start_index = max(num_samples + batch_size - self._window_size, 0)
# Using `self.predictions =` will cause Pyre errors.
getattr(self, PREDICTIONS).append(predictions)
getattr(self, LABELS).append(labels)
getattr(self, WEIGHTS).append(weights)

getattr(self, PREDICTIONS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, PREDICTIONS)[0])[:, start_index:],
predictions,
],
dim=-1,
)
getattr(self, LABELS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, LABELS)[0])[:, start_index:], labels],
dim=-1,
)
getattr(self, WEIGHTS)[0] = torch.cat(
[cast(torch.Tensor, getattr(self, WEIGHTS)[0])[:, start_index:], weights],
dim=-1,
)
if self._grouped_auc:
if REQUIRED_INPUTS not in kwargs or (
(grouping_keys := kwargs[REQUIRED_INPUTS].get(GROUPING_KEYS)) is None
):
raise RecMetricException(
f"Input '{GROUPING_KEYS}' are required for AUCMetricComputation grouped update"
)

getattr(self, GROUPING_KEYS).append(grouping_keys.squeeze())
getattr(self, GROUPING_KEYS)[0] = torch.cat(
[
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0])[start_index:],
grouping_keys.squeeze(),
],
dim=0,
)

def _compute(self) -> List[MetricComputationReport]:
reports = [
Expand All @@ -277,12 +282,10 @@ def _compute(self) -> List[MetricComputationReport]:
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc(
self._n_tasks,
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, PREDICTIONS)),
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, LABELS)),
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, WEIGHTS)),
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
self._apply_bin,
),
)
]
Expand All @@ -293,14 +296,10 @@ def _compute(self) -> List[MetricComputationReport]:
metric_prefix=MetricPrefix.WINDOW,
value=compute_auc_per_group(
self._n_tasks,
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, PREDICTIONS)),
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, LABELS)),
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, WEIGHTS)),
# pyre-ignore[6]
cast(torch.Tensor, getattr(self, GROUPING_KEYS)),
cast(torch.Tensor, getattr(self, PREDICTIONS)[0]),
cast(torch.Tensor, getattr(self, LABELS)[0]),
cast(torch.Tensor, getattr(self, WEIGHTS)[0]),
cast(torch.Tensor, getattr(self, GROUPING_KEYS)[0]),
),
)
)
Expand Down
17 changes: 0 additions & 17 deletions torchrec/metrics/tests/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,6 @@ def test_calc_auc_balanced(self) -> None:
actual_auc = self.auc.compute()["auc-DefaultTask|window_auc"]
torch.allclose(expected_auc, actual_auc)

def test_calc_multiple_updates(self) -> None:
expected_auc = torch.tensor([0.4464], dtype=torch.float)
# first batch
self.labels["DefaultTask"] = torch.tensor([1, 0, 0])
self.predictions["DefaultTask"] = torch.tensor([0.2, 0.6, 0.8])
self.weights["DefaultTask"] = torch.tensor([0.13, 0.2, 0.5])

self.auc.update(**self.batches)
# second batch
self.labels["DefaultTask"] = torch.tensor([1, 1])
self.predictions["DefaultTask"] = torch.tensor([0.4, 0.9])
self.weights["DefaultTask"] = torch.tensor([0.8, 0.75])

self.auc.update(**self.batches)
multiple_batch = self.auc.compute()["auc-DefaultTask|window_auc"]
torch.allclose(expected_auc, multiple_batch)


def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
return [
Expand Down
6 changes: 3 additions & 3 deletions torchrec/metrics/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_auc_reset(self) -> None:
labels={"DefaultTask": model_output["label"]},
weights={"DefaultTask": model_output["weight"]},
)
self.assertEqual(len(auc._metrics_computations[0].predictions), 2)
self.assertEqual(len(auc._metrics_computations[0].labels), 2)
self.assertEqual(len(auc._metrics_computations[0].weights), 2)
self.assertEqual(len(auc._metrics_computations[0].predictions), 1)
self.assertEqual(len(auc._metrics_computations[0].labels), 1)
self.assertEqual(len(auc._metrics_computations[0].weights), 1)
self.assertEqual(auc._metrics_computations[0].predictions[0].device, device)
self.assertEqual(auc._metrics_computations[0].labels[0].device, device)
self.assertEqual(auc._metrics_computations[0].weights[0].device, device)
Expand Down

0 comments on commit c7770f9

Please sign in to comment.