diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 663009e520020..328bc8fe52038 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -87,44 +87,52 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi const inputHelper = inputVariable('input', input.dataType, input.dims, components); const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + const dataType = tensorTypeToWsglStorageType(input.dataType); const WG = 64; // we will store channel scale and channel shift in [2, components] matrix // or in vec2 when components == 1 - const outputType = components === 1 ? 'array>' : `array`; + const outputType = components === 1 ? `vec2<${dataType}>` : `mat2x${components}<${dataType}>`; const setOutputValue = (var1: string, var2: string) => { - return components === 1 - ? `vec2f(${var1}, ${var2})` - : `mat2x${components}(${var1}, ${var2})`; + return `${outputType}(${var1}, ${var2})`; }; const unitsOfWork = n * c / components; const wgSize = Math.ceil(h / WG); + + let divisor = `${dataType}(H)`; + if (input.dataType === DataType.float16 && h > 65504) { + divisor = `f16(${h / 2}) / 2.0h`; + } + const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` const H: u32 = ${h}; const C: u32 = ${c / components}; const imageSize: u32 = ${h * c / components}; ${shaderHelper.declareVariables(inputHelper)} - @group(0) @binding(1) var output : ${outputType}; + @group(0) @binding(1) var output : array<${outputType}>; - ${shaderHelper.mainStart(64)} + ${shaderHelper.mainStart(WG)} let currentImageNumber = global_idx / ${WG} / C; let currentChannelNumber = (global_idx / ${WG}) % C; - let wgId = global_idx % 64; + let wgId = global_idx % ${WG}; let wgOffset = wgId * ${wgSize}; - if (wgOffset > H) { + if (wgOffset >= H) { return; } let wgMax = min(wgOffset + ${wgSize}, H); let offset = currentImageNumber * imageSize + currentChannelNumber; - var sum: ${inputHelper.type.storage} = ${fillVector('f32', components)}; - var squaredSum: ${inputHelper.type.storage} = ${fillVector('f32', components)}; + var sum: ${inputHelper.type.storage} = ${fillVector(dataType, components)}; + var squaredSum: ${inputHelper.type.storage} = ${fillVector(dataType, components)}; for (var i: u32 = wgOffset; i < wgMax; i++) { let value = input[offset + i * C]; sum += value; squaredSum += value * value; } + // we need to divide it here to avoid fp16 overflow + sum = sum / ${divisor}; + squaredSum = squaredSum / ${divisor}; output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; @@ -144,12 +152,12 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi const H: u32 = ${h}; const C: u32 = ${c / components}; const imageSize: u32 = ${WG * c / components}; - const epsilon: f32 = ${epsilon}; + const epsilon: ${dataType} = ${epsilon}; - @group(0) @binding(0) var input : ${outputType}; + @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; - @group(0) @binding(3) var output : ${outputType}; + @group(0) @binding(3) var output : array<${outputType}>; ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} @@ -157,29 +165,27 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi let currentChannelNumber = global_idx % C; let offset = currentImageNumber * imageSize; - var sum: ${inputHelper.type.storage} = ${fillVector('f32', components)}; - var squaredSum: ${inputHelper.type.storage} = ${fillVector('f32', components)}; + var sum: ${inputHelper.type.storage} = ${fillVector(dataType, components)}; + var squaredSum: ${inputHelper.type.storage} = ${fillVector(dataType, components)}; for (var i: u32 = 0; i < ${WG}; i++) { - let value = input[offset + i]; + let value = input[offset + i + currentChannelNumber * ${WG}]; sum += value[0]; squaredSum += value[1]; } - let mean = sum / f32(H); - let invStdDev = 1 / sqrt(squaredSum / f32(H) - mean * mean + epsilon); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); let channelScale = invStdDev * scale[currentChannelNumber]; - let channelShift = bias[currentChannelNumber] - mean * channelScale; + let channelShift = bias[currentChannelNumber] - sum * channelScale; output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; return context.compute( { - name: 'InstanceNormComputeMean', + name: 'InstanceNormComputeChannelScaleShift', inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], cacheHint: JSON.stringify({ components, n, h, c, epsilon }), outputs: [ {dims: [n, c, 2], dataType: DataType.float, gpuDataType: GpuDataType.default}, - // {dims: [h * c], dataType: DataType.float, gpuDataType: GpuDataType.default}, ], getShaderSource, dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) @@ -201,7 +207,8 @@ const createInstanceNormNHWCProgramInfo = const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); - const scaleType = components === 1 ? 'vec2' : `mat2x${components}f`; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? `vec2<${dataType}>` : `mat2x${components}<${dataType}>`; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);