Skip to content

Commit

Permalink
SkipLayerNorm fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 13, 2023
1 parent df4e8d6 commit 56a79cf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 11 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,35 +114,37 @@ const createSkipLayerNormProgramInfo =
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const castToF32 = components === 1 ? 'f32' : `vec${components}f`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const hiddenSize: u32 = ${hiddenSize};
const hiddenSizeVectorized: u32 = ${hiddenSize / components};
const epsilon: ${dataType} = ${attributes.epsilon};
const epsilon: f32 = ${attributes.epsilon};
${shaderHelper.declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)}
let offset = global_idx * hiddenSizeVectorized;
var sum = ${fillVector(dataType, components)};
var squareSum = ${fillVector(dataType, components)};
var sum = ${fillVector('f32', components)};
var squareSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
let skipValue = skip[offset + i];
let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'};
let inputValue = x[offset + i];
let value = inputValue + skipValue + biasValue;
${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''}
output[offset + i] = value;
sum += f32(value);
squareSum += f32(value) * f32(value);
let f32Value = ${castToF32}(value);
sum += f32Value;
squareSum += f32Value * f32Value;
}
let mean: ${dataType} = ${sumVector('sum', components)} / ${dataType}(hiddenSize);
let variance: ${dataType} = sqrt(${sumVector('squareSum', components)}
/ ${dataType}(hiddenSize) - mean * mean + epsilon);
let mean = ${sumVector('sum', components)} / f32(hiddenSize);
let variance = sqrt(${sumVector('squareSum', components)} / f32(hiddenSize) - mean * mean + epsilon);
${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''}
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'};
output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i]
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}];
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
// https://github.com/WebAssembly/memory64/pull/39
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
index: 'u64',
index: 'i64',
shared: true,
});
promises.push(streamResponseToBuffer(weightsResponse, weightsMemory.buffer, 0).then(() => weightsMemory.buffer));
Expand Down

0 comments on commit 56a79cf

Please sign in to comment.