-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perplexity gpu mem optimization #345
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,6 +86,7 @@ def __init__( | |
batch_post_collect_process_func: Optional[Callable] = None, | ||
post_keys_to_collect: Optional[List[str]] = None, | ||
collect_distributed: bool = True, | ||
collect_ids: bool = True, | ||
**keys_to_collect: Dict[str, str], | ||
): | ||
""" | ||
|
@@ -99,6 +100,7 @@ def __init__( | |
:param post_keys_to_collect: specify the keys you want to collect after post_collect_process. Required only if post_collect_process_func or batch_post_collect_process_func are specified. | ||
if None, will aggregate list of post_collect_process_func returned values | ||
:param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. | ||
:param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval) | ||
:param keys_to_collect: specify the keys you want to collect from the source data | ||
""" | ||
super().__init__() | ||
|
@@ -111,6 +113,7 @@ def __init__( | |
self._post_keys_to_collect = copy.copy(post_keys_to_collect) | ||
self._id_keys = MetricCollector.DEFAULT_ID_KEYS | ||
self._collect_distributed = collect_distributed | ||
self._to_collect_ids = collect_ids | ||
|
||
# reset | ||
self.reset() | ||
|
@@ -181,17 +184,18 @@ def collect(self, batch: Dict) -> None: | |
self._collected_data[name].extend(value) | ||
|
||
# extract ids and store it in self._collected_ids | ||
ids = None | ||
for key in self._id_keys: | ||
if key in batch: | ||
ids = batch[key] | ||
# collect distributed | ||
if dist.is_initialized() and self._collect_distributed: | ||
ids = self.sync_ids(ids) | ||
break | ||
|
||
if ids is not None: | ||
self._collected_ids.extend(ids) | ||
if self._to_collect_ids: | ||
ids = None | ||
for key in self._id_keys: | ||
if key in batch: | ||
ids = batch[key] | ||
# collect distributed | ||
if dist.is_initialized() and self._collect_distributed: | ||
ids = self.sync_ids(ids) | ||
break | ||
|
||
if ids is not None: | ||
self._collected_ids.extend(ids) | ||
|
||
@staticmethod | ||
def sync_tensor_data_and_concat(data: torch.Tensor) -> torch.Tensor: | ||
|
@@ -239,15 +243,41 @@ def sync_ids(ids: List[Tuple[str, int]]) -> List[Any]: | |
return data_list | ||
|
||
@staticmethod | ||
def _df_dict_apply(data: pd.Series, func: Callable) -> pd.Series: | ||
result = func(NDict(data.to_dict())) | ||
return pd.Series(result.flatten()) | ||
def _df_dict_apply( | ||
data: pd.Series, func: Callable, batch: bool = False | ||
) -> pd.Series: | ||
if batch: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understood the use-case: I'm suggesting to rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The metric has some preprocessing func in a batch level. |
||
# expand sample to batch | ||
data_dict = {} | ||
for k, v in data.to_dict().items(): | ||
if isinstance(v, torch.Tensor): | ||
data_dict[k] = v.unsqueeze(0) | ||
elif isinstance(v, np.ndarray): | ||
data_dict[k] = np.expand_dims(v, axis=0) | ||
else: | ||
data_dict[k] = [v] | ||
else: | ||
data_dict = data.to_dict() | ||
|
||
result = func(NDict(data_dict)) | ||
|
||
if batch: | ||
# squeeze batch back to sample | ||
result_sample = {} | ||
for k, v in result.items(): | ||
if isinstance(v, (torch.Tensor, np.ndarray, list)): | ||
result_sample[k] = v[0] | ||
else: | ||
result_sample[k] = v | ||
result = result_sample | ||
return pd.Series(result) | ||
|
||
@staticmethod | ||
def _df_dict_apply_kwargs( | ||
data: pd.Series, func: Callable, batch: bool = False, post: bool = False | ||
) -> pd.Series: | ||
if batch: | ||
# expand sample to batch | ||
kwargs = {} | ||
for k, v in data.to_dict().items(): | ||
if isinstance(v, torch.Tensor): | ||
|
@@ -266,6 +296,7 @@ def _df_dict_apply_kwargs( | |
result = {"post_args": result} | ||
|
||
if batch: | ||
# squeeze batch back to sample | ||
result_sample = {} | ||
for k, v in result.items(): | ||
if isinstance(v, (torch.Tensor, np.ndarray, list)): | ||
|
@@ -288,6 +319,12 @@ def set(self, data: pd.DataFrame) -> None: | |
) | ||
data = data.apply(pre_collect_process, axis=1) | ||
|
||
if self._batch_pre_collect_process_func is not None: | ||
pre_collect_process = lambda x: self._df_dict_apply( | ||
x, self._batch_pre_collect_process_func, batch=True | ||
) | ||
data = data.apply(pre_collect_process, axis=1) | ||
|
||
data_to_collect = pd.DataFrame(data=None, columns=self._keys_to_collect) | ||
for name, key in self._keys_to_collect.items(): | ||
if key not in data.keys(): | ||
|
@@ -311,17 +348,20 @@ def set(self, data: pd.DataFrame) -> None: | |
|
||
for name in data_to_collect.keys(): | ||
values = data_to_collect.loc[:, name] | ||
self._collected_data[name].extend(values) | ||
self._collected_data[name].extend( | ||
[v.numpy() if isinstance(v, torch.Tensor) else v for v in values] | ||
) | ||
|
||
# extract ids and store it in self._collected_ids | ||
ids = None | ||
for key in self._id_keys: | ||
if key in data.keys(): | ||
ids = list(data[key]) | ||
break | ||
if self._to_collect_ids: | ||
ids = None | ||
for key in self._id_keys: | ||
if key in data.keys(): | ||
ids = list(data[key]) | ||
break | ||
|
||
if ids is not None: | ||
self._collected_ids.extend(ids) | ||
if ids is not None: | ||
self._collected_ids.extend(ids) | ||
|
||
def reset(self) -> None: | ||
""" | ||
|
@@ -338,8 +378,10 @@ def reset(self) -> None: | |
self._collected_data = {name: [] for name in self._post_keys_to_collect} | ||
else: | ||
self._collected_data = {"post_args": []} | ||
|
||
self._collected_ids = [] # the original collected ids | ||
if self._to_collect_ids: | ||
self._collected_ids = [] # the original collected ids | ||
else: | ||
self._collected_ids = None | ||
|
||
self._sampled_ids = None # the required ids - set by sample() method | ||
|
||
|
@@ -354,7 +396,7 @@ def get(self, ids: Optional[Sequence[Hashable]] = None) -> Tuple[Dict[str, Any]] | |
Get collected data - collected data dictionary and collected ids. | ||
each element in the dictionary will be a list of values from all samples | ||
""" | ||
if ids is None: | ||
if ids is None or self._to_collect_ids is False: | ||
return copy.copy(self._collected_data) | ||
else: | ||
# convert required ids to permutation | ||
|
@@ -391,6 +433,7 @@ def __init__( | |
batch_post_collect_process_func: Optional[Callable] = None, | ||
post_keys_to_collect: Optional[Sequence[str]] = None, | ||
external_data_collector: Optional[MetricCollector] = None, | ||
collect_ids: bool = True, | ||
extract_ids: bool = False, | ||
collect_distributed: bool = True, | ||
**kwargs: Any, | ||
|
@@ -402,6 +445,7 @@ def __init__( | |
:param batch_post_collect_process_func: Optional callable - see details in MetricCollector.__init__ | ||
:param post_keys_to_collect: Optional keys to collect from post_collect_process func results - see details in MetricCollector.__init__ | ||
:param external_data_collector: Optional - in a case space optimization required and there by using shared collector for few metrics | ||
:param collect_ids: if False will not collect ids and will not support permutation of data (which used to compute confidence interval) | ||
:param extract_ids: self._extract_arguments packs all arguments for a underlying function. Set to True, to pack also the ids (under the name 'ids') | ||
:param collect_distributed: if True, in multi gpu training, will collect the samples from all gpus - otherwise only rank0 will be reported. | ||
:param kwargs: specify keywords and value arguments you want to collect from the source data. | ||
|
@@ -430,6 +474,7 @@ def __init__( | |
batch_post_collect_process_func, | ||
post_keys_to_collect=post_keys_to_collect, | ||
collect_distributed=collect_distributed, | ||
collect_ids=collect_ids, | ||
**self._keys_to_collect, | ||
) | ||
if external_data_collector is None | ||
|
@@ -602,8 +647,9 @@ class MetricPerBatchDefault(MetricWithCollectorBase): | |
def __init__( | ||
self, | ||
*, | ||
metric_per_batch_func: Callable, | ||
metric_per_batch_func: Optional[Callable], | ||
result_aggregate_func: Callable, | ||
metric_per_batch_func_pre_collect: Optional[Callable] = None, | ||
**kwargs: Any, | ||
): | ||
""" | ||
|
@@ -616,7 +662,9 @@ def __init__( | |
""" | ||
|
||
super().__init__( | ||
batch_post_collect_process_func=metric_per_batch_func, **kwargs | ||
batch_post_collect_process_func=metric_per_batch_func, | ||
batch_pre_collect_process_func=metric_per_batch_func_pre_collect, | ||
**kwargs, | ||
) | ||
self._result_aggregate_func = result_aggregate_func | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
Created on June 30, 2021 | ||
|
||
""" | ||
from typing import Optional, Tuple, List, Union | ||
from typing import Optional, Tuple, List | ||
from functools import partial | ||
|
||
import torch | ||
|
@@ -34,22 +34,27 @@ def __init__( | |
**kwargs: dict, | ||
) -> None: | ||
super().__init__( | ||
preds=preds, | ||
target=target, | ||
metric_per_batch_func=partial( | ||
_perplexity_update, ignore_index=ignore_index | ||
log_probs="log_probs", # collect log_probs - output of _perplexity_update | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's out of the scope of this PR but still sharing: I don't like passing arbitrary key-values as From
I would rather see an argument like: What's your take on that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you. |
||
token_num="token_num", # collect token_num - output of _perplexity_update | ||
metric_per_batch_func=None, | ||
metric_per_batch_func_pre_collect=partial( | ||
_perplexity_update, | ||
ignore_index=ignore_index, | ||
preds_key=preds, | ||
targets_key=target, | ||
), | ||
result_aggregate_func=_perplexity_compute, | ||
post_keys_to_collect=["log_probs", "token_num"], | ||
collect_ids=False, | ||
**kwargs, | ||
) | ||
|
||
|
||
# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68 | ||
# copied and not imported to not be affected by internal interface modifications. | ||
def _perplexity_update( | ||
preds: Union[torch.Tensor, np.ndarray], | ||
target: Union[torch.Tensor, np.ndarray], | ||
batch_dict: dict, | ||
preds_key: str, | ||
targets_key: str, | ||
ignore_index: Optional[int] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Compute intermediate statistics for Perplexity. | ||
|
@@ -65,6 +70,8 @@ def _perplexity_update( | |
Log probabilities, summed over all samples | ||
Number of tokens | ||
""" | ||
preds = batch_dict[preds_key] | ||
target = batch_dict[targets_key] | ||
if isinstance(preds, np.ndarray): | ||
preds = torch.tensor(preds) | ||
|
||
|
@@ -96,6 +103,10 @@ def _perplexity_update( | |
mask = torch.ones_like(target, dtype=torch.bool) | ||
|
||
preds = preds[:, target].diagonal()[mask] | ||
# avoid from overflow | ||
if preds.dtype == torch.float16: | ||
preds = preds.to(torch.float32) | ||
preds = torch.clamp(preds, min=1e-10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you clamp after moving from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Plus, I think that values in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm getting absolute zeros in float16 precision which causes log() to be infinity. |
||
total_log_probs = -preds.log().sum() | ||
count = mask.sum() | ||
|
||
|
@@ -106,6 +117,10 @@ def _perplexity_compute( | |
log_probs: List[np.ndarray], | ||
token_num: List[np.ndarray], | ||
) -> float: | ||
# avoid from overflow on large epochs | ||
log_probs = [e.astype(np.float64) for e in log_probs] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numerical stability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be large numbers with many GPUS so I took the safe side here. |
||
token_num = [e.astype(np.int64) for e in token_num] | ||
|
||
sum_log_probs = sum(log_probs) | ||
num_total = sum(token_num) | ||
return float(np.exp(sum_log_probs / num_total)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you remove that?
I see that the ground truth remained the same.. But now we don't cover
CI()
in the examples (i.e. test coverage)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do cover CI in other examples I hope.
I removed it because the perplexity now is computed at the batch level (more optimized) and I'm not sure it's a metric we would want a CI for.
We can support it back when necessary.