From 97b5a4a33d33226d9cc28b3e62cd85219f91902d Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:01:49 -0800 Subject: [PATCH] [ET-VK] Rearranging code in permute op shader to reduce heavy math ops and improve performance. Pull Request resolved: https://github.com/pytorch/executorch/pull/7014 The diff rearranges Permute op shader code in executorch vulkan backend to reduce heavy math operations and improve performance. The change also include adding an extension to support explicit arithmetic types and changing the data type of the position variable to u16vec3. ghstack-source-id: 255546339 @exported-using-ghexport Differential Revision: [D66174765](https://our.internmc.facebook.com/intern/diff/D66174765/) --- .../runtime/graph/ops/glsl/permute.glsl | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 8414d811fc..5378099d03 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -36,8 +36,10 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const u16vec3 pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(pos, out_limits))) { return; @@ -46,28 +48,34 @@ void main() { const int out_channel_4up = int(ch_info.x); const int in_channel_4up = int(ch_info.y); const int out_batch = int(sizes[3]); - const int max_dst_index = out_batch * out_channel_4up; VEC4_T outval = VEC4_T(0.0); + ivec4 v = ivec4(0); // holds b,c,h,w + + v[out_ndims[2]] = pos.y; + v[out_ndims[3]] = pos.x; + + const int dst_index = pos.z << 2; + int dst_out_index = dst_index / out_channel_4up; + int dst_out_lane = dst_index % out_channel_4up; - for (int j = 0; j < 4; ++j) { - int dst_index = pos.z * 4 + j; - if (dst_index >= max_dst_index) { + for (int j = 0; j < 4; ++j, ++dst_out_lane) { + if (dst_out_index >= out_batch) { // out of range break; } - ivec4 v = ivec4(0); // holds b,c,h,w - v[out_ndims[0]] = dst_index / out_channel_4up; - v[out_ndims[1]] = dst_index % out_channel_4up; - v[out_ndims[2]] = pos.y; - v[out_ndims[3]] = pos.x; + if (dst_out_lane == out_channel_4up) { + dst_out_lane = 0; + dst_out_index++; + } + + v[out_ndims[0]] = dst_out_index; + v[out_ndims[1]] = dst_out_lane; int src_index = v[0] * in_channel_4up + v[1]; - int w = v[3]; - int h = v[2]; - VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0)); - outval[j] = inval[src_index % 4]; + VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0)); + outval[j] = inval[src_index & 0x3]; } imageStore(image_out, pos, outval);