From 5692203783f6d63448e58b0558380b49b47daad3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 23 Oct 2024 19:23:31 -0700 Subject: [PATCH] [ET-VK][ez] Implement rsqrt (#6472) Pull Request resolved: https://github.com/pytorch/executorch/pull/6456 TSIA. This op is used in Llama model architecture. ghstack-source-id: 249709740 @exported-using-ghexport Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/) Co-authored-by: Stephen Jia --- backends/vulkan/partitioner/supported_ops.py | 1 + backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 2 ++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 6 insertions(+) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 92699be0f8..09759b0d0e 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -76,6 +76,7 @@ def __contains__(self, op): exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten._to_copy.default, # Matrix Multiplication diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 2b9f0032f4..77a334a05e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -32,6 +32,8 @@ unary_op: OPERATOR: sin(X) - NAME: sqrt OPERATOR: sqrt(X) + - NAME: rsqrt + OPERATOR: (1 / sqrt(X)) - NAME: tanh OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: hardshrink diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 64df1aff85..62922e8d9e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg); DEFINE_ACTIVATION_FN(sigmoid); DEFINE_ACTIVATION_FN(sin); DEFINE_ACTIVATION_FN(sqrt); +DEFINE_ACTIVATION_FN(rsqrt); DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); @@ -149,6 +150,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sigmoid.default, sigmoid); VK_REGISTER_OP(aten.sin.default, sin); VK_REGISTER_OP(aten.sqrt.default, sqrt); + VK_REGISTER_OP(aten.rsqrt.default, rsqrt); VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); VK_REGISTER_OP(aten.hardswish.default, hardswish); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 889d3282aa..fb30522209 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -988,6 +988,7 @@ def get_softmax_inputs(): @register_test_suite( [ "aten.sqrt.default", + "aten.rsqrt.default", "aten.exp.default", "aten.hardshrink.default", "aten.sin.default",