diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 119609e06f5a3..51114d8a99dd1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -132,7 +132,7 @@ const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAt export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` @@ -163,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record): Alpha createAttributeWithCacheKey(attributes as {alpha: number}); export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` - const elu_alpha_: f32 = f32(${attributes.alpha}); + const elu_alpha_ = ${dataType}(${attributes.alpha}); - fn elu_f32(a: f32) -> f32 { + fn elu_f32(a: ${dataType}) -> ${dataType} { return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0); } - fn elu_vf32(v: vec4) -> vec4 { + fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> { return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w)); }`, attributes.cacheKey)); @@ -192,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; @@ -206,16 +207,17 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4(0.0))`, - `const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey)); + context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`, + `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey)); }; export const not = (context: ComputeContext): void => { @@ -231,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => { }; export const relu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Relu', a => `select(vec4(0.0), ${a}, ${a} > vec4(0.0))`)); + context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`)); }; export const sigmoid = (context: ComputeContext): void => { @@ -260,9 +263,10 @@ export const tanh = (context: ComputeContext): void => { }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'ThresholdedRelu', a => `select(vec4(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, - `const thresholded_relu_alpha_: vec4 = vec4(${attributes.alpha});`, attributes.cacheKey)); + context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, + `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey)); return 0; };