Skip to content

Commit

Permalink
improve _perplexity_update: to consume less GPU memory.
Browse files Browse the repository at this point in the history
Note that the previous command preds[:, target] created a matrix of size [n,n].
Another simple alternative is
preds[torch.arange(preds.shape[0]), target]
  • Loading branch information
michalozeryflato committed May 23, 2024
1 parent df0cf14 commit 0553d82
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(

# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68
# copied and not imported to not be affected by internal interface modifications.
# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather
def _perplexity_update(
batch_dict: dict,
preds_key: str,
Expand Down Expand Up @@ -193,18 +194,18 @@ def _perplexity_update(
preds = preds.detach()
target = target.detach()

preds = preds.reshape(-1, preds.shape[-1])
target = target.reshape(-1)
preds = preds.view(-1, preds.shape[-1])
target = target.view(-1)

if ignore_index is not None:
mask = target.ne(ignore_index)
target = target.where(
target != ignore_index, torch.tensor(0, device=target.device)
)
target = target[mask]
preds = preds[mask]
count = mask.sum()
else:
mask = torch.ones_like(target, dtype=torch.bool)
count = target.shape[0]

preds = preds[:, target].diagonal()[mask]
preds = torch.gather(preds, 1, target.view(-1, 1)).squeeze(1)
# avoid from overflow
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
Expand Down

0 comments on commit 0553d82

Please sign in to comment.