Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for random sampler recompilations for incomplete batches #663

Open
wants to merge 4 commits into
base: habana_main
Choose a base branch
from

Conversation

mfylcek
Copy link

@mfylcek mfylcek commented Dec 30, 2024

Changes the sampler used by dummy sequences to random if all other sequences are using it. Prevents sampler recompilations.

@mfylcek mfylcek changed the title Sampler-aware batch_size padding Fix for sampler recompilations when using random sampler with batch_size padding Dec 30, 2024
@mfylcek mfylcek changed the title Fix for sampler recompilations when using random sampler with batch_size padding Fix for random sampler recompilations for incomplete batches Dec 31, 2024
@mfylcek mfylcek marked this pull request as ready for review December 31, 2024 13:36
@@ -1228,10 +1228,18 @@ def prepare_input_tensors(
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
if all([
Copy link

@madamczykhabana madamczykhabana Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it have to be 'all' ?
I mean, wouldn't a single sample with sampling be sufficient?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can be only sure that changing the type of sampler for padded sequences will prevent (and not cause) sampler recompilations if all the sequences in a batch use the same type of sampler. For example there could be 1 sequence with greedy sampler and 2 with random sampler in a batch before batch size padding. Let's say the closest warmed-up batch size is 4. We can add 1 sequence with random sampler or 1 with greedy sampler to make 2 greedy/2 random or 1 greedy/3 random groups. If a bucket with batch size 2 was warmed-up then the sampler with 2 greedy samplings or 2 random samplings will also be warmed-up I believe and that prevents the recompilation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it's pretty situational. If we flip the situation you described, i.e. 1x sampling + 2x greedy then it might be better to add another sampling as 2x sampling is more likely to be already warmed up then 3x greedy. Anyway handling mixed batches is a huge PITA as this might accidently create a recompilation because sampler shapes are not padded afaik.

Ideally we should warm-up sampler separately and have a similar yet independent bucketing from current bs bucketing as sampling params are aggregated before running. This is out of scope for this PR I'm afraid.

For now I'm thinking about optimizing the most common case which is greedy. What if we flipped the logic like this:
"if there's at least one sample with temperature=0 set temperature=0 for all dummy samples" ? 'any' can be faster then 'all' as it doesn't need to traverse all samples. This means that in case of all temperature>0 batches we'll behave exactly the same as your original code, but in optimistic scenario which is all greedy we can at least reduce the impact of the check. If the batch is mixed then well... we might as well flip a coin ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants