From c749bd997a02c7b49cbdb9569f0286041d19db08 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 6 Jun 2024 08:21:33 -0700 Subject: [PATCH] webgpu quickgelu (#20939) --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 28 +++++++++++ js/web/test/data/ops/quick-gelu.jsonc | 46 +++++++++++++++++++ .../contrib_ops/js/js_contrib_kernels.cc | 2 + onnxruntime/contrib_ops/js/quick_gelu.cc | 23 ++++++++++ onnxruntime/contrib_ops/js/quick_gelu.h | 24 ++++++++++ 7 files changed, 125 insertions(+) create mode 100644 js/web/test/data/ops/quick-gelu.jsonc create mode 100644 onnxruntime/contrib_ops/js/quick_gelu.cc create mode 100644 onnxruntime/contrib_ops/js/quick_gelu.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 3af4942c2e4aa..919b005ec4c21 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -74,6 +74,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| QuickGelu | com.microsoft(1+) | | | Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 2d2f345d0c273..ce5b4455fde60 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -107,6 +107,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad]], ['Pow', [binaryOps.pow]], + ['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 5f105c745739e..12ba2a10cdf9f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -314,3 +314,31 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib export const log = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log')); }; + +export const quickGeluImpl = (varType: string, alpha: number) => ` +const alpha = vec4<${varType}>(${alpha}); +const one = ${varType}(1.0); +const zero = ${varType}(0.0); + +fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> { + let v = x *alpha; + var x1 : vec4<${varType}>; + for (var i = 0; i < 4; i = i + 1) { + if (v[i] >= zero) { + x1[i] = one / (one + exp(-v[i])); + } else { + x1[i] = one - one / (one + exp(v[i])); + } + } + return x * x1; +} +`; + +export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`; + +export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey, + context.inputs[0].dataType)); +}; diff --git a/js/web/test/data/ops/quick-gelu.jsonc b/js/web/test/data/ops/quick-gelu.jsonc new file mode 100644 index 0000000000000..a6e618fe34796 --- /dev/null +++ b/js/web/test/data/ops/quick-gelu.jsonc @@ -0,0 +1,46 @@ +[ + { + "name": "QuickGelu test", + "operator": "QuickGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, -0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.441123, 0.53689, 0.636815], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, -1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.845795, 1.9356, 2.98192, 3.99558, 4.99899, 0.953383, + 1.0622, 1.17178, 1.2817, 1.39166 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 9d8f79c67d8a4..7bc3414c89978 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -16,6 +16,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); @@ -38,6 +39,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/js/quick_gelu.cc b/onnxruntime/contrib_ops/js/quick_gelu.cc new file mode 100644 index 0000000000000..4bb4d5afd4109 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quick_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "quick_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + QuickGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + QuickGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quick_gelu.h b/onnxruntime/contrib_ops/js/quick_gelu.h new file mode 100644 index 0000000000000..51e39e2718d51 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quick_gelu.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class QuickGelu final : public JsKernel { + public: + explicit QuickGelu(const OpKernelInfo& info) : JsKernel(info) { + float alpha = info.GetAttrOrDefault("alpha", 1.0); + JSEP_INIT_KERNEL_ATTRIBUTE(QuickGelu, ({"alpha" : $1}), alpha); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime