Skip to content

Commit

Permalink
fix build break
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Nov 16, 2023
1 parent 955d1a8 commit 2735eaf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
};

const computeAttentionProbs =
(context: ComputeContext, q: TensorView, key: TensorView, bias: TensorView|undefined,
(context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probsShape = [
parameters.batchSize, parameters.numHeads, parameters.sequenceLength,
Expand Down Expand Up @@ -443,7 +443,7 @@ const computeVxAttentionScore =
const K: u32 = ${params.totalSequenceLength}u;
const numHeads: u32 = ${params.numHeads}u;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
Expand Down Expand Up @@ -482,7 +482,7 @@ const computeVxAttentionScore =
let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads};
let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize};
if (m < M && n < N) {
let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize}
let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize}
+ currentBatchHeadNumber * ${params.vHeadSize} + n;
output[outputIdx] = value;
}
Expand All @@ -502,8 +502,8 @@ const computeVxAttentionScore =
};

export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, maskIndex: TensorView|undefined,
past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes);

Expand Down Expand Up @@ -538,7 +538,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
const numHeads: u32 = ${parameters.numHeads};
const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ const maybeTransposeToBNSHAndAddBias =
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
return context.compute(
createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm),
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
{inputs: [reshapedInput], outputs: [-1]})[0];
} else {
if (sequenceLength === 1) {
Expand All @@ -290,7 +290,7 @@ const maybeTransposeToBNSHAndAddBias =
addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!);
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
return context.compute(
createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm),
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
{inputs: [reshapedInput], outputs: [-1]})[0];
}
}
Expand Down

0 comments on commit 2735eaf

Please sign in to comment.