Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds count num sequences and tokens metric #346

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion fuse/eval/examples/examples_seq_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from typing import Any, Dict
from collections import OrderedDict

from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import MetricPerplexity
from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import (
MetricPerplexity,
MetricCountSeqAndTokens,
)

from fuse.eval.evaluator import EvaluatorDefault

Expand Down Expand Up @@ -94,3 +97,43 @@ def example_seq_gen_1(seed: int = 1234) -> Dict[str, Any]:
print(results)

return results


def example_seq_gen_2() -> Dict[str, Any]:
"""
Example/Test for perplexity metric - batch mode
"""

encoder_input_tokens = torch.arange(5000).reshape(10, 500)
data = {
"encoder_input_tokens": list(encoder_input_tokens),
"id": list(range(10)),
}
data = pd.DataFrame(data)

# Working with pytorch dataloader mode
dynamic_pipeline = PipelineDefault(
"test",
[
(OpReadDataframe(data, key_column="id"), dict()),
],
)
ds = DatasetDefault(sample_ids=len(data), dynamic_pipeline=dynamic_pipeline)
ds.create()
dl = DataLoader(ds, collate_fn=CollateDefault())
metrics = OrderedDict(
[
(
"count",
MetricCountSeqAndTokens(
encoder_input="encoder_input_tokens", ignore_index=4999
),
)
]
)

evaluator = EvaluatorDefault()
results = evaluator.eval(ids=None, data=dl, metrics=metrics, batch_size=0)
print(results)

return results
90 changes: 88 additions & 2 deletions fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,101 @@
Created on June 30, 2021

"""
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Dict
from functools import partial

from copy import copy
import torch
import numpy as np

from fuse.eval.metrics.metrics_common import MetricPerBatchDefault


class MetricCountSeqAndTokens(MetricPerBatchDefault):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question:
I'm not sure counting the sequences and tokens should be defined as metric. I don't have another suggestion it's just sounds weird :)

What do you think of that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It uses the metric mechanism, and it's ok to me that it just counts some stats.

"""
Counts the total number sequences and tokens in encoder_input
"""

def __init__(
self,
encoder_input: str,
ignore_index: Optional[int] = None,
state: Optional[dict] = None,
**kwargs: dict,
) -> None:
"""
:param encoder_input: key to the encoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be able to support a list of token ids to ignore. Unless you want to enforce the user to ignore only the PAD one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with just one to be more efficient - typically we would like to just skip the padding.

:param state: the sequence count and token count to continue for. Should be restored when we continue training.
use get_state() to get the state and save it upon checkpointing,
:param kwargs: additional super class arguments
"""
super().__init__(
seq_num="seq_num", # collect seq_num - output of _count_seq_and_tokens_update
token_num="token_num", # collect token_num - output of _count_seq_and_tokens_update
metric_per_batch_func=None,
metric_per_batch_func_pre_collect=partial(
_count_seq_and_tokens_update,
ignore_index=ignore_index,
encoder_input_key=encoder_input,
),
result_aggregate_func=self._count_seq_and_tokens_compute,
**kwargs,
)
if state is None:
self._state = {"seq_num": 0, "token_num": 0}
else:
assert "seq_num" in state
assert "token_num" in state
self._state = state

def _count_seq_and_tokens_compute(
self,
seq_num: List[np.ndarray],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_num will be a numpy array such that each entry represents a batch? If so, how often the metrics being calculate? each epoch?

I forgot these :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sub epoch and each entry is a batch.

token_num: List[np.ndarray],
) -> dict:

seq_num_total = sum(seq_num)
token_num_total = sum(token_num)
self._state["seq_num"] += seq_num_total
self._state["token_num"] += token_num_total
return copy(self._state)

def get_state(self) -> dict:
return copy(self._state)


def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encoder_input_key: Union[str, None] 

or

encoder_input_key: Optional[str]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's. a must. Why optional?

ignore_index: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
ignore_index:
Token not to count, typically padding
Returns:
dictionary with number of sequences and tokens
"""
encoder_input = batch_dict[encoder_input_key]

# to save GPU memory
encoder_input = encoder_input.detach()

if ignore_index is not None:
mask = encoder_input.ne(ignore_index)
else:
mask = torch.ones_like(encoder_input, dtype=torch.bool)

seq_num = torch.tensor(
mask.shape[0], dtype=torch.int64, device=encoder_input.device
)
token_num = mask.sum().to(dtype=torch.int64)

return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)}


class MetricPerplexity(MetricPerBatchDefault):
def __init__(
self,
Expand Down
11 changes: 10 additions & 1 deletion fuse/eval/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
example_seg_4,
)

from fuse.eval.examples.examples_seq_gen import example_seq_gen_0, example_seq_gen_1
from fuse.eval.examples.examples_seq_gen import (
example_seq_gen_0,
example_seq_gen_1,
example_seq_gen_2,
)

from fuse.eval.examples.examples_stats import example_pearson_correlation

Expand Down Expand Up @@ -234,6 +238,11 @@ def test_eval_example_seq_gen_1(self) -> None:
results = example_seq_gen_1(seed=1234)
self.assertAlmostEqual(results["metrics.perplexity"], 162.87, places=2)

def test_eval_example_seq_gen_2(self) -> None:
results = example_seq_gen_2()
self.assertAlmostEqual(results["metrics.count.seq_num"], 10)
self.assertAlmostEqual(results["metrics.count.token_num"], 4999)

def test_pearson_correlation(self) -> None:
res = example_pearson_correlation()
self.assertAlmostEqual(res["metrics.pearsonr"], 1.0, places=2)
Expand Down
Loading