Skip to content

Commit

Permalink
BiasSplitGelu fix for batch size > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Aug 31, 2023
1 parent 92aeb75 commit d3cbe9e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ const createBiasSplitGeluProgramInfo = (metadata: ProgramMetadata, inputs: reado
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let biasIdx = global_idx % halfChannels;
let valueLeft = input[global_idx] + bias[biasIdx];
let valueRight = input[global_idx + halfChannels] + bias[biasIdx + halfChannels];
let batchIndex = global_idx / halfChannels;
let inputOffset = biasIdx + batchIndex * halfChannels * 2;
let valueLeft = input[inputOffset] + bias[biasIdx];
let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels];
let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1);
${output.setByOffset('global_idx', 'valueLeft * geluRight')}
Expand Down

0 comments on commit d3cbe9e

Please sign in to comment.