diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 92699be0f8..09759b0d0e 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -76,6 +76,7 @@ def __contains__(self, op): exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten._to_copy.default, # Matrix Multiplication diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 2b9f0032f4..77a334a05e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -32,6 +32,8 @@ unary_op: OPERATOR: sin(X) - NAME: sqrt OPERATOR: sqrt(X) + - NAME: rsqrt + OPERATOR: (1 / sqrt(X)) - NAME: tanh OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: hardshrink diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 64df1aff85..62922e8d9e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -129,6 +129,7 @@ DEFINE_ACTIVATION_FN(neg); DEFINE_ACTIVATION_FN(sigmoid); DEFINE_ACTIVATION_FN(sin); DEFINE_ACTIVATION_FN(sqrt); +DEFINE_ACTIVATION_FN(rsqrt); DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); @@ -149,6 +150,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sigmoid.default, sigmoid); VK_REGISTER_OP(aten.sin.default, sin); VK_REGISTER_OP(aten.sqrt.default, sqrt); + VK_REGISTER_OP(aten.rsqrt.default, rsqrt); VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); VK_REGISTER_OP(aten.hardswish.default, hardswish); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 889d3282aa..fb30522209 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -988,6 +988,7 @@ def get_softmax_inputs(): @register_test_suite( [ "aten.sqrt.default", + "aten.rsqrt.default", "aten.exp.default", "aten.hardshrink.default", "aten.sin.default",