Skip to content

Commit

Permalink
[js/web] FP16 binary and unary ops (microsoft#17515)
Browse files Browse the repository at this point in the history
### Description
Binary and unary ops with fp16 support
  • Loading branch information
dakenf authored Sep 18, 2023
1 parent f58c425 commit b8b218e
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float16:
func = 'vec4<f16>';
break;
case DataType.float:
func = 'vec4<f32>';
break;
Expand All @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
Expand Down Expand Up @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string, varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
Expand All @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl('vec4<f32>')));
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down

0 comments on commit b8b218e

Please sign in to comment.