Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] FP16 binary and unary ops #17515

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';

type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
Expand Down Expand Up @@ -101,6 +101,9 @@ export const parseCastAttributes = (attributes: Record<string, unknown>): CastAt
export const cast = (context: ComputeContext, attributes: CastAttributes): void => {
let func: ElementwiseFunctionCall;
switch (attributes.to) {
case DataType.float16:
func = 'vec4<f16>';
break;
case DataType.float:
func = 'vec4<f32>';
break;
Expand All @@ -126,11 +129,12 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfoLoader(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
const clip_min_: vec4<f32> = vec4(f32(${attributes.min}));
const clip_max_: vec4<f32> = vec4(f32(${attributes.max}));
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
attributes.cacheKey),
{inputs: [0]});
Expand Down Expand Up @@ -180,13 +184,13 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};

export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string, varType = 'f32') => `
const r0: ${varType} = 0.3275911;
const r1: ${varType} = 0.254829592;
const r2: ${varType} = -0.284496736;
const r3: ${varType} = 1.421413741;
const r4: ${varType} = -1.453152027;
const r5: ${varType} = 1.061405429;

fn erf_vf32(v: ${dataType}) -> ${dataType} {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
let absv = abs(v);
Expand All @@ -195,8 +199,9 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;

export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};

export const exp = (context: ComputeContext): void => {
Expand All @@ -208,9 +213,10 @@ export const floor = (context: ComputeContext): void => {
};

export const gelu = (context: ComputeContext): void => {
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl('vec4<f32>')));
erfImpl(`vec4<${dataType}>`, dataType)));
};

export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
Expand Down
18 changes: 8 additions & 10 deletions onnxruntime/core/providers/js/operators/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
namespace onnxruntime {
namespace js {

#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
KERNEL_CLASS);

#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
Expand All @@ -22,8 +21,7 @@ namespace js {
kOnnxDomain, \
VERSION_FROM, VERSION_TO, \
kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KernelDefBuilder().TypeConstraint("T", JsepSupportedDataTypes()), \
KERNEL_CLASS);

JSEP_KERNEL_IMPL(Add, Add)
Expand Down
100 changes: 54 additions & 46 deletions onnxruntime/core/providers/js/operators/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,37 @@
namespace onnxruntime {
namespace js {

#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
#define JSEP_ELEMENTWISE_TYPED_KERNEL(OP_TYPE, VERSION, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, TYPE, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
#define JSEP_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);

#define JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kJsExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(), \
DataTypeImpl::GetTensorType<MLFloat16>(), \
DataTypeImpl::GetTensorType<int32_t>()}), \
KERNEL_CLASS);
// math
Expand All @@ -42,115 +50,115 @@ JSEP_ELEMENTWISE_MULTI_TYPED_VERSIONED_KERNEL(Neg, 6, 12, Neg)
JSEP_ELEMENTWISE_MULTI_TYPED_KERNEL(Neg, 13, Neg)

JSEP_KERNEL_IMPL(Floor, Floor)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, float, Floor)
JSEP_ELEMENTWISE_KERNEL(Floor, 13, float, Floor)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, Floor)
JSEP_ELEMENTWISE_KERNEL(Floor, 13, Floor)

JSEP_KERNEL_IMPL(Ceil, Ceil)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, float, Ceil)
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, float, Ceil)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, Ceil)
JSEP_ELEMENTWISE_KERNEL(Ceil, 13, Ceil)

JSEP_KERNEL_IMPL(Reciprocal, Reciprocal)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, float, Reciprocal)
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, float, Reciprocal)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, Reciprocal)
JSEP_ELEMENTWISE_KERNEL(Reciprocal, 13, Reciprocal)

JSEP_KERNEL_IMPL(Sqrt, Sqrt)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, float, Sqrt)
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, float, Sqrt)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, Sqrt)
JSEP_ELEMENTWISE_KERNEL(Sqrt, 13, Sqrt)

JSEP_KERNEL_IMPL(Exp, Exp)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, float, Exp)
JSEP_ELEMENTWISE_KERNEL(Exp, 13, float, Exp)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, Exp)
JSEP_ELEMENTWISE_KERNEL(Exp, 13, Exp)

JSEP_KERNEL_IMPL(Erf, Erf)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, float, Erf)
JSEP_ELEMENTWISE_KERNEL(Erf, 13, float, Erf)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, Erf)
JSEP_ELEMENTWISE_KERNEL(Erf, 13, Erf)

JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, float, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, float, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)

JSEP_KERNEL_IMPL(Log, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, float, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, float, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)

JSEP_KERNEL_IMPL(Sin, Sin)
JSEP_ELEMENTWISE_KERNEL(Sin, 7, float, Sin)
JSEP_ELEMENTWISE_KERNEL(Sin, 7, Sin)

JSEP_KERNEL_IMPL(Cos, Cos)
JSEP_ELEMENTWISE_KERNEL(Cos, 7, float, Cos)
JSEP_ELEMENTWISE_KERNEL(Cos, 7, Cos)

JSEP_KERNEL_IMPL(Tan, Tan)
JSEP_ELEMENTWISE_KERNEL(Tan, 7, float, Tan)
JSEP_ELEMENTWISE_KERNEL(Tan, 7, Tan)

JSEP_KERNEL_IMPL(Asin, Asin)
JSEP_ELEMENTWISE_KERNEL(Asin, 7, float, Asin)
JSEP_ELEMENTWISE_KERNEL(Asin, 7, Asin)

JSEP_KERNEL_IMPL(Acos, Acos)
JSEP_ELEMENTWISE_KERNEL(Acos, 7, float, Acos)
JSEP_ELEMENTWISE_KERNEL(Acos, 7, Acos)

JSEP_KERNEL_IMPL(Atan, Atan)
JSEP_ELEMENTWISE_KERNEL(Atan, 7, float, Atan)
JSEP_ELEMENTWISE_KERNEL(Atan, 7, Atan)

JSEP_KERNEL_IMPL(Sinh, Sinh)
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, float, Sinh)
JSEP_ELEMENTWISE_KERNEL(Sinh, 9, Sinh)

JSEP_KERNEL_IMPL(Cosh, Cosh)
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, float, Cosh)
JSEP_ELEMENTWISE_KERNEL(Cosh, 9, Cosh)

JSEP_KERNEL_IMPL(Asinh, Asinh)
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, float, Asinh)
JSEP_ELEMENTWISE_KERNEL(Asinh, 9, Asinh)

JSEP_KERNEL_IMPL(Acosh, Acosh)
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, float, Acosh)
JSEP_ELEMENTWISE_KERNEL(Acosh, 9, Acosh)

JSEP_KERNEL_IMPL(Atanh, Atanh)
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, float, Atanh)
JSEP_ELEMENTWISE_KERNEL(Atanh, 9, Atanh)

JSEP_KERNEL_IMPL(Tanh, Tanh)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, float, Tanh)
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, float, Tanh)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, Tanh)
JSEP_ELEMENTWISE_KERNEL(Tanh, 13, Tanh)

JSEP_KERNEL_IMPL(Not, Not)
JSEP_ELEMENTWISE_KERNEL(Not, 1, bool, Not)
JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)

// activation

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, float, ClipV10)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
JSEP_KERNEL_IMPL(Clip, Clip)
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);
ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2),
Clip);

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu)
JSEP_ELEMENTWISE_KERNEL(Elu, 6, Elu)

JSEP_KERNEL_IMPL(Relu, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu)
JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, Relu)
JSEP_ELEMENTWISE_KERNEL(Relu, 14, Relu)

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu)
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, LeakyRelu)
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, LeakyRelu)

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu)
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, ThresholdedRelu)

} // namespace js
} // namespace onnxruntime