diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 87cf0712b43fb..97f633c7cf47e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -161,13 +161,14 @@ const computeMean = const meanValues = context.compute( { name: 'InstanceNormComputeMean', - inputTypes: [GpuDataType.default], - cacheHint: JSON.stringify({components, n, h, c}), - outputs: [ - {dims: [n, c, WG, 2], dataType: DataType.float, gpuDataType: GpuDataType.default}, - ], + shaderCache: {hint: JSON.stringify({components, n, h, c})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, WG, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: n * c / components}, + }), getShaderSource: getMeanShaderSource, - dispatchGroup: () => ({x: n * c / components}) }, {inputs: [input], outputs: [-1]})[0]; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -206,20 +207,20 @@ const computeMean = return context.compute( { name: 'InstanceNormComputeChannelScaleShift', - inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], - cacheHint: JSON.stringify({components, n, h, c, epsilon}), - outputs: [ - {dims: [n, c, 2], dataType: DataType.float, gpuDataType: GpuDataType.default}, - ], + shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + getRunData: () => ({ + outputs: [ + {dims: [n, c, 2], dataType: DataType.float}, + ], + dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + }), getShaderSource, - dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) }, {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; }; const createInstanceNormNHWCProgramInfo = - (context: ComputeContext, metadata: ProgramMetadata, inputs: readonly TensorView[], - attributes: InstanceNormAttributes) => { + (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; const N = xShape[0]; @@ -255,13 +256,13 @@ const createInstanceNormNHWCProgramInfo = }`; context.compute( { - ...metadata, - inputTypes: [GpuDataType.default, GpuDataType.default], - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], + name: 'InstanceNormalization', + shaderCache: {hint: `${attributes.cacheKey}`}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + }), getShaderSource, - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }, {inputs: [inputs[0], channelScaleShift]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 98fd3d1fd54d6..8a9eeecf2c68d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -26,50 +26,50 @@ const createLayerNormProgramInfo = const scale = inputs[1]; const bias = inputs[2]; - const outputShape = xShape; - const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - - const scaleSize = ShapeUtil.size(scale.dims); - const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; - if (scaleSize !== normSize || (bias && biasSize !== normSize)) { - throw new Error(`Size of X.shape()[axis:] == ${normSize}. + const outputShape = xShape; + const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); + const normCount = ShapeUtil.sizeToDimension(xShape, axis); + const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + + const scaleSize = ShapeUtil.size(scale.dims); + const biasSize = bias ? ShapeUtil.size(bias.dims) : 0; + if (scaleSize !== normSize || (bias && biasSize !== normSize)) { + throw new Error(`Size of X.shape()[axis:] == ${normSize}. Size of scale and bias (if provided) must match this. Got scale size of ${scaleSize} and bias size of ${biasSize}`); - } - - const meanInvStdDevDim = []; - for (let i = 0; i < xShape.length; ++i) { - if (i < axis) { - meanInvStdDevDim.push(xShape[i]); - } else { - meanInvStdDevDim.push(1); - } - } - - const components = getMaxComponents(normSize); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('scale', scale.dataType, scale.dims, components), - ]; - if (bias) { - variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - - const hasMeanDataOutput = outputCount > 1; - const hasInvStdOutput = outputCount > 2; - - if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - - const getShaderSource = (shaderHelper: ShaderHelper) => ` + } + + const meanInvStdDevDim = []; + for (let i = 0; i < xShape.length; ++i) { + if (i < axis) { + meanInvStdDevDim.push(xShape[i]); + } else { + meanInvStdDevDim.push(1); + } + } + + const components = getMaxComponents(normSize); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + + const hasMeanDataOutput = outputCount > 1; + const hasInvStdOutput = outputCount > 2; + + if (hasMeanDataOutput) { + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + + const getShaderSource = (shaderHelper: ShaderHelper) => ` const normSize: f32 = ${normSize}; const normSizeVectorized: u32 = ${normSize / components}; const epsilon: f32 = ${attributes.epsilon}; @@ -101,13 +101,13 @@ const createLayerNormProgramInfo = ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; ${hasInvStdOutput ? 'invStdOutput[global_idx] = 1 / meanSquare' : ''}; }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; - if (hasMeanDataOutput) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); - } - if (hasInvStdOutput) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); - } + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (hasMeanDataOutput) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (hasInvStdOutput) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } return { name: 'LayerNormalization', diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 9486fa3d0c534..7e500f865c19b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -71,44 +71,44 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const createSkipLayerNormProgramInfo = - (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, - isTraining: boolean): ProgramInfo => { - const inputShape = inputs[0].dims; - const inputSize = ShapeUtil.size(inputShape); - const outputShape = inputShape; - const outputSize = inputSize; - const hiddenSize = inputShape.slice(-1)[0]; - const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; - const hasBetaInput = inputs.length > 3; - const hasBiasInput = inputs.length > 4; - const hasMeanOutput = isTraining && outputCount > 1; - const hasInvStdDevOutput = isTraining && outputCount > 2; - const hasInputSkipBiasSumOutput = outputCount > 3; - - const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + (inputs: readonly TensorView[], attributes: SkipLayerNormAttributes, outputCount: number, isTraining: boolean): + ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const outputShape = inputShape; + const outputSize = inputSize; + const hiddenSize = inputShape.slice(-1)[0]; + const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; + const hasBetaInput = inputs.length > 3; + const hasBiasInput = inputs.length > 4; + const hasMeanOutput = isTraining && outputCount > 1; + const hasInvStdDevOutput = isTraining && outputCount > 2; + const hasInputSkipBiasSumOutput = outputCount > 3; + + const components = getMaxComponents(hiddenSize); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` const hiddenSize: f32 = ${hiddenSize}; const hiddenSizeVectorized: u32 = ${hiddenSize / components}; const epsilon: f32 = ${attributes.epsilon}; @@ -140,24 +140,24 @@ const createSkipLayerNormProgramInfo = + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; - const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; - if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); - } - if (outputCount > 3) { - outputs.push({dims: inputShape, dataType: inputs[0].dataType}); - } - - return { - name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, - getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), - }; - }; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (outputCount > 1) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 2) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); + } + if (outputCount > 3) { + outputs.push({dims: inputShape, dataType: inputs[0].dataType}); + } + + return { + name: 'SkipLayerNormalization', + shaderCache: {hint: attributes.cacheKey}, + getShaderSource, + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + }; + }; export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNormAttributes): void => { // TODO: initialize isTraining from ComputeContext