Skip to content

Commit

Permalink
webgpu quickgelu (#20939)
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue authored Jun 6, 2024
1 parent 5b87544 commit c749bd9
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 0 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Do not modify directly.*
| Not | ai.onnx(1+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
| Pow | ai.onnx(7-11,12,13-14,15+) | |
| QuickGelu | com.microsoft(1+) | |
| Range | ai.onnx(11+) | |
| Reciprocal | ai.onnx(6-12,13+) | |
| ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | |
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Not', [unaryOps.not]],
['Pad', [pad]],
['Pow', [binaryOps.pow]],
['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
['ReduceMin', [reduceMin]],
Expand Down
28 changes: 28 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,31 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib
export const log = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log'));
};

export const quickGeluImpl = (varType: string, alpha: number) => `
const alpha = vec4<${varType}>(${alpha});
const one = ${varType}(1.0);
const zero = ${varType}(0.0);
fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> {
let v = x *alpha;
var x1 : vec4<${varType}>;
for (var i = 0; i < 4; i = i + 1) {
if (v[i] >= zero) {
x1[i] = one / (one + exp(-v[i]));
} else {
x1[i] = one - one / (one + exp(v[i]));
}
}
return x * x1;
}
`;

export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`;

export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
const dType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey,
context.inputs[0].dataType));
};
46 changes: 46 additions & 0 deletions js/web/test/data/ops/quick-gelu.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
[
{
"name": "QuickGelu test",
"operator": "QuickGelu",
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "[2x4]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, -0.8],
"dims": [2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.441123, 0.53689, 0.636815],
"dims": [2, 4],
"type": "float32"
}
]
},
{
"name": "[3x5]",
"inputs": [
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, -1.5],
"dims": [3, 5],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.845795, 1.9356, 2.98192, 3.99558, 4.99899, 0.953383,
1.0622, 1.17178, 1.2817, 1.39166
],
"dims": [3, 5],
"type": "float32"
}
]
}
]
}
]
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/js/js_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
Expand All @@ -38,6 +39,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
SkipLayerNormalization)>,
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/js/quick_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "quick_gelu.h"

namespace onnxruntime {
namespace contrib {
namespace js {

using onnxruntime::js::JsepSupportedFloatTypes;

ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
1,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", JsepSupportedFloatTypes()),
QuickGelu);

} // namespace js
} // namespace contrib
} // namespace onnxruntime
24 changes: 24 additions & 0 deletions onnxruntime/contrib_ops/js/quick_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/js/js_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace js {

using onnxruntime::js::JsKernel;

class QuickGelu final : public JsKernel {
public:
explicit QuickGelu(const OpKernelInfo& info) : JsKernel(info) {
float alpha = info.GetAttrOrDefault<float>("alpha", 1.0);
JSEP_INIT_KERNEL_ATTRIBUTE(QuickGelu, ({"alpha" : $1}), alpha);
}
};

} // namespace js
} // namespace contrib
} // namespace onnxruntime

0 comments on commit c749bd9

Please sign in to comment.