Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline: Add support to eval micro bs configuration #4859

Merged
merged 2 commits into from
Jan 3, 2024

Conversation

nelyahu
Copy link
Contributor

@nelyahu nelyahu commented Dec 21, 2023

When running evaluation the general memory consumption is reduced. Mainly due to absence of gradients, and hanging FWD activations. It allows to increase the micro-bs and improve the evaluation performance. This commits add the option to pass num_micro_batches to eval_batch(), as the current assumption is that same micro-bs and global-bs is used, so same number micro batches will take place.
This commit also modifies _scale_loss_by_gas in runtime/engine.py to consider number of eval micro batches for loss scaling instead of training gas.

When running evaluation the general memory consumption is reduced.
Mainly due to absence of gradients, and hanging FWD activations.
It allows to increase the micro-bs and improve the evaluation performance.
This commits add the option to pass num_micro_batches to eval_batch(), as
the current assumption is that same micro-bs and global-bs is used, so
same number micro batches will take place.
This commit also modifies _scale_loss_by_gas in runtime/engine.py to
consider number of eval micro batches for loss scaling instead of training
gas.
Copy link
Contributor

@ShadenSmith ShadenSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!

@loadams loadams added this pull request to the merge queue Jan 3, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jan 3, 2024
@loadams loadams added this pull request to the merge queue Jan 3, 2024
Merged via the queue into microsoft:master with commit ac84cf3 Jan 3, 2024
14 checks passed
@nelyahu nelyahu deleted the pp_eval_micro_bs branch February 4, 2024 11:58
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
When running evaluation the general memory consumption is reduced.
Mainly due to absence of gradients, and hanging FWD activations. It
allows to increase the micro-bs and improve the evaluation performance.
This commits add the option to pass num_micro_batches to eval_batch(),
as the current assumption is that same micro-bs and global-bs is used,
so same number micro batches will take place.
This commit also modifies _scale_loss_by_gas in runtime/engine.py to
consider number of eval micro batches for loss scaling instead of
training gas.

Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants