diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 2fd90d7d4c9a3..391fdd39d1013 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -64,10 +64,10 @@ const createBatchNormInferenceProgramInfo = const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); - const scale = inputVariable('scale', inputs[1].dataType, [ShapeUtil.size(inputs[1].dims)], cComponents); - const bias = inputVariable('bias', inputs[2].dataType, [ShapeUtil.size(inputs[2].dims)], cComponents); - const inputMean = inputVariable('inputMean', inputs[3].dataType, [ShapeUtil.size(inputs[3].dims)], cComponents); - const inputVar = inputVariable('inputVar', inputs[4].dataType, [ShapeUtil.size(inputs[4].dims)], cComponents); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents); + const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents); + const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents); const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components); const calcCOffset = (): string => { @@ -84,11 +84,11 @@ const createBatchNormInferenceProgramInfo = let cOffset = ${y.indicesToOffset('outputIndices')};`; } else { // update C channel. - cOffset = `var cIndices = ${scale.type.indices}('0'); + cOffset = `var cIndices = ${scale.type.indices}(0); cIndices[0] = outputIndices[${yShape.length - 1}];`; // update D1 x ... x Dn channels. for (let i = 1; i < scale.rank; i++) { - cOffset += `cIndices[${i}] = outputIndices[${i + 1}];`; + cOffset += `cIndices[${i}] = outputIndices[${i}];`; } cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`; } diff --git a/js/web/test/data/ops/batch-norm.jsonc b/js/web/test/data/ops/batch-norm.jsonc index 2771a7cc7d1c9..4ea16f290dc8f 100644 --- a/js/web/test/data/ops/batch-norm.jsonc +++ b/js/web/test/data/ops/batch-norm.jsonc @@ -397,5 +397,50 @@ ] } ] + }, + { + "name": "BatchNormalization non-spatial mode - NHWC", + "operator": "BatchNormalization", + "opset": { "domain": "com.ms.internal.nhwc", "version": 7 }, + "attributes": [{ "name": "spatial", "data": 0, "type": "int" }], + "cases": [ + { + "name": "T[3,2,1]", + "inputs": [ + { + "data": [0.2134, 0.32434, 0.5644, 0.3234, 0.4545, 0.3445], + "dims": [3, 2, 1], + "type": "float32" + }, + { + "data": [0.5, 0.6], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.2, 0.1], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [0.034, 0.342], + "dims": [1, 2], + "type": "float32" + }, + { + "data": [1, 1], + "dims": [1, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2897, 0.089404, 0.4652, 0.08884, 0.41025, 0.1015], + "dims": [3, 2, 1], + "type": "float32" + } + ] + } + ] } ]