From 7ce75261ff965899c1c620cb8ab35bef7177d1cb Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Oct 2024 09:45:15 -0700 Subject: [PATCH 1/3] [ET-VK][ez] Implement rsqrt TSIA. This op is used in Llama model architecture. Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/) [ghstack-poisoned] --- backends/vulkan/partitioner/supported_ops.py | 1 + backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 2 ++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 6 insertions(+) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 92699be0f8..09759b0d0e 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -76,6 +76,7 @@ def __contains__(self, op): exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten._to_copy.default, # Matrix Multiplication diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 2b9f0032f4..77a334a05e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -32,6 +32,8 @@ unary_op: OPERATOR: sin(X) - NAME: sqrt OPERATOR: sqrt(X) + - NAME: rsqrt + OPERATOR: (1 / sqrt(X)) - NAME: tanh OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: hardshrink diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 64df1aff85..62922e8d9e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg); DEFINE_ACTIVATION_FN(sigmoid); DEFINE_ACTIVATION_FN(sin); DEFINE_ACTIVATION_FN(sqrt); +DEFINE_ACTIVATION_FN(rsqrt); DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); @@ -149,6 +150,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sigmoid.default, sigmoid); VK_REGISTER_OP(aten.sin.default, sin); VK_REGISTER_OP(aten.sqrt.default, sqrt); + VK_REGISTER_OP(aten.rsqrt.default, rsqrt); VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); VK_REGISTER_OP(aten.hardswish.default, hardswish); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 889d3282aa..fb30522209 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -988,6 +988,7 @@ def get_softmax_inputs(): @register_test_suite( [ "aten.sqrt.default", + "aten.rsqrt.default", "aten.exp.default", "aten.hardshrink.default", "aten.sin.default", From 10a879fc8182428b2cbe273d64c917cfd8cfab09 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Oct 2024 10:58:08 -0700 Subject: [PATCH 2/3] Update on "[ET-VK][ez] Implement rsqrt" TSIA. This op is used in Llama model architecture. Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/) [ghstack-poisoned] From 48e36d92f2b6e450b6d49f96646504d066e23f14 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Oct 2024 14:47:06 -0700 Subject: [PATCH 3/3] Update on "[ET-VK][ez] Implement rsqrt" TSIA. This op is used in Llama model architecture. Differential Revision: [D64840505](https://our.internmc.facebook.com/intern/diff/D64840505/) [ghstack-poisoned]