Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mha for in the case that present kv is not consumed #21777

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,27 @@ export const applyAttention = (
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const pastSequenceLength =
parameters.kvNumHeads !== undefined || context.outputCount > 1 ? parameters.pastSequenceLength : 0;

// context.outputCount comes from KernelOp and is the number of outputs the op has.
// If they are not consumed we need to make sure the shaders don't generate the output
// since there is no buffer for it.
// We check by requesting the output and if not there we'll adjust context.outputCount
const presentKeyShape = [
parameters.batchSize,
Copy link
Contributor

@tianleiwu tianleiwu Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shape only works for MHA and GQA.
Attention output 1 shape is [2, B, N, T, H] instead of [B, N, T, H], since it concatenates present_key and present_value as present output.

I think here need extra code like

if (attention op) { // can we get operator name from context? Maybe we can use context.outputCount === 2 since MHA and GQA has 3 outputs if present_key are needed.
    // insert 2 at the beginning of present shape.
}

Copy link
Contributor

@tianleiwu tianleiwu Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need consider another special case for GQA that past and present shares buffers. In that case, the length is max sequence length.

parameters.kvNumHeads === undefined ? parameters.numHeads : parameters.kvNumHeads,
parameters.totalSequenceLength,
parameters.headSize,
];
const output1 = context.output(1, presentKeyShape);
if (output1 === 0) {
context.outputCount = 1;
}
const outputCount = context.outputCount;
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const outputPresent = outputCount > 1;

const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k];
const inputsK = parameters.kvNumHeads === undefined && outputPresent && pastKey ? [q, k, pastKey] : [q, k];
if (attentionBias) {
inputsK.push(attentionBias);
}
Expand All @@ -683,13 +699,13 @@ export const applyAttention = (
context,
q,
k,
outputCount > 1 ? pastKey : undefined,
outputPresent ? pastKey : undefined,
attentionBias,
parameters,
attributes,
pastSequenceLength,
),
{ inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
{ inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputPresent ? [-1, 1] : [-1] },
)[0];

// Run Softmax
Expand All @@ -698,24 +714,24 @@ export const applyAttention = (
context,
probs,
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
totalSequenceLength,
parameters.totalSequenceLength,
),
{ inputs: [probs], outputs: [] },
);

// Run AttrionScore
const inputsV =
parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v];
parameters.kvNumHeads === undefined && outputPresent && pastValue ? [probs, v, pastValue] : [probs, v];
context.compute(
createVxAttentionScoreProgramInfo(
context,
probs,
v,
outputCount > 1 && pastValue ? pastValue : undefined,
outputPresent && pastValue ? pastValue : undefined,
parameters,
pastSequenceLength,
),
{ inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] },
{ inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputPresent ? [0, 2] : [0] },
);
};

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ export interface ComputeContext {
/**
* a number of outputs for the node
*/
readonly outputCount: number;
outputCount: number;

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[];
output(index: number, dims: readonly number[]): number;
Expand Down
Loading