Skip to content

Commit

Permalink
[js/web] FP16 binary and unary ops (microsoft#17515)
Browse files Browse the repository at this point in the history
### Description
Binary and unary ops with fp16 support
  • Loading branch information
dakenf authored and kleiti committed Mar 22, 2024
1 parent 857bd74 commit 3d15b6a
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 69 deletions.
32 changes: 19 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float16:
func = 'vec4<f16>';
break;
case DataType.float:
func = 'vec4<f32>';
break;
Expand All @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
Expand Down Expand Up @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string, varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
Expand All @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl('vec4<f32>')));
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down
18 changes: 8 additions & 10 deletions onnxruntime/core/providers/js/operators/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
namespace onnxruntime {
namespace js {

#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
KERNEL_CLASS);

#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
Expand All @@ -22,8 +21,7 @@ namespace js {
kOnnxDomain, \
VERSION_FROM, VERSION_TO, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
KERNEL_CLASS);

JSEP_KERNEL_IMPL(Add, Add)
Expand Down
100 changes: 54 additions & 46 deletions onnxruntime/core/providers/js/operators/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@
namespace onnxruntime {
namespace js {

#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);
// math
Expand All @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg)
JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg)

JSEP_KERNEL_IMPL(Floor, Floor)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor)
JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor)
JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor)

JSEP_KERNEL_IMPL(Ceil, Ceil)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil)
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil)
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil)

JSEP_KERNEL_IMPL(Reciprocal, Reciprocal)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal)
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal)
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal)

JSEP_KERNEL_IMPL(Sqrt, Sqrt)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt)
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt)
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt)

JSEP_KERNEL_IMPL(Exp, Exp)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp)
JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp)
JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp)

JSEP_KERNEL_IMPL(Erf, Erf)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf)
JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf)
JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf)

JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)

JSEP_KERNEL_IMPL(Log, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)

JSEP_KERNEL_IMPL(Sin, Sin)
JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin)
JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin)

JSEP_KERNEL_IMPL(Cos, Cos)
JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos)
JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos)

JSEP_KERNEL_IMPL(Tan, Tan)
JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan)
JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan)

JSEP_KERNEL_IMPL(Asin, Asin)
JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin)
JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin)

JSEP_KERNEL_IMPL(Acos, Acos)
JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos)
JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos)

JSEP_KERNEL_IMPL(Atan, Atan)
JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan)
JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan)

JSEP_KERNEL_IMPL(Sinh, Sinh)
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh)
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh)

JSEP_KERNEL_IMPL(Cosh, Cosh)
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh)
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh)

JSEP_KERNEL_IMPL(Asinh, Asinh)
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh)
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh)

JSEP_KERNEL_IMPL(Acosh, Acosh)
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh)
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh)

JSEP_KERNEL_IMPL(Atanh, Atanh)
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh)
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh)

JSEP_KERNEL_IMPL(Tanh, Tanh)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh)
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh)
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh)

JSEP_KERNEL_IMPL(Not, Not)
JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not)
JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)

// activation

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
JSEP_KERNEL_IMPL(Clip, Clip)
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);
ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu)
JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu)

JSEP_KERNEL_IMPL(Relu, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu)
JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu)
JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu)

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu)
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu)
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu)

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu)
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu)

} // namespace js
} // namespace onnxruntime

0 comments on commit 3d15b6a

Please sign in to comment.