diff --git a/fuse/eval/examples/examples_seq_gen.py b/fuse/eval/examples/examples_seq_gen.py index 346eec58..32a1b51a 100644 --- a/fuse/eval/examples/examples_seq_gen.py +++ b/fuse/eval/examples/examples_seq_gen.py @@ -15,7 +15,10 @@ from typing import Any, Dict from collections import OrderedDict -from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import MetricPerplexity +from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import ( + MetricPerplexity, + MetricCountSeqAndTokens, +) from fuse.eval.evaluator import EvaluatorDefault @@ -94,3 +97,43 @@ def example_seq_gen_1(seed: int = 1234) -> Dict[str, Any]: print(results) return results + + +def example_seq_gen_2() -> Dict[str, Any]: + """ + Example/Test for perplexity metric - batch mode + """ + + encoder_input_tokens = torch.arange(5000).reshape(10, 500) + data = { + "encoder_input_tokens": list(encoder_input_tokens), + "id": list(range(10)), + } + data = pd.DataFrame(data) + + # Working with pytorch dataloader mode + dynamic_pipeline = PipelineDefault( + "test", + [ + (OpReadDataframe(data, key_column="id"), dict()), + ], + ) + ds = DatasetDefault(sample_ids=len(data), dynamic_pipeline=dynamic_pipeline) + ds.create() + dl = DataLoader(ds, collate_fn=CollateDefault()) + metrics = OrderedDict( + [ + ( + "count", + MetricCountSeqAndTokens( + encoder_input="encoder_input_tokens", ignore_index=4999 + ), + ) + ] + ) + + evaluator = EvaluatorDefault() + results = evaluator.eval(ids=None, data=dl, metrics=metrics, batch_size=0) + print(results) + + return results diff --git a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py index a406dc11..c5e2947c 100644 --- a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py +++ b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py @@ -16,15 +16,101 @@ Created on June 30, 2021 """ -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Dict from functools import partial - +from copy import copy import torch import numpy as np from fuse.eval.metrics.metrics_common import MetricPerBatchDefault +class MetricCountSeqAndTokens(MetricPerBatchDefault): + """ + Counts the total number sequences and tokens in encoder_input + """ + + def __init__( + self, + encoder_input: str, + ignore_index: Optional[int] = None, + state: Optional[dict] = None, + **kwargs: dict, + ) -> None: + """ + :param encoder_input: key to the encoder_input + :param ignore_index: token_id to ignore (not to count), typically pad token id + :param state: the sequence count and token count to continue for. Should be restored when we continue training. + use get_state() to get the state and save it upon checkpointing, + :param kwargs: additional super class arguments + """ + super().__init__( + seq_num="seq_num", # collect seq_num - output of _count_seq_and_tokens_update + token_num="token_num", # collect token_num - output of _count_seq_and_tokens_update + metric_per_batch_func=None, + metric_per_batch_func_pre_collect=partial( + _count_seq_and_tokens_update, + ignore_index=ignore_index, + encoder_input_key=encoder_input, + ), + result_aggregate_func=self._count_seq_and_tokens_compute, + **kwargs, + ) + if state is None: + self._state = {"seq_num": 0, "token_num": 0} + else: + assert "seq_num" in state + assert "token_num" in state + self._state = state + + def _count_seq_and_tokens_compute( + self, + seq_num: List[np.ndarray], + token_num: List[np.ndarray], + ) -> dict: + + seq_num_total = sum(seq_num) + token_num_total = sum(token_num) + self._state["seq_num"] += seq_num_total + self._state["token_num"] += token_num_total + return copy(self._state) + + def get_state(self) -> dict: + return copy(self._state) + + +def _count_seq_and_tokens_update( + batch_dict: dict, + encoder_input_key: str, + ignore_index: Optional[int] = None, +) -> Dict[str, torch.Tensor]: + """Count number of sequences and tokens + Args: + encoder_input_key: + key to encoder_input + ignore_index: + Token not to count, typically padding + Returns: + dictionary with number of sequences and tokens + """ + encoder_input = batch_dict[encoder_input_key] + + # to save GPU memory + encoder_input = encoder_input.detach() + + if ignore_index is not None: + mask = encoder_input.ne(ignore_index) + else: + mask = torch.ones_like(encoder_input, dtype=torch.bool) + + seq_num = torch.tensor( + mask.shape[0], dtype=torch.int64, device=encoder_input.device + ) + token_num = mask.sum().to(dtype=torch.int64) + + return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)} + + class MetricPerplexity(MetricPerBatchDefault): def __init__( self, diff --git a/fuse/eval/tests/test_eval.py b/fuse/eval/tests/test_eval.py index 71ded72c..0fbd9252 100644 --- a/fuse/eval/tests/test_eval.py +++ b/fuse/eval/tests/test_eval.py @@ -48,7 +48,11 @@ example_seg_4, ) -from fuse.eval.examples.examples_seq_gen import example_seq_gen_0, example_seq_gen_1 +from fuse.eval.examples.examples_seq_gen import ( + example_seq_gen_0, + example_seq_gen_1, + example_seq_gen_2, +) from fuse.eval.examples.examples_stats import example_pearson_correlation @@ -234,6 +238,11 @@ def test_eval_example_seq_gen_1(self) -> None: results = example_seq_gen_1(seed=1234) self.assertAlmostEqual(results["metrics.perplexity"], 162.87, places=2) + def test_eval_example_seq_gen_2(self) -> None: + results = example_seq_gen_2() + self.assertAlmostEqual(results["metrics.count.seq_num"], 10) + self.assertAlmostEqual(results["metrics.count.token_num"], 4999) + def test_pearson_correlation(self) -> None: res = example_pearson_correlation() self.assertAlmostEqual(res["metrics.pearsonr"], 1.0, places=2)