-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix mha for in the case that present kv is not consumed #21777
Conversation
// since there is no buffer for it. | ||
// We check by requesting the output and if not there we'll adjust context.outputCount | ||
const presentKeyShape = [ | ||
parameters.batchSize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shape only works for MHA and GQA.
Attention output 1 shape is [2, B, N, T, H] instead of [B, N, T, H], since it concatenates present_key and present_value as present output.
I think here need extra code like
if (attention op) { // can we get operator name from context? Maybe we can use context.outputCount === 2 since MHA and GQA has 3 outputs if present_key are needed.
// insert 2 at the beginning of present shape.
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need consider another special case for GQA that past and present shares buffers. In that case, the length is max sequence length.
with #21782 this one is no longer needed. |
No description provided.