From 814324c4197c51ed644829512e8e074406e292f0 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 11 Nov 2024 13:54:46 -0800 Subject: [PATCH] [ET-VK] Removing tile input storage variable in conv_pw op and fetching the data in main loop. Also unrolling the main loop for performance improvement. This diff removes the tile input storage array in_tex in the conv_pw op and fetches the data in the main loop for performance improvement. The main loop has also been unrolled for performance improvement. Differential Revision: [D64767314](https://our.internmc.facebook.com/intern/diff/D64767314/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 806562d950..fedbdb0b5b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -82,17 +82,15 @@ void main() { // During prepacking, the weight tensor has been permuted so that the // channel (IC) dim is along the x-axis, and the batch (OC) dim is along // the z-axis. - vec4 in_tex[TILE_SIZE * TILE_SIZE]; const vec4 ktex_0 = texelFetch(t_kernel, u16vec2(z + 0, gpos.z), 0); const vec4 ktex_1 = texelFetch(t_kernel, u16vec2(z + 1, gpos.z), 0); const vec4 ktex_2 = texelFetch(t_kernel, u16vec2(z + 2, gpos.z), 0); const vec4 ktex_3 = texelFetch(t_kernel, u16vec2(z + 3, gpos.z), 0); - for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { - in_tex[i] = texelFetch(t_in, u16vec3(ipos[i], z4), 0); - } +#pragma unroll for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { + const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0); // For 2x2 tile size algorithm works as follows. // To explain the calculations below, the contents of one in_tex and the // group of 4 texels loaded from t_kernel are shown: @@ -126,10 +124,10 @@ void main() { // // which is what is expressed in the following calculations. This is done // for each output position. - sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]); - sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]); - sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]); - sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]); + sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]); + sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]); + sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]); + sum[i] = fma(in_tex.wwww, ktex_3, sum[i]); } }