From 0553d82626282c5409c73f3cc8a8ba355e68d3b6 Mon Sep 17 00:00:00 2001 From: Michal Ozery-Flato Date: Thu, 23 May 2024 09:07:15 +0300 Subject: [PATCH] improve _perplexity_update: to consume less GPU memory. Note that the previous command preds[:, target] created a matrix of size [n,n]. Another simple alternative is preds[torch.arange(preds.shape[0]), target] --- .../sequence_gen/metrics_seq_gen_common.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py index 46bc0c37..752aa784 100644 --- a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py +++ b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py @@ -153,6 +153,7 @@ def __init__( # Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68 # copied and not imported to not be affected by internal interface modifications. +# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather def _perplexity_update( batch_dict: dict, preds_key: str, @@ -193,18 +194,18 @@ def _perplexity_update( preds = preds.detach() target = target.detach() - preds = preds.reshape(-1, preds.shape[-1]) - target = target.reshape(-1) + preds = preds.view(-1, preds.shape[-1]) + target = target.view(-1) if ignore_index is not None: mask = target.ne(ignore_index) - target = target.where( - target != ignore_index, torch.tensor(0, device=target.device) - ) + target = target[mask] + preds = preds[mask] + count = mask.sum() else: - mask = torch.ones_like(target, dtype=torch.bool) + count = target.shape[0] - preds = preds[:, target].diagonal()[mask] + preds = torch.gather(preds, 1, target.view(-1, 1)).squeeze(1) # avoid from overflow if preds.dtype == torch.float16: preds = preds.to(torch.float32)