Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perplexity gpu mem optimization #345

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions fuse/eval/examples/examples_seq_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from collections import OrderedDict

from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import MetricPerplexity
from fuse.eval.metrics.metrics_common import CI

from fuse.eval.evaluator import EvaluatorDefault

Expand Down Expand Up @@ -47,11 +46,7 @@ def example_seq_gen_0(seed: int = 1234) -> Dict[str, Any]:
[
(
"perplexity",
CI(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove that?

I see that the ground truth remained the same.. But now we don't cover CI() in the examples (i.e. test coverage)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do cover CI in other examples I hope.
I removed it because the perplexity now is computed at the batch level (more optimized) and I'm not sure it's a metric we would want a CI for.
We can support it back when necessary.

MetricPerplexity(preds="pred", target="label"),
stratum=None,
num_of_bootstraps=10,
),
MetricPerplexity(preds="pred", target="label"),
)
]
)
Expand Down
102 changes: 75 additions & 27 deletions fuse/eval/metrics/metrics_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
batch_post_collect_process_func: Optional[Callable] = None,
post_keys_to_collect: Optional[List[str]] = None,
collect_distributed: bool = True,
collect_ids: bool = True,
**keys_to_collect: Dict[str, str],
):
"""
Expand All @@ -99,6 +100,7 @@ def __init__(
:param post_keys_to_collect: specify the keys you want to collect after post_collect_process. Required only if post_collect_process_func or batch_post_collect_process_func are specified.
if None, will aggregate list of post_collect_process_func returned values
:param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported.
:param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval)
:param keys_to_collect: specify the keys you want to collect from the source data
"""
super().__init__()
Expand All @@ -111,6 +113,7 @@ def __init__(
self._post_keys_to_collect = copy.copy(post_keys_to_collect)
self._id_keys = MetricCollector.DEFAULT_ID_KEYS
self._collect_distributed = collect_distributed
self._to_collect_ids = collect_ids

# reset
self.reset()
Expand Down Expand Up @@ -181,17 +184,18 @@ def collect(self, batch: Dict) -> None:
self._collected_data[name].extend(value)

# extract ids and store it in self._collected_ids
ids = None
for key in self._id_keys:
if key in batch:
ids = batch[key]
# collect distributed
if dist.is_initialized() and self._collect_distributed:
ids = self.sync_ids(ids)
break

if ids is not None:
self._collected_ids.extend(ids)
if self._to_collect_ids:
ids = None
for key in self._id_keys:
if key in batch:
ids = batch[key]
# collect distributed
if dist.is_initialized() and self._collect_distributed:
ids = self.sync_ids(ids)
break

if ids is not None:
self._collected_ids.extend(ids)

@staticmethod
def sync_tensor_data_and_concat(data: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -239,15 +243,41 @@ def sync_ids(ids: List[Tuple[str, int]]) -> List[Any]:
return data_list

@staticmethod
def _df_dict_apply(data: pd.Series, func: Callable) -> pd.Series:
result = func(NDict(data.to_dict()))
return pd.Series(result.flatten())
def _df_dict_apply(
data: pd.Series, func: Callable, batch: bool = False
) -> pd.Series:
if batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood the use-case: data is always a single data point (e.g. series) and func might be perform on a batch level? And where func expects a batch you pass batch=True?

I'm suggesting to rename batch or add a short description to make it clearer :))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metric has some preprocessing func in a batch level.
Here we are not running a training loop,, to avoid duplication, we run the function on each sample as if it was a batch. To do that we are converting single sample to batch, call the function and then squeezed it back to sample.
As I mentioned, those are dark places in fuse, that I would refactor If we had some time to spend on it.

# expand sample to batch
data_dict = {}
for k, v in data.to_dict().items():
if isinstance(v, torch.Tensor):
data_dict[k] = v.unsqueeze(0)
elif isinstance(v, np.ndarray):
data_dict[k] = np.expand_dims(v, axis=0)
else:
data_dict[k] = [v]
else:
data_dict = data.to_dict()

result = func(NDict(data_dict))

if batch:
# squeeze batch back to sample
result_sample = {}
for k, v in result.items():
if isinstance(v, (torch.Tensor, np.ndarray, list)):
result_sample[k] = v[0]
else:
result_sample[k] = v
result = result_sample
return pd.Series(result)

@staticmethod
def _df_dict_apply_kwargs(
data: pd.Series, func: Callable, batch: bool = False, post: bool = False
) -> pd.Series:
if batch:
# expand sample to batch
kwargs = {}
for k, v in data.to_dict().items():
if isinstance(v, torch.Tensor):
Expand All @@ -266,6 +296,7 @@ def _df_dict_apply_kwargs(
result = {"post_args": result}

if batch:
# squeeze batch back to sample
result_sample = {}
for k, v in result.items():
if isinstance(v, (torch.Tensor, np.ndarray, list)):
Expand All @@ -288,6 +319,12 @@ def set(self, data: pd.DataFrame) -> None:
)
data = data.apply(pre_collect_process, axis=1)

if self._batch_pre_collect_process_func is not None:
pre_collect_process = lambda x: self._df_dict_apply(
x, self._batch_pre_collect_process_func, batch=True
)
data = data.apply(pre_collect_process, axis=1)

data_to_collect = pd.DataFrame(data=None, columns=self._keys_to_collect)
for name, key in self._keys_to_collect.items():
if key not in data.keys():
Expand All @@ -311,17 +348,20 @@ def set(self, data: pd.DataFrame) -> None:

for name in data_to_collect.keys():
values = data_to_collect.loc[:, name]
self._collected_data[name].extend(values)
self._collected_data[name].extend(
[v.numpy() if isinstance(v, torch.Tensor) else v for v in values]
)

# extract ids and store it in self._collected_ids
ids = None
for key in self._id_keys:
if key in data.keys():
ids = list(data[key])
break
if self._to_collect_ids:
ids = None
for key in self._id_keys:
if key in data.keys():
ids = list(data[key])
break

if ids is not None:
self._collected_ids.extend(ids)
if ids is not None:
self._collected_ids.extend(ids)

def reset(self) -> None:
"""
Expand All @@ -338,8 +378,10 @@ def reset(self) -> None:
self._collected_data = {name: [] for name in self._post_keys_to_collect}
else:
self._collected_data = {"post_args": []}

self._collected_ids = [] # the original collected ids
if self._to_collect_ids:
self._collected_ids = [] # the original collected ids
else:
self._collected_ids = None

self._sampled_ids = None # the required ids - set by sample() method

Expand All @@ -354,7 +396,7 @@ def get(self, ids: Optional[Sequence[Hashable]] = None) -> Tuple[Dict[str, Any]]
Get collected data - collected data dictionary and collected ids.
each element in the dictionary will be a list of values from all samples
"""
if ids is None:
if ids is None or self._to_collect_ids is False:
return copy.copy(self._collected_data)
else:
# convert required ids to permutation
Expand Down Expand Up @@ -391,6 +433,7 @@ def __init__(
batch_post_collect_process_func: Optional[Callable] = None,
post_keys_to_collect: Optional[Sequence[str]] = None,
external_data_collector: Optional[MetricCollector] = None,
collect_ids: bool = True,
extract_ids: bool = False,
collect_distributed: bool = True,
**kwargs: Any,
Expand All @@ -402,6 +445,7 @@ def __init__(
:param batch_post_collect_process_func: Optional callable - see details in MetricCollector.__init__
:param post_keys_to_collect: Optional keys to collect from post_collect_process func results - see details in MetricCollector.__init__
:param external_data_collector: Optional - in a case space optimization required and there by using shared collector for few metrics
:param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval)
:param extract_ids: self._extract_arguments packs all arguments for a underlying function. Set to True, to pack also the ids (under the name 'ids')
:param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported.
:param kwargs: specify keywords and value arguments you want to collect from the source data.
Expand Down Expand Up @@ -430,6 +474,7 @@ def __init__(
batch_post_collect_process_func,
post_keys_to_collect=post_keys_to_collect,
collect_distributed=collect_distributed,
collect_ids=collect_ids,
**self._keys_to_collect,
)
if external_data_collector is None
Expand Down Expand Up @@ -602,8 +647,9 @@ class MetricPerBatchDefault(MetricWithCollectorBase):
def __init__(
self,
*,
metric_per_batch_func: Callable,
metric_per_batch_func: Optional[Callable],
result_aggregate_func: Callable,
metric_per_batch_func_pre_collect: Optional[Callable] = None,
**kwargs: Any,
):
"""
Expand All @@ -616,7 +662,9 @@ def __init__(
"""

super().__init__(
batch_post_collect_process_func=metric_per_batch_func, **kwargs
batch_post_collect_process_func=metric_per_batch_func,
batch_pre_collect_process_func=metric_per_batch_func_pre_collect,
**kwargs,
)
self._result_aggregate_func = result_aggregate_func

Expand Down
31 changes: 23 additions & 8 deletions fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Created on June 30, 2021

"""
from typing import Optional, Tuple, List, Union
from typing import Optional, Tuple, List
from functools import partial

import torch
Expand All @@ -34,22 +34,27 @@ def __init__(
**kwargs: dict,
) -> None:
super().__init__(
preds=preds,
target=target,
metric_per_batch_func=partial(
_perplexity_update, ignore_index=ignore_index
log_probs="log_probs", # collect log_probs - output of _perplexity_update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's out of the scope of this PR but still sharing:

I don't like passing arbitrary key-values as **kwargs to the parent class where they have an actual purpose:

From MetricPerBatchDefault description:

        :param kwargs: specify keywords and value arguments you want to collect from the source data.
                can be strings (key names) and/or actual values
                to collect from the results dictionary: add a "results:" prefix to the key name

I would rather see an argument like: key_values_to_collect: dict.

What's your take on that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you.
But it's indeed out of PR scope.

token_num="token_num", # collect token_num - output of _perplexity_update
metric_per_batch_func=None,
metric_per_batch_func_pre_collect=partial(
_perplexity_update,
ignore_index=ignore_index,
preds_key=preds,
targets_key=target,
),
result_aggregate_func=_perplexity_compute,
post_keys_to_collect=["log_probs", "token_num"],
collect_ids=False,
**kwargs,
)


# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68
# copied and not imported to not be affected by internal interface modifications.
def _perplexity_update(
preds: Union[torch.Tensor, np.ndarray],
target: Union[torch.Tensor, np.ndarray],
batch_dict: dict,
preds_key: str,
targets_key: str,
ignore_index: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute intermediate statistics for Perplexity.
Expand All @@ -65,6 +70,8 @@ def _perplexity_update(
Log probabilities, summed over all samples
Number of tokens
"""
preds = batch_dict[preds_key]
target = batch_dict[targets_key]
if isinstance(preds, np.ndarray):
preds = torch.tensor(preds)

Expand Down Expand Up @@ -96,6 +103,10 @@ def _perplexity_update(
mask = torch.ones_like(target, dtype=torch.bool)

preds = preds[:, target].diagonal()[mask]
# avoid from overflow
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
preds = torch.clamp(preds, min=1e-10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you clamp after moving from float16 to float32 ? Isn't float32 should be more stable? I would guess the other way around - clamp before moving from float32 into float16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I think that values in (0, 1e-10) cannot be exist in float16, so why clamping that after moving to float32 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting absolute zeros in float16 precision which causes log() to be infinity.

total_log_probs = -preds.log().sum()
count = mask.sum()

Expand All @@ -106,6 +117,10 @@ def _perplexity_compute(
log_probs: List[np.ndarray],
token_num: List[np.ndarray],
) -> float:
# avoid from overflow on large epochs
log_probs = [e.astype(np.float64) for e in log_probs]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numerical stability?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be large numbers with many GPUS so I took the safe side here.

token_num = [e.astype(np.int64) for e in token_num]

sum_log_probs = sum(log_probs)
num_total = sum(token_num)
return float(np.exp(sum_log_probs / num_total))
2 changes: 1 addition & 1 deletion fuse/eval/tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_eval_example_14(self) -> None:

def test_eval_example_seq_gen_0(self) -> None:
results = example_seq_gen_0(seed=1234)
self.assertAlmostEqual(results["metrics.perplexity.org"], 162.87, places=2)
self.assertAlmostEqual(results["metrics.perplexity"], 162.87, places=2)

def test_eval_example_seq_gen_1(self) -> None:
results = example_seq_gen_1(seed=1234)
Expand Down
Loading