From 56a79cf1cdd47e7272098fb50d92602d8bb7a8aa Mon Sep 17 00:00:00 2001
From: Arthur Islamov <arthur@islamov.ai>
Date: Wed, 13 Sep 2023 20:00:24 +0400
Subject: [PATCH] SkipLayerNorm fix

---
 .../wasm/jsep/webgpu/ops/skip-layer-norm.ts   | 20 ++++++++++---------
 js/web/lib/wasm/session-handler.ts            |  2 +-
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
index a69f37b8d0828..2de86729e7660 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
@@ -114,18 +114,19 @@ const createSkipLayerNormProgramInfo =
         variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
       }
       const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
+      const castToF32 = components === 1 ? 'f32' : `vec${components}f`;
       const getShaderSource = (shaderHelper: ShaderHelper) => `
       const hiddenSize: u32 = ${hiddenSize};
       const hiddenSizeVectorized: u32 = ${hiddenSize / components};
-      const epsilon: ${dataType} = ${attributes.epsilon};
+      const epsilon: f32 = ${attributes.epsilon};
 
       ${shaderHelper.declareVariables(...variables)}
 
       ${shaderHelper.mainStart()}
         ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)}
         let offset = global_idx * hiddenSizeVectorized;
-        var sum = ${fillVector(dataType, components)};
-        var squareSum = ${fillVector(dataType, components)};
+        var sum = ${fillVector('f32', components)};
+        var squareSum = ${fillVector('f32', components)};
         for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
           let skipValue = skip[offset + i];
           let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'};
@@ -133,16 +134,17 @@ const createSkipLayerNormProgramInfo =
           let value = inputValue + skipValue + biasValue;
           ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''}
           output[offset + i] = value;
-          sum += f32(value);
-          squareSum += f32(value) * f32(value);
+          let f32Value = ${castToF32}(value);
+          sum += f32Value;
+          squareSum += f32Value * f32Value;
         }
-        let mean: ${dataType} = ${sumVector('sum', components)} / ${dataType}(hiddenSize);
-        let variance: ${dataType} = sqrt(${sumVector('squareSum', components)} 
-          / ${dataType}(hiddenSize) - mean * mean + epsilon);
+        let mean = ${sumVector('sum', components)} / f32(hiddenSize);
+        let variance = sqrt(${sumVector('squareSum', components)} / f32(hiddenSize) - mean * mean + epsilon);
         ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
         ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''}
         for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
-          output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'};
+          output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i]
+           + ${hasBetaInput ? 'beta[i]' : '0.0'};
         }
       }`;
       const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}];
diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts
index b19ac9993fb28..827c7fa4d0664 100644
--- a/js/web/lib/wasm/session-handler.ts
+++ b/js/web/lib/wasm/session-handler.ts
@@ -37,7 +37,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
         // https://github.com/WebAssembly/memory64/pull/39
         // eslint-disable-next-line @typescript-eslint/ban-ts-comment
         // @ts-ignore
-        index: 'u64',
+        index: 'i64',
         shared: true,
       });
       promises.push(streamResponseToBuffer(weightsResponse, weightsMemory.buffer, 0).then(() => weightsMemory.buffer));