Skip to content

Commit

Permalink
[js/webgpu] Fix f16 errors in unary (#18839)
Browse files Browse the repository at this point in the history
### Description
This PR fixes below errors:
```
no matching overload for operator > (vec4<f16>, vec4<f32>)
  • Loading branch information
qjia7 authored Dec 15, 2023
1 parent f52668c commit 4bbed4c
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -132,7 +132,7 @@ const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAt

export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
Expand Down Expand Up @@ -163,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record<string, unknown>): Alpha
createAttributeWithCacheKey(attributes as {alpha: number});

export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
const elu_alpha_: f32 = f32(${attributes.alpha});
const elu_alpha_ = ${dataType}(${attributes.alpha});
fn elu_f32(a: f32) -> f32 {
fn elu_f32(a: ${dataType}) -> ${dataType} {
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
}
fn elu_vf32(v: vec4<f32>) -> vec4<f32> {
fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
}`,
attributes.cacheKey));
Expand All @@ -192,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};
Expand All @@ -206,16 +207,17 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<f32>(0.0))`,
`const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey));
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
`const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey));
};

export const not = (context: ComputeContext): void => {
Expand All @@ -231,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => {
};

export const relu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Relu', a => `select(vec4<f32>(0.0), ${a}, ${a} > vec4<f32>(0.0))`));
context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`));
};

export const sigmoid = (context: ComputeContext): void => {
Expand Down Expand Up @@ -260,9 +263,10 @@ export const tanh = (context: ComputeContext): void => {
};

export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<f32>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`, attributes.cacheKey));
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
`const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey));
return 0;
};

Expand Down

0 comments on commit 4bbed4c

Please sign in to comment.