-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds count num sequences and tokens metric #346
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,13 +18,99 @@ | |
""" | ||
from typing import Optional, Tuple, List | ||
from functools import partial | ||
|
||
from copy import copy | ||
import torch | ||
import numpy as np | ||
|
||
from fuse.eval.metrics.metrics_common import MetricPerBatchDefault | ||
|
||
|
||
class MetricCountSeqAndTokens(MetricPerBatchDefault): | ||
""" | ||
Counts the total number sequences and tokens in encoder_input | ||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_input: str, | ||
ignore_index: Optional[int] = None, | ||
state: Optional[dict] = None, | ||
**kwargs: dict, | ||
) -> None: | ||
""" | ||
:param encoder_input: key to the encoder_input | ||
:param ignore_index: token_id to ignore (not to count), typically pad token id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be able to support a list of token ids to ignore. Unless you want to enforce the user to ignore only the PAD one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with just one to be more efficient - typically we would like to just skip the padding. |
||
:param state: the sequence count and token count to continue for. Should be restored when we continue training. | ||
use get_state() to get the state and save it upon checkpointing, | ||
:param kwargs: additional super class arguments | ||
""" | ||
super().__init__( | ||
seq_num="seq_num", # collect log_probs - output of _count_seq_and_tokens_update | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. obsolete comments in this line and the following one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
token_num="token_num", # collect token_num - output of _count_seq_and_tokens_update | ||
metric_per_batch_func=None, | ||
metric_per_batch_func_pre_collect=partial( | ||
_count_seq_and_tokens_update, | ||
ignore_index=ignore_index, | ||
encoder_input_key=encoder_input, | ||
), | ||
result_aggregate_func=self._count_seq_and_tokens_compute, | ||
**kwargs, | ||
) | ||
if state is None: | ||
self._state = {"seq_num": 0, "token_num": 0} | ||
else: | ||
assert "seq_num" in state | ||
assert "token_num" in state | ||
self._state = state | ||
|
||
def _count_seq_and_tokens_compute( | ||
self, | ||
seq_num: List[np.ndarray], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I forgot these :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. each sub epoch and each entry is a batch. |
||
token_num: List[np.ndarray], | ||
) -> float: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. returns a dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
seq_num_total = sum(seq_num) | ||
token_num_total = sum(token_num) | ||
self._state["seq_num"] += seq_num_total | ||
self._state["token_num"] += token_num_total | ||
return copy(self._state) | ||
|
||
def get_state(self) -> dict: | ||
return copy(self._state) | ||
|
||
|
||
def _count_seq_and_tokens_update( | ||
batch_dict: dict, | ||
encoder_input_key: str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. encoder_input_key: Union[str, None] or encoder_input_key: Optional[str] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's. a must. Why optional? |
||
ignore_index: Optional[int] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> dict[str, Tensor] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
"""Count number of sequences and tokens | ||
Args: | ||
encoder_input_key: | ||
key to encoder_input | ||
ignore_index: | ||
Token not to count, typically padding | ||
Returns: | ||
dictionary with number of sequences and tokens | ||
""" | ||
encoder_input = batch_dict[encoder_input_key] | ||
|
||
# to save GPU memory | ||
encoder_input = encoder_input.detach() | ||
|
||
if ignore_index is not None: | ||
mask = encoder_input.ne(ignore_index) | ||
else: | ||
mask = torch.ones_like(encoder_input, dtype=torch.bool) | ||
|
||
seq_num = torch.tensor( | ||
mask.shape[0], dtype=torch.int64, device=encoder_input.device | ||
) | ||
token_num = mask.sum().to(dtype=torch.int64) | ||
|
||
return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)} | ||
|
||
|
||
class MetricPerplexity(MetricPerBatchDefault): | ||
def __init__( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General question:
I'm not sure counting the sequences and tokens should be defined as metric. I don't have another suggestion it's just sounds weird :)
What do you think of that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It uses the metric mechanism, and it's ok to me that it just counts some stats.