From f6e7335e2ee80c66ffd24f94146bd5274c0eb501 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Wed, 20 Sep 2023 16:52:14 +0400 Subject: [PATCH] FP16 LayerNorm, InstanceNorm, SkipLayerNorm --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 46 +++++ .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 185 +++++++++++++----- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 58 +++--- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 74 +++---- onnxruntime/contrib_ops/js/skip_layer_norm.cc | 6 +- .../providers/js/operators/instance_norm.cc | 19 +- .../core/providers/js/operators/layer_norm.cc | 24 +-- .../core/providers/js/operators/layer_norm.h | 1 - 8 files changed, 276 insertions(+), 137 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c054da51a3098..b7e59c60c71b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -237,6 +237,52 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; +/** + * A helper function to get maximum vector size for specified data length + * @param size + */ +export const getMaxComponents = (size: number) => { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 === 0) { + return 4; + } else if (size % 2 === 0) { + return 2; + } + + return 1; +}; + +/** + * A helper function that initializes variable as a scalar or vector. e.g. f32(0) or vec4f(0,0,0,0) + * @param dataType + * @param components + * @param value + */ +export const fillVector = (dataType = 'f32', components?: number, value = '0') => { + if (!components || components === 1) { + return `${dataType}(${value})`; + } + + return `vec${components}<${dataType}>(${new Array(components).fill(value).join(',')})`; +}; + +/** + * A helper function that returns scalar or sums all components of a vector + * @param name + * @param components + */ +export const sumVector = (name: string, components: number) => { + if (components === 4) { + return `(${name}.x + ${name}.y + ${name}.z + ${name}.w)`; + } else if (components === 2) { + return `(${name}.x + ${name}.y)`; + } else if (components === 3) { + return `(${name}.x + ${name}.y + ${name}.z)`; + } + + return name; +}; + /** * A helper function to get a IndicesHelper for a given input or output. * diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 5a148bda0a9f7..6ed811c38306b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -104,66 +105,160 @@ const createInstanceNormProgramInfo = }; }; +const computeMean = + (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, + epsilon: number) => { + const components = getMaxComponents(c); + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + + const WG = 64; + // we will store channel scale and channel shift in [2, components] matrix + // or in vec2 when components == 1 + const outputType = components === 1 ? `vec2f` : `mat2x${components}f`; + const sumCastType = components === 1 ? `f32` : `vec${components}f`; + const setOutputValue = (var1: string, var2: string) => { + return `${outputType}(${var1}, ${var2})`; + }; + const unitsOfWork = n * c / components; + const wgSize = Math.ceil(h / WG); + + const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${h * c / components}; + + ${shaderHelper.declareVariables(inputHelper)} + @group(0) @binding(1) var output : array<${outputType}>; + + ${shaderHelper.mainStart(WG)} + let currentImageNumber = global_idx / ${WG} / C; + let currentChannelNumber = (global_idx / ${WG}) % C; + let wgId = global_idx % ${WG}; + let wgOffset = wgId * ${wgSize}; + if (wgOffset >= H) { + return; + } + let wgMax = min(wgOffset + ${wgSize}, H); + + let offset = currentImageNumber * imageSize + currentChannelNumber; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = wgOffset; i < wgMax; i++) { + let value = ${sumCastType}(input[offset + i * C]); + sum += value; + squaredSum += value * value; + } + output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; + }`; + + const meanValues = context.compute( + { + name: 'InstanceNormComputeMean', + inputTypes: [GpuDataType.default], + cacheHint: JSON.stringify({components, n, h, c}), + outputs: [ + {dims: [n, c, WG, 2], dataType: DataType.float, gpuDataType: GpuDataType.default}, + ], + getShaderSource: getMeanShaderSource, + dispatchGroup: () => ({x: n * c / components}) + }, + {inputs: [input], outputs: [-1]})[0]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const H: u32 = ${h}; + const C: u32 = ${c / components}; + const imageSize: u32 = ${WG * c / components}; + const epsilon: f32 = ${epsilon}; + + @group(0) @binding(0) var input : array<${outputType}>; + @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; + @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; + @group(0) @binding(3) var output : array<${outputType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} + let currentImageNumber = global_idx / C; + let currentChannelNumber = global_idx % C; + + let offset = currentImageNumber * imageSize; + var sum = ${fillVector('f32', components)}; + var squaredSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < ${WG}; i++) { + let value = input[offset + i + currentChannelNumber * ${WG}]; + sum += value[0]; + squaredSum += value[1]; + } + sum = sum / f32(H); + squaredSum = squaredSum / f32(H); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); + let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; + + output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; + }`; + + return context.compute( + { + name: 'InstanceNormComputeChannelScaleShift', + inputTypes: [GpuDataType.default, GpuDataType.default, GpuDataType.default], + cacheHint: JSON.stringify({components, n, h, c, epsilon}), + outputs: [ + {dims: [n, c, 2], dataType: DataType.float, gpuDataType: GpuDataType.default}, + ], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) + }, + {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; + }; + const createInstanceNormNHWCProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { + (context: ComputeContext, metadata: ProgramMetadata, inputs: readonly TensorView[], + attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; + const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? `vec2f` : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + // first compute mean + const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); - const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; const H: u32 = ${H}; - const C: u32 = ${C}; - const normSizeTyped: ${dataType} = ${H}; - const imageSize: u32 = ${H * C}; - const epsilon: f32 = ${attributes.epsilon}; + const C: u32 = ${C / components}; - @group(0) @binding(0) var x : array<${dataType}>; - @group(0) @binding(1) var scale : array<${dataType}>; - @group(0) @binding(2) var bias : array<${dataType}>; - @group(0) @binding(3) var output : array<${dataType}>; + @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; + @group(0) @binding(1) var scaleInput : array<${scaleType}>; + @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / C; + let currentImageNumber = global_idx / (C * H); let currentChannelNumber = global_idx % C; - // offset is channel num * N - let offset = currentImageNumber * imageSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - - for (var i: u32 = 0u; i < H; i++) { - mean = mean + x[offset + i * C + currentChannelNumber]; - } - mean = mean / normSizeTyped; - - var squaredNorm: ${dataType} = 0; - for (var i: u32 = 0u; i < H; i++) { - let deviation: f32 = x[offset + i * C + currentChannelNumber] - mean; - squaredNorm = squaredNorm + deviation * deviation; - } - let invStdDev = 1 / sqrt(squaredNorm / normSizeTyped + epsilon); - let channelScale = invStdDev * scale[currentChannelNumber]; - let channelShift = bias[currentChannelNumber] - mean * channelScale; - for (var i: u32 = 0u; i < H; i++) { - let currentOffset = offset + i * C + currentChannelNumber; - output[currentOffset] = x[currentOffset] * channelScale + channelShift; - } + let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scale = scaleInput[scaleOffset]; + output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; - return { - ...metadata, - outputs: [ - {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(normCount / 64 /* workgroup size */)}) - }; + context.compute( + { + ...metadata, + inputTypes: [GpuDataType.default, GpuDataType.default], + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }, + {inputs: [inputs[0], channelScaleShift]}); }; export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => @@ -177,7 +272,7 @@ export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAt }; if (attributes.format === 'NHWC') { - context.compute(createInstanceNormNHWCProgramInfo(metadata, context.inputs, attributes)); + createInstanceNormNHWCProgramInfo(context, metadata, context.inputs, attributes); } else { context.compute(createInstanceNormProgramInfo(metadata, context.inputs, attributes)); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index d6a79e9460c3f..f3dda8a701038 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -18,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 2) { throw new Error('layerNorm requires at least 2 inputs.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; const createLayerNormProgramInfo = @@ -32,7 +27,6 @@ const createLayerNormProgramInfo = const bias = inputs[2]; const outputShape = xShape; - const outputSize = ShapeUtil.size(outputShape); const axis = ShapeUtil.normalizeAxis(attributes.axis, xShape.length); const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); @@ -55,40 +49,44 @@ const createLayerNormProgramInfo = } const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const components = getMaxComponents(normSize); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('scale', scale.dataType, scale.dims, components), + ]; + if (bias) { + variables.push(inputVariable('bias', bias.dataType, bias.dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; - let bindingIndex = 0; + + if (hasMeanDataOutput) { + variables.push(outputVariable('meanDataOutput', inputs[0].dataType, meanInvStdDevDim)); + } + if (hasInvStdOutput) { + variables.push(outputVariable('invStdOutput', inputs[0].dataType, meanInvStdDevDim)); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const normSize: u32 = ${normSize}; + const normSize: u32 = ${normSize / components}; const normSizeTyped: ${dataType} = ${normSize}; - const epsilon: f32 = ${attributes.epsilon}; - - @group(0) @binding(${bindingIndex++}) var x : array<${dataType}>; - @group(0) @binding(${bindingIndex++}) var scale : array<${dataType}>; - ${bias ? `@group(0) @binding(${bindingIndex++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingIndex++}) var output : array<${dataType}>; - ${ - hasMeanDataOutput ? - `@group(0) @binding(${bindingIndex++}) var meanDataOutput : array<${dataType}>` : - ''}; - ${ - hasInvStdOutput ? - `@group(0) @binding(${bindingIndex++}) var invStdOutput : array<${dataType}>` : - ''}; + const epsilon: ${dataType} = ${attributes.epsilon}; + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} let offset = global_idx * normSize; - if (offset >= ${outputSize}) { return; } - var mean: ${dataType} = 0; - var meanSquare: ${dataType} = 0; + var meanVector = ${fillVector(dataType, components)}; + var meanSquareVector = ${fillVector(dataType, components)}; for (var h: u32 = 0u; h < normSize; h++) { - mean = mean + x[h + offset]; - meanSquare = meanSquare + x[h + offset] * x[h + offset]; + meanVector += x[h + offset]; + meanSquareVector += x[h + offset] * x[h + offset]; } - mean = mean / normSizeTyped; - meanSquare = sqrt(meanSquare / normSizeTyped - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / normSizeTyped; + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} / normSizeTyped - mean * mean + epsilon); for (var j: u32 = 0; j < normSize; j++) { output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 7bfdd73b8af18..ec91b9979898b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -18,9 +18,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('layerNorm requires at least 3 inputs.'); } - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } const input: TensorView = inputs[0]; const skip: TensorView = inputs[1]; const gamma: TensorView = inputs[2]; @@ -84,55 +81,64 @@ const createSkipLayerNormProgramInfo = const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; const hasBetaInput = inputs.length > 3; const hasBiasInput = inputs.length > 4; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanOutput = isTraining && outputCount > 1; const hasInvStdDevOutput = isTraining && outputCount > 2; const hasInputSkipBiasSumOutput = outputCount > 3; - let bindingNumber = 0; + + const components = getMaxComponents(hiddenSize); + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const getShaderSource = (shaderHelper: ShaderHelper) => ` const hiddenSize: u32 = ${hiddenSize}; + const hiddenSizeVectorized: u32 = ${hiddenSize / components}; const epsilon: f32 = ${attributes.epsilon}; - @group(0) @binding(${bindingNumber++}) var x : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var skip : array<${dataType}>; - @group(0) @binding(${bindingNumber++}) var gamma : array<${dataType}>; - ${hasBetaInput ? `@group(0) @binding(${bindingNumber++}) var beta : array<${dataType}>;` : ''} - ${hasBiasInput ? `@group(0) @binding(${bindingNumber++}) var bias : array<${dataType}>;` : ''} - @group(0) @binding(${bindingNumber++}) var output : array<${dataType}>; - ${ - hasMeanOutput ? - `@group(0) @binding(${bindingNumber++}) var meanOutput : array<${dataType}>;` : - ''} - ${ - hasInvStdDevOutput ? - `@group(0) @binding(${bindingNumber++}) var invStdOutput : array<${dataType}>;` : - ''} - ${ - hasInputSkipBiasSumOutput ? - `@group(0) @binding(${bindingNumber++}) var inputSkipBiasSum : array<${dataType}>;` : - ''} + ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSize; - var sum: f32 = 0.0; - var squareSum: f32 = 0.0; - for (var i: u32 = 0; i < hiddenSize; i++) { + let offset = global_idx * hiddenSizeVectorized; + var sum = ${fillVector('f32', components)}; + var squareSum = ${fillVector('f32', components)}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { let skipValue = skip[offset + i]; let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; let inputValue = x[offset + i]; let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - sum += value; - squareSum += value * value; + let f32Value = ${castToF32}(value); + sum += f32Value; + squareSum += f32Value * f32Value; } - let mean: f32 = sum / f32(hiddenSize); - let variance: f32 = sqrt(squareSum / f32(hiddenSize) - mean * mean + epsilon); + let mean = ${sumVector('sum', components)} / f32(hiddenSize); + let variance = sqrt(${sumVector('squareSum', components)} / f32(hiddenSize) - mean * mean + epsilon); ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = 1.0 / variance;' : ''} - for (var i: u32 = 0; i < hiddenSize; i++) { - output[offset + i] = (output[offset + i] - mean) / variance * gamma[i] + ${hasBetaInput ? 'beta[i]' : '0.0'}; + for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) / ${dataType}(variance) * gamma[i] + + ${hasBetaInput ? 'beta[i]' : '0.0'}; } }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; diff --git a/onnxruntime/contrib_ops/js/skip_layer_norm.cc b/onnxruntime/contrib_ops/js/skip_layer_norm.cc index ee315f9b31e3b..f949326e1dc95 100644 --- a/onnxruntime/contrib_ops/js/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/js/skip_layer_norm.cc @@ -7,14 +7,16 @@ namespace onnxruntime { namespace contrib { namespace js { +using onnxruntime::js::JsepSupportedFloatTypes; + ONNX_OPERATOR_KERNEL_EX( SkipLayerNormalization, kMSDomain, 1, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("U", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("U", JsepSupportedFloatTypes()), SkipLayerNorm); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/instance_norm.cc b/onnxruntime/core/providers/js/operators/instance_norm.cc index 9d674766a866d..b8e67a69b24d3 100644 --- a/onnxruntime/core/providers/js/operators/instance_norm.cc +++ b/onnxruntime/core/providers/js/operators/instance_norm.cc @@ -6,18 +6,17 @@ namespace onnxruntime { namespace js { -#define INSTANCE_NORM_KERNEL(op_name, domain, data_type, since_version, is_channels_last) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define INSTANCE_NORM_KERNEL(op_name, domain, since_version, is_channels_last) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), \ InstanceNorm); -INSTANCE_NORM_KERNEL(InstanceNormalization, kOnnxDomain, float, 6, false) -INSTANCE_NORM_KERNEL(InstanceNormalization, kMSInternalNHWCDomain, float, 6, true) +INSTANCE_NORM_KERNEL(InstanceNormalization, kOnnxDomain, 6, false) +INSTANCE_NORM_KERNEL(InstanceNormalization, kMSInternalNHWCDomain, 6, true) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/layer_norm.cc b/onnxruntime/core/providers/js/operators/layer_norm.cc index 46baedf5ac7af..9ba379ed09247 100644 --- a/onnxruntime/core/providers/js/operators/layer_norm.cc +++ b/onnxruntime/core/providers/js/operators/layer_norm.cc @@ -8,21 +8,15 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - LayerNormalization, \ - kOnnxDomain, \ - 17, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ - LayerNorm); - -REGISTER_KERNEL_TYPED(float) -// REGISTER_KERNEL_TYPED(double) -// REGISTER_KERNEL_TYPED(MLFloat16) +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()) + .TypeConstraint("U", JsepSupportedFloatTypes()), + LayerNorm); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/layer_norm.h b/onnxruntime/core/providers/js/operators/layer_norm.h index 040fb256ff6e2..791329f3e880d 100644 --- a/onnxruntime/core/providers/js/operators/layer_norm.h +++ b/onnxruntime/core/providers/js/operators/layer_norm.h @@ -8,7 +8,6 @@ namespace onnxruntime { namespace js { -template class LayerNorm : public JsKernel { public: LayerNorm(const OpKernelInfo& info) : JsKernel(info) {