-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for random sampler recompilations for incomplete batches #663
base: habana_main
Are you sure you want to change the base?
Conversation
@@ -1228,10 +1228,18 @@ def prepare_input_tensors( | |||
batch_size_padded = self.bucketing_ctx.get_padded_batch_size( | |||
real_batch_size, is_prompt) | |||
batch_size_padding = batch_size_padded - real_batch_size | |||
if all([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it have to be 'all' ?
I mean, wouldn't a single sample with sampling be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can be only sure that changing the type of sampler for padded sequences will prevent (and not cause) sampler recompilations if all the sequences in a batch use the same type of sampler. For example there could be 1 sequence with greedy sampler and 2 with random sampler in a batch before batch size padding. Let's say the closest warmed-up batch size is 4. We can add 1 sequence with random sampler or 1 with greedy sampler to make 2 greedy/2 random or 1 greedy/3 random groups. If a bucket with batch size 2 was warmed-up then the sampler with 2 greedy samplings or 2 random samplings will also be warmed-up I believe and that prevents the recompilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say it's pretty situational. If we flip the situation you described, i.e. 1x sampling + 2x greedy then it might be better to add another sampling as 2x sampling is more likely to be already warmed up then 3x greedy. Anyway handling mixed batches is a huge PITA as this might accidently create a recompilation because sampler shapes are not padded afaik.
Ideally we should warm-up sampler separately and have a similar yet independent bucketing from current bs bucketing as sampling params are aggregated before running. This is out of scope for this PR I'm afraid.
For now I'm thinking about optimizing the most common case which is greedy. What if we flipped the logic like this:
"if there's at least one sample with temperature=0 set temperature=0 for all dummy samples" ? 'any' can be faster then 'all' as it doesn't need to traverse all samples. This means that in case of all temperature>0 batches we'll behave exactly the same as your original code, but in optimistic scenario which is all greedy we can at least reduce the impact of the check. If the batch is mixed then well... we might as well flip a coin ;)
Changes the sampler used by dummy sequences to random if all other sequences are using it. Prevents sampler recompilations.