From 56a79cf1cdd47e7272098fb50d92602d8bb7a8aa Mon Sep 17 00:00:00 2001 From: Arthur Islamov <arthur@islamov.ai> Date: Wed, 13 Sep 2023 20:00:24 +0400 Subject: [PATCH] SkipLayerNorm fix --- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 20 ++++++++++--------- js/web/lib/wasm/session-handler.ts | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index a69f37b8d0828..2de86729e7660 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -114,18 +114,19 @@ const createSkipLayerNormProgramInfo = variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); } const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const getShaderSource = (shaderHelper: ShaderHelper) => ` const hiddenSize: u32 = ${hiddenSize}; const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: ${dataType} = ${attributes.epsilon}; + const epsilon: f32 = ${attributes.epsilon}; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} let offset = global_idx * hiddenSizeVectorized; - var sum = ${fillVector(dataType, components)}; - var squareSum = ${fillVector(dataType, components)}; + var sum = ${fillVector('f32', components)}; + var squareSum = ${fillVector('f32', components)}; for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { let skipValue = skip[offset + i]; let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; @@ -133,16 +134,17 @@ const createSkipLayerNormProgramInfo = let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - sum += f32(value); - squareSum += f32(value) * f32(value); + let f32Value = ${castToF32}(value); + sum += f32Value; + squareSum += f32Value * f32Value; } - let mean: ${dataType} = ${sumVector('sum', components)} / ${dataType}(hiddenSize); - let variance: ${dataType} = sqrt(${sumVector('squareSum', components)} - / ${dataType}(hiddenSize) - mean * mean + epsilon); + let mean = ${sumVector('sum', components)} / f32(hiddenSize); + let variance = sqrt(${sumVector('squareSum', components)} / f32(hiddenSize) - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; + output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] + + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index b19ac9993fb28..827c7fa4d0664 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -37,7 +37,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { // https://github.com/WebAssembly/memory64/pull/39 // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore - index: 'u64', + index: 'i64', shared: true, }); promises.push(streamResponseToBuffer(weightsResponse, weightsMemory.buffer, 0).then(() => weightsMemory.buffer));