Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Continuous Decoding support in GQA #21523

Merged
merged 28 commits into from
Sep 13, 2024
Merged

Conversation

aciddelgado
Copy link
Contributor

@aciddelgado aciddelgado commented Jul 26, 2024

Description

This PR will add support for Continuous Decoding for batch_size = 1 input. From now on, GQA can take arbitrary length input using seqlens_k as total_sequence_length - 1 and the sequence length of qkv as new_sequence_length.

This change will not affect the default behavior of GQA

Motivation and Context

Prior to this change it was impossible to support sequence_length > 1 inputs when past context was given. This use case is essential to making continuous decoding work, which is one of our current efforts in ORT-GenAI.

BowenBao added a commit to BowenBao/onnxruntime-genai that referenced this pull request Aug 2, 2024
Results are validated with model-generate.py by using a int4 quantized
model as the original model's assistant. The output sequence is the same
and increased tps is observed.

NOTE: Only MHA decoder only models, batch size 1, CPU, greedy select top
is supported in this initial version. GQA needs microsoft/onnxruntime#21523
to support seqlen > 1 in token phase.

* Updated builder.py to produce MHA graph that supports seqlen > 1
  in token phase.
* Introduce speculative decoding currently through a separate Generator
  class. This can be merged with existing Generator potentially on
  either API level or implementation level.
* Extended various components for functionalities to support
  speculative search. Previously most methods are hardcoded assuming
  seqlen == 1 for token phase.
@aciddelgado aciddelgado marked this pull request as ready for review September 9, 2024 23:23
@aciddelgado aciddelgado changed the title Add Interactive Decoding support in GQA Add Continuous Decoding support in GQA Sep 10, 2024
yufenglee
yufenglee previously approved these changes Sep 11, 2024
Copy link
Member

@yufenglee yufenglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

docs/ContribOperators.md Outdated Show resolved Hide resolved
@tianleiwu
Copy link
Contributor

Please fix PREfast warnings.

@tianleiwu
Copy link
Contributor

@tianleiwu
Copy link
Contributor

Python format failed. Please run lintrunner

docs/ContribOperators.md Outdated Show resolved Hide resolved
@aciddelgado aciddelgado merged commit 7e2c722 into main Sep 13, 2024
87 checks passed
@aciddelgado aciddelgado deleted the aciddelgado/gqa_interactive branch September 13, 2024 20:21
BowenBao added a commit to BowenBao/onnxruntime-genai that referenced this pull request Oct 15, 2024
Results are validated with model-generate.py by using a int4 quantized
model as the original model's assistant. The output sequence is the same
and increased tps is observed.

NOTE: Only MHA decoder only models, batch size 1, CPU, greedy select top
is supported in this initial version. GQA needs microsoft/onnxruntime#21523
to support seqlen > 1 in token phase.

* Updated builder.py to produce MHA graph that supports seqlen > 1
  in token phase.
* Introduce speculative decoding currently through a separate Generator
  class. This can be merged with existing Generator potentially on
  either API level or implementation level.
* Extended various components for functionalities to support
  speculative search. Previously most methods are hardcoded assuming
  seqlen == 1 for token phase.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants