From 73dfd72f9aa83859b381987788b484142ec01d47 Mon Sep 17 00:00:00 2001 From: "moshiko.raboh" Date: Sun, 17 Mar 2024 14:53:43 +0200 Subject: [PATCH 1/3] optimize memory usage of perplexity in multi gpu mode --- fuse/eval/metrics/metrics_common.py | 7 +++-- .../sequence_gen/metrics_seq_gen_common.py | 30 ++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index fca34a42..7abbe1b4 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -602,8 +602,9 @@ class MetricPerBatchDefault(MetricWithCollectorBase): def __init__( self, *, - metric_per_batch_func: Callable, + metric_per_batch_func: Optional[Callable], result_aggregate_func: Callable, + metric_per_batch_func_pre_collect: Optional[Callable] = None, **kwargs: Any, ): """ @@ -616,7 +617,9 @@ def __init__( """ super().__init__( - batch_post_collect_process_func=metric_per_batch_func, **kwargs + batch_post_collect_process_func=metric_per_batch_func, + batch_pre_collect_process_func=metric_per_batch_func_pre_collect, + **kwargs, ) self._result_aggregate_func = result_aggregate_func diff --git a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py index 166d1243..18914741 100644 --- a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py +++ b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py @@ -16,7 +16,7 @@ Created on June 30, 2021 """ -from typing import Optional, Tuple, List, Union +from typing import Optional, Tuple, List from functools import partial import torch @@ -34,13 +34,16 @@ def __init__( **kwargs: dict, ) -> None: super().__init__( - preds=preds, - target=target, - metric_per_batch_func=partial( - _perplexity_update, ignore_index=ignore_index + log_probs="log_probs", # collect log_probs - output of _perplexity_update + token_num="token_num", # collect token_num - output of _perplexity_update + metric_per_batch_func=None, + metric_per_batch_func_pre_collect=partial( + _perplexity_update, + ignore_index=ignore_index, + preds_key=preds, + targets_key=target, ), result_aggregate_func=_perplexity_compute, - post_keys_to_collect=["log_probs", "token_num"], **kwargs, ) @@ -48,8 +51,9 @@ def __init__( # Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68 # copied and not imported to not be affected by internal interface modifications. def _perplexity_update( - preds: Union[torch.Tensor, np.ndarray], - target: Union[torch.Tensor, np.ndarray], + batch_dict: dict, + preds_key: str, + targets_key: str, ignore_index: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute intermediate statistics for Perplexity. @@ -65,6 +69,8 @@ def _perplexity_update( Log probabilities, summed over all samples Number of tokens """ + preds = batch_dict[preds_key] + target = batch_dict[targets_key] if isinstance(preds, np.ndarray): preds = torch.tensor(preds) @@ -96,6 +102,10 @@ def _perplexity_update( mask = torch.ones_like(target, dtype=torch.bool) preds = preds[:, target].diagonal()[mask] + # avoid from overflow + if preds.dtype == torch.float16: + preds = preds.to(torch.float32) + preds = torch.clamp(preds, min=1e-10) total_log_probs = -preds.log().sum() count = mask.sum() @@ -106,6 +116,10 @@ def _perplexity_compute( log_probs: List[np.ndarray], token_num: List[np.ndarray], ) -> float: + # avoid from overflow on large epochs + log_probs = [e.astype(np.float64) for e in log_probs] + token_num = [e.astype(np.int64) for e in token_num] + sum_log_probs = sum(log_probs) num_total = sum(token_num) return float(np.exp(sum_log_probs / num_total)) From 2f487cc0a70c04846f43f8ba2d5be853bae5a0e9 Mon Sep 17 00:00:00 2001 From: "moshiko.raboh" Date: Sun, 17 Mar 2024 22:51:06 +0200 Subject: [PATCH 2/3] .. --- fuse/eval/examples/examples_seq_gen.py | 7 +- fuse/eval/metrics/metrics_common.py | 95 ++++++++++++++----- .../sequence_gen/metrics_seq_gen_common.py | 1 + fuse/eval/tests/test_eval.py | 2 +- 4 files changed, 73 insertions(+), 32 deletions(-) diff --git a/fuse/eval/examples/examples_seq_gen.py b/fuse/eval/examples/examples_seq_gen.py index fc4367b3..346eec58 100644 --- a/fuse/eval/examples/examples_seq_gen.py +++ b/fuse/eval/examples/examples_seq_gen.py @@ -16,7 +16,6 @@ from collections import OrderedDict from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import MetricPerplexity -from fuse.eval.metrics.metrics_common import CI from fuse.eval.evaluator import EvaluatorDefault @@ -47,11 +46,7 @@ def example_seq_gen_0(seed: int = 1234) -> Dict[str, Any]: [ ( "perplexity", - CI( - MetricPerplexity(preds="pred", target="label"), - stratum=None, - num_of_bootstraps=10, - ), + MetricPerplexity(preds="pred", target="label"), ) ] ) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 7abbe1b4..1dc2b608 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -86,6 +86,7 @@ def __init__( batch_post_collect_process_func: Optional[Callable] = None, post_keys_to_collect: Optional[List[str]] = None, collect_distributed: bool = True, + collect_ids: bool = True, **keys_to_collect: Dict[str, str], ): """ @@ -99,6 +100,7 @@ def __init__( :param post_keys_to_collect: specify the keys you want to collect after post_collect_process. Required only if post_collect_process_func or batch_post_collect_process_func are specified. if None, will aggregate list of post_collect_process_func returned values :param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. + :param collect_ids: if True will not collect ids and will not support permutation of data (which used to compute confidence interval) :param keys_to_collect: specify the keys you want to collect from the source data """ super().__init__() @@ -111,6 +113,7 @@ def __init__( self._post_keys_to_collect = copy.copy(post_keys_to_collect) self._id_keys = MetricCollector.DEFAULT_ID_KEYS self._collect_distributed = collect_distributed + self._to_collect_ids = collect_ids # reset self.reset() @@ -181,17 +184,18 @@ def collect(self, batch: Dict) -> None: self._collected_data[name].extend(value) # extract ids and store it in self._collected_ids - ids = None - for key in self._id_keys: - if key in batch: - ids = batch[key] - # collect distributed - if dist.is_initialized() and self._collect_distributed: - ids = self.sync_ids(ids) - break - - if ids is not None: - self._collected_ids.extend(ids) + if self._to_collect_ids: + ids = None + for key in self._id_keys: + if key in batch: + ids = batch[key] + # collect distributed + if dist.is_initialized() and self._collect_distributed: + ids = self.sync_ids(ids) + break + + if ids is not None: + self._collected_ids.extend(ids) @staticmethod def sync_tensor_data_and_concat(data: torch.Tensor) -> torch.Tensor: @@ -239,15 +243,41 @@ def sync_ids(ids: List[Tuple[str, int]]) -> List[Any]: return data_list @staticmethod - def _df_dict_apply(data: pd.Series, func: Callable) -> pd.Series: - result = func(NDict(data.to_dict())) - return pd.Series(result.flatten()) + def _df_dict_apply( + data: pd.Series, func: Callable, batch: bool = False + ) -> pd.Series: + if batch: + # expand sample to batch + data_dict = {} + for k, v in data.to_dict().items(): + if isinstance(v, torch.Tensor): + data_dict[k] = v.unsqueeze(0) + elif isinstance(v, np.ndarray): + data_dict[k] = np.expand_dims(v, axis=0) + else: + data_dict[k] = [v] + else: + data_dict = data.to_dict() + + result = func(NDict(data_dict)) + + if batch: + # squeeze batch back to sample + result_sample = {} + for k, v in result.items(): + if isinstance(v, (torch.Tensor, np.ndarray, list)): + result_sample[k] = v[0] + else: + result_sample[k] = v + result = result_sample + return pd.Series(result) @staticmethod def _df_dict_apply_kwargs( data: pd.Series, func: Callable, batch: bool = False, post: bool = False ) -> pd.Series: if batch: + # expand sample to batch kwargs = {} for k, v in data.to_dict().items(): if isinstance(v, torch.Tensor): @@ -266,6 +296,7 @@ def _df_dict_apply_kwargs( result = {"post_args": result} if batch: + # squeeze batch back to sample result_sample = {} for k, v in result.items(): if isinstance(v, (torch.Tensor, np.ndarray, list)): @@ -288,6 +319,12 @@ def set(self, data: pd.DataFrame) -> None: ) data = data.apply(pre_collect_process, axis=1) + if self._batch_pre_collect_process_func is not None: + pre_collect_process = lambda x: self._df_dict_apply( + x, self._batch_pre_collect_process_func, batch=True + ) + data = data.apply(pre_collect_process, axis=1) + data_to_collect = pd.DataFrame(data=None, columns=self._keys_to_collect) for name, key in self._keys_to_collect.items(): if key not in data.keys(): @@ -311,17 +348,20 @@ def set(self, data: pd.DataFrame) -> None: for name in data_to_collect.keys(): values = data_to_collect.loc[:, name] - self._collected_data[name].extend(values) + self._collected_data[name].extend( + [v.numpy() if isinstance(v, torch.Tensor) else v for v in values] + ) # extract ids and store it in self._collected_ids - ids = None - for key in self._id_keys: - if key in data.keys(): - ids = list(data[key]) - break + if self._to_collect_ids: + ids = None + for key in self._id_keys: + if key in data.keys(): + ids = list(data[key]) + break - if ids is not None: - self._collected_ids.extend(ids) + if ids is not None: + self._collected_ids.extend(ids) def reset(self) -> None: """ @@ -338,8 +378,10 @@ def reset(self) -> None: self._collected_data = {name: [] for name in self._post_keys_to_collect} else: self._collected_data = {"post_args": []} - - self._collected_ids = [] # the original collected ids + if self._to_collect_ids: + self._collected_ids = [] # the original collected ids + else: + self._collected_ids = None self._sampled_ids = None # the required ids - set by sample() method @@ -354,7 +396,7 @@ def get(self, ids: Optional[Sequence[Hashable]] = None) -> Tuple[Dict[str, Any]] Get collected data - collected data dictionary and collected ids. each element in the dictionary will be a list of values from all samples """ - if ids is None: + if ids is None or self._to_collect_ids is False: return copy.copy(self._collected_data) else: # convert required ids to permutation @@ -391,6 +433,7 @@ def __init__( batch_post_collect_process_func: Optional[Callable] = None, post_keys_to_collect: Optional[Sequence[str]] = None, external_data_collector: Optional[MetricCollector] = None, + collect_ids: bool = True, extract_ids: bool = False, collect_distributed: bool = True, **kwargs: Any, @@ -402,6 +445,7 @@ def __init__( :param batch_post_collect_process_func: Optional callable - see details in MetricCollector.__init__ :param post_keys_to_collect: Optional keys to collect from post_collect_process func results - see details in MetricCollector.__init__ :param external_data_collector: Optional - in a case space optimization required and there by using shared collector for few metrics + :param collect_ids: if True will not collect ids and will not support permutation of data (which used to compute confidence interval) :param extract_ids: self._extract_arguments packs all arguments for a underlying function. Set to True, to pack also the ids (under the name 'ids') :param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. :param kwargs: specify keywords and value arguments you want to collect from the source data. @@ -430,6 +474,7 @@ def __init__( batch_post_collect_process_func, post_keys_to_collect=post_keys_to_collect, collect_distributed=collect_distributed, + collect_ids=collect_ids, **self._keys_to_collect, ) if external_data_collector is None diff --git a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py index 18914741..a406dc11 100644 --- a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py +++ b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py @@ -44,6 +44,7 @@ def __init__( targets_key=target, ), result_aggregate_func=_perplexity_compute, + collect_ids=False, **kwargs, ) diff --git a/fuse/eval/tests/test_eval.py b/fuse/eval/tests/test_eval.py index b54492a9..71ded72c 100644 --- a/fuse/eval/tests/test_eval.py +++ b/fuse/eval/tests/test_eval.py @@ -228,7 +228,7 @@ def test_eval_example_14(self) -> None: def test_eval_example_seq_gen_0(self) -> None: results = example_seq_gen_0(seed=1234) - self.assertAlmostEqual(results["metrics.perplexity.org"], 162.87, places=2) + self.assertAlmostEqual(results["metrics.perplexity"], 162.87, places=2) def test_eval_example_seq_gen_1(self) -> None: results = example_seq_gen_1(seed=1234) From 6ad4270bd218938649e2289f05e158f8bedb6702 Mon Sep 17 00:00:00 2001 From: "moshiko.raboh" Date: Mon, 18 Mar 2024 13:29:04 +0200 Subject: [PATCH 3/3] fix doc --- fuse/eval/metrics/metrics_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 1dc2b608..464503ab 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -100,7 +100,7 @@ def __init__( :param post_keys_to_collect: specify the keys you want to collect after post_collect_process. Required only if post_collect_process_func or batch_post_collect_process_func are specified. if None, will aggregate list of post_collect_process_func returned values :param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. - :param collect_ids: if True will not collect ids and will not support permutation of data (which used to compute confidence interval) + :param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval) :param keys_to_collect: specify the keys you want to collect from the source data """ super().__init__() @@ -445,7 +445,7 @@ def __init__( :param batch_post_collect_process_func: Optional callable - see details in MetricCollector.__init__ :param post_keys_to_collect: Optional keys to collect from post_collect_process func results - see details in MetricCollector.__init__ :param external_data_collector: Optional - in a case space optimization required and there by using shared collector for few metrics - :param collect_ids: if True will not collect ids and will not support permutation of data (which used to compute confidence interval) + :param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval) :param extract_ids: self._extract_arguments packs all arguments for a underlying function. Set to True, to pack also the ids (under the name 'ids') :param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. :param kwargs: specify keywords and value arguments you want to collect from the source data.