From d4b3b72de871d44d37dde903ced30aec5d493eb3 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Fri, 29 Sep 2023 19:05:55 +0400 Subject: [PATCH] Fixes for SkipLayerNorm and LayerNorm --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 19 ++++++++- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 41 ++++++++++--------- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 9 ++-- 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index b7e59c60c71b1..22eac96a86b70 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -263,7 +263,24 @@ export const fillVector = (dataType = 'f32', components?: number, value = '0') = return `${dataType}(${value})`; } - return `vec${components}<${dataType}>(${new Array(components).fill(value).join(',')})`; + return `vec${components}<${dataType}>(${value})`; +}; + +/** + * A helper function that casts value or vector to f32 + * @param dataType + * @param components + * @param value + */ +export const castToF32 = (dataType: string, components: number, value: string) => { + if (dataType === 'f32') { + return value; + } + if (components === 1) { + return `f32(${value})`; + } + + return `vec${components}f(${value})`; }; /** diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index f3dda8a701038..40a92f9e0fd69 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,12 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -48,8 +49,9 @@ const createLayerNormProgramInfo = } } + // TODO: for some reason it does not work correctly with fp16 + const components = inputs[0].dataType !== DataType.float16 ? getMaxComponents(normSize) : 1; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const components = getMaxComponents(normSize); const variables = [ inputVariable('x', inputs[0].dataType, inputs[0].dims, components), inputVariable('scale', scale.dataType, scale.dims, components), @@ -63,33 +65,38 @@ const createLayerNormProgramInfo = const hasInvStdOutput = outputCount > 2; if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', inputs[0].dataType, meanInvStdDevDim)); + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); } if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', inputs[0].dataType, meanInvStdDevDim)); + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); } const getShaderSource = (shaderHelper: ShaderHelper) => ` const normSize: u32 = ${normSize / components}; - const normSizeTyped: ${dataType} = ${normSize}; - const epsilon: ${dataType} = ${attributes.epsilon}; + const epsilon: f32 = ${attributes.epsilon}; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} let offset = global_idx * normSize; - var meanVector = ${fillVector(dataType, components)}; - var meanSquareVector = ${fillVector(dataType, components)}; + var meanVector = ${fillVector('f32', components)}; + var meanSquareVector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < normSize; h++) { - meanVector += x[h + offset]; - meanSquareVector += x[h + offset] * x[h + offset]; + let value = ${castToF32(dataType, components, 'x[h + offset]')}; + meanVector += value; + meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSizeTyped; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} / normSizeTyped - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / f32(normSize); + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / f32(normSize) - mean * mean + epsilon); for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''}; + let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; + let f32scale = ${castToF32(dataType, components, 'scale[j]')}; + output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} + ); } ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; @@ -97,14 +104,10 @@ const createLayerNormProgramInfo = }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; if (hasMeanDataOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (hasInvStdOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index ec91b9979898b..268ce5307eb74 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -108,7 +108,6 @@ const createSkipLayerNormProgramInfo = variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); } const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const getShaderSource = (shaderHelper: ShaderHelper) => ` const hiddenSize: u32 = ${hiddenSize}; const hiddenSizeVectorized: u32 = ${hiddenSize / components}; @@ -128,7 +127,7 @@ const createSkipLayerNormProgramInfo = let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32}(value); + let f32Value = ${castToF32(dataType, components, 'value')}; sum += f32Value; squareSum += f32Value * f32Value; } @@ -143,10 +142,10 @@ const createSkipLayerNormProgramInfo = }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default});