Skip to content

Commit

Permalink
[ET-VK][ez] Implement rsqrt
Browse files Browse the repository at this point in the history
Differential Revision: D64840505

Pull Request resolved: #6456
  • Loading branch information
SS-JIA authored Oct 23, 2024
1 parent 169ddbf commit 9a110cd
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __contains__(self, op):
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten._to_copy.default,
# Matrix Multiplication
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ unary_op:
OPERATOR: sin(X)
- NAME: sqrt
OPERATOR: sqrt(X)
- NAME: rsqrt
OPERATOR: (1 / sqrt(X))
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
- NAME: hardshrink
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg);
DEFINE_ACTIVATION_FN(sigmoid);
DEFINE_ACTIVATION_FN(sin);
DEFINE_ACTIVATION_FN(sqrt);
DEFINE_ACTIVATION_FN(rsqrt);
DEFINE_ACTIVATION_FN(tanh);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
Expand All @@ -149,6 +150,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.sin.default, sin);
VK_REGISTER_OP(aten.sqrt.default, sqrt);
VK_REGISTER_OP(aten.rsqrt.default, rsqrt);
VK_REGISTER_OP(aten.tanh.default, tanh);
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
VK_REGISTER_OP(aten.hardswish.default, hardswish);
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ def get_softmax_inputs():
@register_test_suite(
[
"aten.sqrt.default",
"aten.rsqrt.default",
"aten.exp.default",
"aten.hardshrink.default",
"aten.sin.default",
Expand Down

0 comments on commit 9a110cd

Please sign in to comment.