Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds count num sequences and tokens metric #346

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,99 @@
"""
from typing import Optional, Tuple, List
from functools import partial

from copy import copy
import torch
import numpy as np

from fuse.eval.metrics.metrics_common import MetricPerBatchDefault


class MetricCountSeqAndTokens(MetricPerBatchDefault):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question:
I'm not sure counting the sequences and tokens should be defined as metric. I don't have another suggestion it's just sounds weird :)

What do you think of that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses the metric mechanism, and it's ok to me that it just counts some stats.

"""
Counts the total number sequences and tokens in encoder_input
"""

def __init__(
self,
encoder_input: str,
ignore_index: Optional[int] = None,
state: Optional[dict] = None,
**kwargs: dict,
) -> None:
"""
:param encoder_input: key to the encoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be able to support a list of token ids to ignore. Unless you want to enforce the user to ignore only the PAD one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with just one to be more efficient - typically we would like to just skip the padding.

:param state: the sequence count and token count to continue for. Should be restored when we continue training.
use get_state() to get the state and save it upon checkpointing,
:param kwargs: additional super class arguments
"""
super().__init__(
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

obsolete comments in this line and the following one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

token_num="token_num", # collect token_num - output of _count_seq_and_tokens_update
metric_per_batch_func=None,
metric_per_batch_func_pre_collect=partial(
_count_seq_and_tokens_update,
ignore_index=ignore_index,
encoder_input_key=encoder_input,
),
result_aggregate_func=self._count_seq_and_tokens_compute,
**kwargs,
)
if state is None:
self._state = {"seq_num": 0, "token_num": 0}
else:
assert "seq_num" in state
assert "token_num" in state
self._state = state

def _count_seq_and_tokens_compute(
self,
seq_num: List[np.ndarray],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_num will be a numpy array such that each entry represents a batch? If so, how often the metrics being calculate? each epoch?

I forgot these :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sub epoch and each entry is a batch.

token_num: List[np.ndarray],
) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a dict

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


seq_num_total = sum(seq_num)
token_num_total = sum(token_num)
self._state["seq_num"] += seq_num_total
self._state["token_num"] += token_num_total
return copy(self._state)

def get_state(self) -> dict:
return copy(self._state)


def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder_input_key: Union[str, None] 

or

encoder_input_key: Optional[str]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's. a must. Why optional?

ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> dict[str, Tensor]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
ignore_index:
Token not to count, typically padding
Returns:
dictionary with number of sequences and tokens
"""
encoder_input = batch_dict[encoder_input_key]

# to save GPU memory
encoder_input = encoder_input.detach()

if ignore_index is not None:
mask = encoder_input.ne(ignore_index)
else:
mask = torch.ones_like(encoder_input, dtype=torch.bool)

seq_num = torch.tensor(
mask.shape[0], dtype=torch.int64, device=encoder_input.device
)
token_num = mask.sum().to(dtype=torch.int64)

return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)}


class MetricPerplexity(MetricPerBatchDefault):
def __init__(
self,
Expand Down
Loading