Skip to content

Commit

Permalink
[ET-VK] Rearranging code in permute op shader to reduce heavy math op…
Browse files Browse the repository at this point in the history
…s and improve performance. (#7095)

Pull Request resolved: #7014

The diff rearranges Permute op shader code in executorch vulkan backend to reduce heavy math operations and improve performance. The change also include adding an extension to support explicit arithmetic types and changing the data type of the position variable to u16vec3.
ghstack-source-id: 255546339
@exported-using-ghexport

Differential Revision: [D66174765](https://our.internmc.facebook.com/intern/diff/D66174765/)

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Nov 26, 2024
1 parent b8fbc48 commit b9a1762
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions backends/vulkan/runtime/graph/ops/glsl/permute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ layout(set = 0, binding = 4) uniform PRECISION restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
Expand All @@ -46,28 +48,34 @@ void main() {
const int out_channel_4up = int(ch_info.x);
const int in_channel_4up = int(ch_info.y);
const int out_batch = int(sizes[3]);
const int max_dst_index = out_batch * out_channel_4up;
VEC4_T outval = VEC4_T(0.0);
ivec4 v = ivec4(0); // holds b,c,h,w

v[out_ndims[2]] = pos.y;
v[out_ndims[3]] = pos.x;

const int dst_index = pos.z << 2;
int dst_out_index = dst_index / out_channel_4up;
int dst_out_lane = dst_index % out_channel_4up;

for (int j = 0; j < 4; ++j) {
int dst_index = pos.z * 4 + j;
if (dst_index >= max_dst_index) {
for (int j = 0; j < 4; ++j, ++dst_out_lane) {
if (dst_out_index >= out_batch) {
// out of range
break;
}

ivec4 v = ivec4(0); // holds b,c,h,w
v[out_ndims[0]] = dst_index / out_channel_4up;
v[out_ndims[1]] = dst_index % out_channel_4up;
v[out_ndims[2]] = pos.y;
v[out_ndims[3]] = pos.x;
if (dst_out_lane == out_channel_4up) {
dst_out_lane = 0;
dst_out_index++;
}

v[out_ndims[0]] = dst_out_index;
v[out_ndims[1]] = dst_out_lane;

int src_index = v[0] * in_channel_4up + v[1];
int w = v[3];
int h = v[2];

VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0));
outval[j] = inval[src_index % 4];
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(v[3], v[2], src_index >> 2), 0));
outval[j] = inval[src_index & 0x3];
}

imageStore(image_out, pos, outval);
Expand Down

0 comments on commit b9a1762

Please sign in to comment.