diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 5378099d03..59d6aecdc1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -19,15 +19,9 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 3) uniform PRECISION restrict Sizes { +layout(push_constant) uniform PRECISION restrict Block { + ivec4 out_limits; ivec4 sizes; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict Block { // output dims ivec4 out_ndims; // x = output channels aligned to 4, y = input channels aligned to 4 @@ -41,7 +35,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const u16vec3 pos = u16vec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_limits))) { + if (any(greaterThanEqual(pos, out_limits.xyz))) { return; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index c107f288f3..a56925751e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -75,13 +75,7 @@ void add_permute_node( int32_t out_c_aligned = utils::align_up_4(out_channels); int32_t in_c_aligned = utils::align_up_4(in_channels); - const struct Block final { - ivec4 out_ndims; - ivec2 ch_info; - } params{ - out_dims, - {out_c_aligned, in_c_aligned}, - }; + const ivec2 ch_info = {out_c_aligned, in_c_aligned}; graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -90,14 +84,16 @@ void add_permute_node( graph.create_local_wg_size(out), {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, - {t_out->logical_limits_ubo(), - t_out->sizes_ubo(), - graph.create_params_buffer(params)}, + {}, // Specialization Constants {}, // Resizing Logic nullptr, - {})); + {}, + {{graph.logical_limits_pc_of(out), + graph.sizes_pc_of(out), + PushConstantDataInfo(&out_dims, sizeof(out_dims)), + PushConstantDataInfo(&ch_info, sizeof(ch_info))}})); } void add_permute_node(