Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perplexity gpu mem optimization #345

Merged
merged 4 commits into from
Mar 20, 2024
Merged

Conversation

mosheraboh
Copy link
Collaborator

No description provided.

moshiko.raboh added 2 commits March 17, 2024 22:51
@@ -47,11 +46,7 @@ def example_seq_gen_0(seed: int = 1234) -> Dict[str, Any]:
[
(
"perplexity",
CI(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove that?

I see that the ground truth remained the same.. But now we don't cover CI() in the examples (i.e. test coverage)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do cover CI in other examples I hope.
I removed it because the perplexity now is computed at the batch level (more optimized) and I'm not sure it's a metric we would want a CI for.
We can support it back when necessary.

target=target,
metric_per_batch_func=partial(
_perplexity_update, ignore_index=ignore_index
log_probs="log_probs", # collect log_probs - output of _perplexity_update
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's out of the scope of this PR but still sharing:

I don't like passing arbitrary key-values as **kwargs to the parent class where they have an actual purpose:

From MetricPerBatchDefault description:

        :param kwargs: specify keywords and value arguments you want to collect from the source data.
                can be strings (key names) and/or actual values
                to collect from the results dictionary: add a "results:" prefix to the key name

I would rather see an argument like: key_values_to_collect: dict.

What's your take on that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you.
But it's indeed out of PR scope.

# avoid from overflow
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
preds = torch.clamp(preds, min=1e-10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you clamp after moving from float16 to float32 ? Isn't float32 should be more stable? I would guess the other way around - clamp before moving from float32 into float16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I think that values in (0, 1e-10) cannot be exist in float16, so why clamping that after moving to float32 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting absolute zeros in float16 precision which causes log() to be infinity.

@@ -106,6 +117,10 @@ def _perplexity_compute(
log_probs: List[np.ndarray],
token_num: List[np.ndarray],
) -> float:
# avoid from overflow on large epochs
log_probs = [e.astype(np.float64) for e in log_probs]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numerical stability?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be large numbers with many GPUS so I took the safe side here.

def _df_dict_apply(
data: pd.Series, func: Callable, batch: bool = False
) -> pd.Series:
if batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understood the use-case: data is always a single data point (e.g. series) and func might be perform on a batch level? And where func expects a batch you pass batch=True?

I'm suggesting to rename batch or add a short description to make it clearer :))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The metric has some preprocessing func in a batch level.
Here we are not running a training loop,, to avoid duplication, we run the function on each sample as if it was a batch. To do that we are converting single sample to batch, call the function and then squeezed it back to sample.
As I mentioned, those are dark places in fuse, that I would refactor If we had some time to spend on it.

Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
I added some questions inline :)

Could you please also describe the GPU mem opt part? Maybe how did you find it and what exactly fixing it (in high-level of course).

@mosheraboh mosheraboh merged commit bc21aad into master Mar 20, 2024
5 checks passed
@mosheraboh mosheraboh deleted the perplexity_gpu_mem_optimizatio branch March 20, 2024 14:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants