diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 97f633c7cf47e..3a84844544c96 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -26,22 +26,25 @@ const createInstanceNormProgramInfo = const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(normSize); + const normPackedSize = normSize / components; const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const variables = [x, scale, bias, output]; const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; const workgroupSize = 64; const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; const normSize: u32 = ${normSize}; const epsilon: f32 = ${attributes.epsilon}; - var meanShared : ${dataType}; - var squaredNormShared : ${dataType}; - var workgroupShared : array<${dataType}, ${workgroupSize}>; + var meanShared : f32; + var squaredNormShared : f32; + var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} @@ -51,9 +54,9 @@ const createInstanceNormProgramInfo = let localIndex = local_id.x; // initialize workgroup memory - var initial: ${dataType} = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - initial = initial + ${x.get('batch', 'channel', 'h')}; + var initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; workgroupBarrier(); @@ -66,14 +69,14 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = workgroupShared[0] / ${dataType}(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); } workgroupBarrier(); // reinitialize workgroup memory. - initial = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let deviation = ${x.get('batch', 'channel', 'h')} - meanShared; + initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } workgroupShared[localIndex] = initial; @@ -87,15 +90,16 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - squaredNormShared = workgroupShared[0]; + squaredNormShared = ${sumVector('workgroupShared[0]', components)}; } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon); - let channelScale = invStdDev * ${scale.getByOffset('channel')}; - let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift; + let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); + let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`;