From 3d15b6a5e960542326eda85f6b8f8896275d22bd Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 19 Sep 2023 02:43:32 +0400 Subject: [PATCH] [js/web] FP16 binary and unary ops (#17515) ### Description Binary and unary ops with fp16 support --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 32 +++--- .../core/providers/js/operators/binary.cc | 18 ++-- .../core/providers/js/operators/unary.cc | 100 ++++++++++-------- 3 files changed, 81 insertions(+), 69 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 7e52954734216..f08d7a77d1099 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record): CastAt export const cast = (context: ComputeContext, attributes: CastAttributes): void => { let func: ElementwiseFunctionCall; switch (attributes.to) { + case DataType.float16: + func = 'vec4'; + break; case DataType.float: func = 'vec4'; break; @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey { } export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfoLoader( context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, ` - const clip_min_: vec4 = vec4(f32(${attributes.min})); - const clip_max_: vec4 = vec4(f32(${attributes.max})); + const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); + const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); `, attributes.cacheKey), {inputs: [0]}); @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string) => ` -const r0: f32 = 0.3275911; -const r1: f32 = 0.254829592; -const r2: f32 = -0.284496736; -const r3: f32 = 1.421413741; -const r4: f32 = -1.453152027; -const r5: f32 = 1.061405429; +export const erfImpl = (dataType: string, varType = 'f32') => ` +const r0: ${varType} = 0.3275911; +const r1: ${varType} = 0.254829592; +const r2: ${varType} = -0.284496736; +const r3: ${varType} = 1.421413741; +const r4: ${varType} = -1.453152027; +const r5: ${varType} = 1.061405429; fn erf_vf32(v: ${dataType}) -> ${dataType} { let absv = abs(v); @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { }`; export const erf = (context: ComputeContext): void => { - context.compute( - createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4'))); + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); }; export const exp = (context: ComputeContext): void => { @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => { }; export const gelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfoLoader( context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl('vec4'))); + erfImpl(`vec4<${dataType}>`, dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { diff --git a/onnxruntime/core/providers/js/operators/binary.cc b/onnxruntime/core/providers/js/operators/binary.cc index 98f7ca6e613b0..e61cb1094736d 100644 --- a/onnxruntime/core/providers/js/operators/binary.cc +++ b/onnxruntime/core/providers/js/operators/binary.cc @@ -6,14 +6,13 @@ namespace onnxruntime { namespace js { -#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, \ - kOnnxDomain, \ - VERSION, \ - kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); #define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ @@ -22,8 +21,7 @@ namespace js { kOnnxDomain, \ VERSION_FROM, VERSION_TO, \ kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ - DataTypeImpl::GetTensorType()}), \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \ KERNEL_CLASS); JSEP_KERNEL_IMPL(Add, Add) diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 5e972e43e4566..e9bbfabcf86bd 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -6,22 +6,29 @@ namespace onnxruntime { namespace js { -#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ +#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ KERNEL_CLASS); -#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ + KERNEL_CLASS); + +#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \ KERNEL_CLASS); #define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ ONNX_OPERATOR_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); @@ -29,6 +36,7 @@ namespace js { ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \ KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}), \ KERNEL_CLASS); // math @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg) JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg) JSEP_KERNEL_IMPL(Floor, Floor) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor) -JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor) +JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor) JSEP_KERNEL_IMPL(Ceil, Ceil) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil) -JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil) +JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil) JSEP_KERNEL_IMPL(Reciprocal, Reciprocal) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal) -JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal) +JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal) JSEP_KERNEL_IMPL(Sqrt, Sqrt) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt) -JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt) +JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt) JSEP_KERNEL_IMPL(Exp, Exp) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp) -JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp) +JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp) JSEP_KERNEL_IMPL(Erf, Erf) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf) -JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf) +JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf) JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid) -JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) +JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) JSEP_KERNEL_IMPL(Log, Log) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log) -JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) +JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) JSEP_KERNEL_IMPL(Sin, Sin) -JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin) +JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin) JSEP_KERNEL_IMPL(Cos, Cos) -JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos) +JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos) JSEP_KERNEL_IMPL(Tan, Tan) -JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan) +JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan) JSEP_KERNEL_IMPL(Asin, Asin) -JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin) +JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin) JSEP_KERNEL_IMPL(Acos, Acos) -JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos) +JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos) JSEP_KERNEL_IMPL(Atan, Atan) -JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan) +JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan) JSEP_KERNEL_IMPL(Sinh, Sinh) -JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh) +JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh) JSEP_KERNEL_IMPL(Cosh, Cosh) -JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh) +JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh) JSEP_KERNEL_IMPL(Asinh, Asinh) -JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh) +JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh) JSEP_KERNEL_IMPL(Acosh, Acosh) -JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh) +JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh) JSEP_KERNEL_IMPL(Atanh, Atanh) -JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh) +JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh) JSEP_KERNEL_IMPL(Tanh, Tanh) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh) -JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh) +JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh) JSEP_KERNEL_IMPL(Not, Not) -JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not) +JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2), Clip); JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) +JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu) JSEP_KERNEL_IMPL(Relu, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu) -JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu) +JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01) -JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu) -JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu) +JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu) JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0) -JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu) +JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu) } // namespace js } // namespace onnxruntime