Skip to content

Commit

Permalink
[ET-VK][ez] Implement rsqrt (#6472)
Browse files Browse the repository at this point in the history
Pull Request resolved: #6456

TSIA. This op is used in Llama model architecture.
ghstack-source-id: 249709740
@exported-using-ghexport

Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/)

Co-authored-by: Stephen Jia <[email protected]>
  • Loading branch information
pytorchbot and SS-JIA authored Oct 24, 2024
1 parent cb25809 commit 5692203
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __contains__(self, op):
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten._to_copy.default,
# Matrix Multiplication
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ unary_op:
OPERATOR: sin(X)
- NAME: sqrt
OPERATOR: sqrt(X)
- NAME: rsqrt
OPERATOR: (1 / sqrt(X))
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
- NAME: hardshrink
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg);
DEFINE_ACTIVATION_FN(sigmoid);
DEFINE_ACTIVATION_FN(sin);
DEFINE_ACTIVATION_FN(sqrt);
DEFINE_ACTIVATION_FN(rsqrt);
DEFINE_ACTIVATION_FN(tanh);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
Expand All @@ -149,6 +150,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.sin.default, sin);
VK_REGISTER_OP(aten.sqrt.default, sqrt);
VK_REGISTER_OP(aten.rsqrt.default, rsqrt);
VK_REGISTER_OP(aten.tanh.default, tanh);
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
VK_REGISTER_OP(aten.hardswish.default, hardswish);
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ def get_softmax_inputs():
@register_test_suite(
[
"aten.sqrt.default",
"aten.rsqrt.default",
"aten.exp.default",
"aten.hardshrink.default",
"aten.sin.default",
Expand Down

0 comments on commit 5692203

Please sign in to comment.