From 3e174c24cd3f4dd7275f4afd406346645f0ff493 Mon Sep 17 00:00:00 2001 From: trivedivivek <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:18:57 -0600 Subject: [PATCH] [ET-VK] Reduced int precision for texture coordinates in conv2d_pw op, to reduce shader register pressure and slightly improve performance. Differential Revision: D64766910 Pull Request resolved: https://github.com/pytorch/executorch/pull/6764 --- .../runtime/graph/ops/glsl/conv2d_pw.glsl | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 9621a3b600..806562d950 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -32,13 +32,15 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const ivec3 gpos = ivec3(gl_GlobalInvocationID); + const u16vec3 gpos = u16vec3(gl_GlobalInvocationID); // Output position for TILE_SIZE = 2 // +--------+--------+ @@ -46,10 +48,10 @@ void main() { // +--------+--------+ // | pos[2] | pos[3] | // +--------+--------+ - ivec3 pos[TILE_SIZE * TILE_SIZE]; + u16vec3 pos[TILE_SIZE * TILE_SIZE]; for (int y = 0, i = 0; y < TILE_SIZE; ++y) { for (int x = 0; x < TILE_SIZE; ++x) { - pos[i] = ivec3( + pos[i] = u16vec3( gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z); i++; } @@ -64,13 +66,13 @@ void main() { // Compute the index of the input texture that needs to be loaded for each // output position. Note that negative indices can be produced indicating that // the top-left element is in a region added by padding. - ivec2 ipos[TILE_SIZE * TILE_SIZE]; + u16vec2 ipos[TILE_SIZE * TILE_SIZE]; for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - ipos[i] = pos[i].xy * stride - padding; + ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding); } vec4 sum[TILE_SIZE * TILE_SIZE]; - sum[0] = texelFetch(t_bias, ivec2(gpos.z, 0), 0); + sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0); for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) { sum[i] = sum[0]; } @@ -81,13 +83,13 @@ void main() { // channel (IC) dim is along the x-axis, and the batch (OC) dim is along // the z-axis. vec4 in_tex[TILE_SIZE * TILE_SIZE]; - const vec4 ktex_0 = texelFetch(t_kernel, ivec2(z + 0, gpos.z), 0); - const vec4 ktex_1 = texelFetch(t_kernel, ivec2(z + 1, gpos.z), 0); - const vec4 ktex_2 = texelFetch(t_kernel, ivec2(z + 2, gpos.z), 0); - const vec4 ktex_3 = texelFetch(t_kernel, ivec2(z + 3, gpos.z), 0); + const vec4 ktex_0 = texelFetch(t_kernel, u16vec2(z + 0, gpos.z), 0); + const vec4 ktex_1 = texelFetch(t_kernel, u16vec2(z + 1, gpos.z), 0); + const vec4 ktex_2 = texelFetch(t_kernel, u16vec2(z + 2, gpos.z), 0); + const vec4 ktex_3 = texelFetch(t_kernel, u16vec2(z + 3, gpos.z), 0); for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - in_tex[i] = texelFetch(t_in, ivec3(ipos[i], z4), 0); + in_tex[i] = texelFetch(t_in, u16vec3(ipos[i], z4), 0); } for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {