From 8140a90c609a5a5e8c3ffcf9e019beef75361de4 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 23 Oct 2024 19:24:03 -0700 Subject: [PATCH] [ET-VK] Implement generic reduction shader + mean, sum, amax, amin (#6473) Pull Request resolved: https://github.com/pytorch/executorch/pull/6457 ## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. ghstack-source-id: 249709743 @exported-using-ghexport Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) --------- Co-authored-by: Stephen Jia --- backends/vulkan/TARGETS | 1 - backends/vulkan/partitioner/supported_ops.py | 7 +- .../vulkan/runtime/graph/ops/glsl/reduce.glsl | 214 ++++++++++++++++++ .../vulkan/runtime/graph/ops/glsl/reduce.yaml | 29 +++ .../runtime/graph/ops/glsl/sum_dim.glsl | 108 --------- .../runtime/graph/ops/glsl/sum_dim.yaml | 16 -- .../graph/ops/glsl/sum_dim_keepdim.glsl | 95 -------- .../graph/ops/glsl/sum_dim_keepdim.yaml | 16 -- .../vulkan/runtime/graph/ops/impl/Reduce.cpp | 123 ++++++++++ .../vulkan/runtime/graph/ops/impl/Sum.cpp | 159 ------------- backends/vulkan/test/op_tests/cases.py | 63 ++++-- backends/vulkan/test/test_vulkan_delegate.py | 9 + backends/vulkan/vulkan_preprocess.py | 2 - 13 files changed, 418 insertions(+), 424 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce.yaml delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Reduce.cpp delete mode 100644 backends/vulkan/runtime/graph/ops/impl/Sum.cpp diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index dd49512c08..4e0e83f276 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -27,7 +27,6 @@ runtime.python_library( "//executorch/backends/transforms:fuse_conv_with_clamp", "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", - "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir:graph_module", diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 09759b0d0e..5a85c5f0ec 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -89,6 +89,10 @@ def __contains__(self, op): # Reduction exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.amin.default, # 2D Pooling exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.max_pool2d_with_indices.default, @@ -101,9 +105,6 @@ def __contains__(self, op): ] NO_DYNAMIC_SHAPE = [ - # Reduction - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, # Normalization exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl new file mode 100644 index 0000000000..7a6263d9f5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "tin_limits")} +${layout_declare_ubo(B, "ivec4", "tin_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; +layout(constant_id = 4) const int reduce_dim = 0; +layout(constant_id = 5) const int group_dim = 1; + +// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of +// threads that will co-operate to compute one reduction output. There may be +// multiple groups computing distinct reduction outputs within one work group. +#define NWORKERS 4 + +// Sets an upper limit on the total size of a work group based on how many +// elements are allocated in the shared memory array below. Each thread in the +// work group will write into its assigned element in the shared array. +#define MAX_NTHREADS 16 + + +shared vec4 shared_vecs[MAX_NTHREADS]; + +#include "indexing_utils.h" + +int tid_to_smi(const ivec2 tid) { + return tid.x + tid.y * NWORKERS; +} + +/* + * The functions below compute reduction along a single dimension for a tensor. + * The shader template generalize reduction by abstracting the initial value of + * the accumulator, the calculation used to update the accumulator with new + * values, and a postprocessing calculation that can be used to modify the + * accumulator before writing to output. + * + * This shader also utilize shared memory to have multiple threads help compute + * the max and sum reduction operations. A total of NGROUPS x NWORKERS threads + * are expected to be launched. Each group works on a unique reduction "row", and + * within a group NWORKERS threads co-operate to compute the max and sum of one + * "row". Each worker in the group is responsible for computing a partial output + * of the "row" and uploading it to shared memory; the overall reduction output + * can then be determined by aggregating the partial outputs stored in shared + * memory. + * + * As a caveat, this shader does not currently support cases where `batch` > 1 + * and the reduce dim happens to also be the batch concatenation dim. To support + * this, there will need to be additional logic to set the starting value of + * `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case, + * supporting this case is left as an exercise for when it is required. + */ + +// Initializing the accumulator accepts the first value in the reduction row, +// since some reduction operations (i.e. amax, amin) prefer to initialize with +// a data point instead of a static value. +#define INIT_ACCUM(first_val) ${INIT_ACCUM} +#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} +// Useful for operators such as mean which want to perform a final calculation +// with the accumulator. +#define POSTPROCESS(accum) ${POSTPROCESS} + +/* + * Computes reduction where the reduction dim is orthogonal to the packed dim. + * This case is simpler because each element of a texel belongs to a separate + * reduction "group", meaning we don't have to perform reduction along a texel. + */ +void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + scan_pos[reduce_dim] = 0; + vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); + + scan_pos[reduce_dim] = tid.x; + // Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of + // the reduction row + for (int i = tid.x; i < tin_sizes[reduce_dim]; + i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + // Write partial output to shared memory and synchronize work group + shared_vecs[smi] = accum; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + // Iterate over the partial outputs to obtain the overall output + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; i++, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + + // Determine if there are any padding elements in the final texel of the + // packed dimension + const int nspill = mod4(tin_sizes[packed_dim]); + // Detect if this thread is working on the final texels of the packed + // dimension, which may have padding elements + const bool is_last_texel = + scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); + + // Explicitly set padding elements to 0 + if (is_last_texel && nspill > 0) { + [[unroll]] for (int i = nspill; i < 4; i++) { + accum[i] = 0; + } + } + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, POSTPROCESS(accum)); + } +} + +/* + * Compute reduction where the reduction dim is also the packed dim. This case is + * complex because the reduction needs to occur over the individual texels. + * Therefore, in this algorithm each element of the accumulator texels are + * themselves partial outputs. Special care has to be taken to ignore padding + * elements in texels (which occur when the size of the packed dim is not a + * multiple of 4) so that they do not influence the output of reduction. + */ +void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + // Number of non-padding elements in the last texel in the reduction row + const int nspill = mod4(tin_sizes[packed_dim]); + // Only reduce up to the last "complete" texel. The last texel will need to be + // handled specially if it has padding elements. + const int reduce_len = tin_sizes[packed_dim] - nspill; + + scan_pos[reduce_dim] = 0; + vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x)); + + // Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of + // the reduction row + scan_pos[reduce_dim] = tid.x; + for (int i = tid.x * 4; i < reduce_len; + i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + // For the last texel in the dim, if there are padding elements then each + // element of the texel needs to be processed individually such that the + // padding elements are ignored + if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) { + const vec4 intex = load_texel(tin, scan_pos); + for (int i = 0; i < nspill; i++) { + accum.x = UPDATE_ACCUM(accum.x, intex[i]); + } + } + // Write partial output to shared memory and synchronize work group + shared_vecs[smi] = accum; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + // Iterate over the partial maximums to obtain the overall maximum + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; i++, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + // Each element of the texel is itself a partial maximum; iterate over the + // texel to find the actual maximum + float accum_final = accum.x; + [[unroll]] for (int i = 1; i < 4; i++) { + accum_final = UPDATE_ACCUM(accum[i], accum_final); + } + + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0))); + } +} + +void main() { + ivec3 scan_pos = ivec3(gl_GlobalInvocationID); + scan_pos[reduce_dim] = 0; + + const ivec2 tid = ivec2( + gl_LocalInvocationID[reduce_dim], + gl_LocalInvocationID[group_dim]); + + if (any(greaterThanEqual(scan_pos, tin_limits))) { + return; + } + + if (reduce_dim != packed_dim) { + reduce_nonpacked_dim(tid, scan_pos); + } else { + reduce_packed_dim(tid, scan_pos); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml new file mode 100644 index 0000000000..21a7132b8d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + INIT_ACCUM: VEC4_T(0) + UPDATE_ACCUM: accum + new_val + POSTPROCESS: accum + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: sum + - NAME: mean + POSTPROCESS: (accum / tin_sizes[reduce_dim]) + - NAME: amax + INIT_ACCUM: first_val + UPDATE_ACCUM: max(accum, new_val) + POSTPROCESS: accum + - NAME: amin + INIT_ACCUM: first_val + UPDATE_ACCUM: min(accum, new_val) + POSTPROCESS: accum diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl b/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl deleted file mode 100644 index 03cd94fb3d..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#include "broadcasting_utils.h" -#include "indexing_utils.h" - -layout(std430) buffer; - -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; - -layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -// dim to sum -layout(set = 0, binding = 3) uniform PRECISION restrict DimVal { - int dim; -}; - -// size of dim (in the input) -layout(set = 0, binding = 4) uniform PRECISION restrict DimSize { - int dim_size; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict Channel { - int flattened_channels; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int packed_dim = C_DIM; - -/* - * Returns a new tensor with values summed along dimension dim - * Dimension dim is squeezed - * For each pos: - * - Iterate over the out_texel and the summed dimension - * - For H,W; rearrange pos.x, pos.y - * - For C,H,W; - * When CHW are summed, batch moves into channel - * The src N is determined by pos.z * 4 + out_index - */ - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { - return; - } - - vec4 out_texel = vec4(0); - - int src_n; - int src_c; - - // Batch - if (dim == 0) { - for (int batch = 0; batch < dim_size; ++batch) { - src_n = batch; - src_c = pos.z; - int src_z = src_n * flattened_channels + src_c; - vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); - out_texel += v; - } - imageStore(image_out, pos, out_texel); - } - - // Channel - else if (dim == 1) { - for (int out_index = 0; out_index < 4; ++out_index) { - for (int channel = 0; channel < dim_size; ++channel) { - src_n = pos.z * 4 + out_index; - src_c = channel; - int src_z = - src_n * flattened_channels + src_c / 4; - vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); - out_texel[out_index] += v[channel % 4]; - } - } - imageStore(image_out, pos, out_texel); - } - - // Height, Width - else { - for (int out_index = 0; out_index < 4; ++out_index) { - src_n = pos.z * 4 + out_index; - src_c = pos.y; - int src_z = src_n * flattened_channels + src_c / 4; - for (int hw = 0; hw < dim_size; ++hw) { - vec4 v = (dim == 2) - ? texelFetch(image_in, ivec3(pos.x, hw, src_z), 0) // Height - : texelFetch(image_in, ivec3(hw, pos.x, src_z), 0); // Width - out_texel[out_index] += v[pos.y % 4]; - } - } - imageStore(image_out, pos, out_texel); - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml b/backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml deleted file mode 100644 index de3fddce88..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -sum_dim: - parameter_names_with_default_values: - NDIM: 3 - DTYPE: float - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: sum_dim diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl b/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl deleted file mode 100644 index 64d37a13e8..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#include "indexing_utils.h" - -layout(std430) buffer; - -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; - -layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -// dim to sum -layout(set = 0, binding = 3) uniform PRECISION restrict DimVal { - int dim; -}; - -// size of dim (in the input) -layout(set = 0, binding = 4) uniform PRECISION restrict DimSize { - int dim_size; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict Channel { - int flattened_channels; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int packed_dim = C_DIM; - -/* - * Returns a new tensor with values summed along dimension dim. - * Output and input have same number of dimensions. - * summed dimension is of size 1. - */ - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { - return; - } - - vec4 out_texel = vec4(0); - - int src_n; - int src_c; - - // Batch - if (dim == 0) { - for (int batch = 0; batch < dim_size; ++batch) { - src_n = batch; - src_c = pos.z; - int src_z = src_n * flattened_channels + src_c; - out_texel += texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); - } - imageStore(image_out, pos, out_texel); - } - - // Channel - else if (dim == 1) { - for (int out_index = 0; out_index < 4; ++out_index) { - for (int channel = 0; channel < dim_size; ++channel) { - src_n = pos.z; - src_c = channel; - int src_z = src_n * flattened_channels + src_c / 4; - vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); - out_texel[out_index] += v[channel % 4]; - } - } - imageStore(image_out, pos, out_texel); - } - - // Height, Width - else { - for (int hw = 0; hw < dim_size; ++hw) { - vec4 v = (dim == 2) - ? texelFetch(image_in, ivec3(pos.x, hw, pos.z), 0) // Height - : texelFetch(image_in, ivec3(hw, pos.y, pos.z), 0); // Width - out_texel += v; - } - imageStore(image_out, pos, out_texel); - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.yaml b/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.yaml deleted file mode 100644 index f74bf229e5..0000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -sum_dim_keepdim: - parameter_names_with_default_values: - NDIM: 3 - DTYPE: float - generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: sum_dim_keepdim diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp new file mode 100644 index 0000000000..9b1cdf824d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +namespace vkcompute { + +using namespace utils; + +void resize_reduce_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + int dim = extra_args[0]; + + std::vector new_sizes = in->sizes(); + new_sizes[normalize(dim, new_sizes.size())] = 1; + out->virtual_resize(new_sizes); +} + +void add_reduce_node( + ComputeGraph& graph, + ValueRef in, + const int dim, + ValueRef out, + const std::string& op_name) { + VK_CHECK_COND( + !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), + "Vulkan reduction only supports texture storage"); + + const int64_t ndim = graph.dim_of(in); + + int32_t reduce_dim = dim; + reduce_dim = normalize(reduce_dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + // Check that the concat dim is not the reduction dim, if the tensor has a + // batch dim greater than 1. + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim); + } + + vkapi::ShaderInfo shader_descriptor; + std::string kernel_name = op_name; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // This should match the value of MAX_NTHREADS in the softmax shader. + constexpr uint32_t max_nthreads = 16; + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[reduce_dim] = 1; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + const int other_dim_1 = (reduce_dim + 1) % 3; + const int other_dim_2 = (reduce_dim + 2) % 3; + int32_t group_dim; + if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { + local_wg_size[other_dim_1] = ngroups; + group_dim = other_dim_1; + } else { + local_wg_size[other_dim_2] = ngroups; + group_dim = other_dim_2; + } + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // shader_descriptor, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, + // Specialization Constants + {graph.packed_dim_of(out), reduce_dim, group_dim}, + // Resizing Logic + resize_reduce_node, + {dim})); +} + +#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + const IntListPtr dims_list = graph.get_int_list(args[1]); \ + VK_CHECK_COND(dims_list->size() == 1); \ + return add_reduce_node( \ + graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \ + } + +DEFINE_REDUCE_FN(sum, 4) +DEFINE_REDUCE_FN(mean, 4) +DEFINE_REDUCE_FN(amax, 3) +DEFINE_REDUCE_FN(amin, 3) + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.sum.dim_IntList, sum); + VK_REGISTER_OP(aten.mean.dim, mean); + VK_REGISTER_OP(aten.amax.default, amax); + VK_REGISTER_OP(aten.amin.default, amin); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp deleted file mode 100644 index 7dd3762ecf..0000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include - -#include -#include -#include - -namespace vkcompute { - -std::vector -calc_out_sizes(api::vTensor& self, int64_t dim, bool keepdim) { - std::vector output_size = self.sizes(); - if (keepdim) { - output_size.at(dim) = 1; - } else { - output_size.erase(output_size.begin() + dim); - } - return output_size; -} - -void resize_sum_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - vTensorPtr out = graph->get_tensor(extra_args[0]); - vTensorPtr in = graph->get_tensor(extra_args[1]); - - const auto dim = extra_args[2]; - const auto keepdim = extra_args[3]; - - std::vector output_size = calc_out_sizes(*in, dim, keepdim); - - out->virtual_resize(output_size); -} - -void check_sum_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); -} - -void add_sum_dim_node( - ComputeGraph& graph, - const ValueRef in, - const int64_t dim, - const bool keepdim, - const ValueRef out) { - vTensorPtr t_out = graph.get_tensor(out); - vTensorPtr t_input = graph.get_tensor(in); - - check_sum_args(*t_input, *t_out); - - int64_t in_dim = t_input->sizes().size(); - int32_t channel = - in_dim > 2 ? static_cast(t_input->sizes()[in_dim - 3]) : 1; - uint32_t dim_size = t_input->sizes()[dim]; - - std::string kernel_name("sum_dim"); - kernel_name.reserve(kShaderNameReserve); - if (keepdim) { - kernel_name += "_keepdim"; - } - add_dtype_suffix(kernel_name, *t_out); - - graph.execute_nodes().emplace_back(new DispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), - // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}}, - // Shader params buffers - {t_out->logical_limits_ubo(), - graph.create_params_buffer(dim + 4 - in_dim), - graph.create_params_buffer(dim_size), - graph.create_params_buffer(int(ceil(channel / 4.0)))}, - // Specialization Constants - {}, - // Resizing Logic - resize_sum_node, - {out, in, static_cast(dim), keepdim})); -} - -ValueRef add_node( - ComputeGraph& graph, - const ValueRef input, - const int dim, - const bool keepdim, - const vkapi::ScalarType dtype = vkapi::kFloat) { - std::vector output_size = - calc_out_sizes(*(graph.get_tensor(input)), dim, keepdim); - return graph.add_tensor(output_size, dtype, utils::kChannelsPacked); -} - -void add_sum_dim_IntList( - ComputeGraph& graph, - const ValueRef in, - const ValueRef opt_dim, - const ValueRef keepdim, - const ValueRef out) { - bool keepdim_val = graph.get_bool(keepdim); - - std::set dims_set; - const auto dims_to_sum = *graph.get_int_list(opt_dim); - int64_t in_dim = graph.get_tensor(in)->sizes().size(); - - if (dims_to_sum.empty()) { - // If dim is not specified, reduce over all dims - for (int64_t i = 0; i < in_dim; ++i) { - dims_set.insert(i); - } - } else { - for (const auto& dim : dims_to_sum) { - // Normalize (negative) dim into range [0, self.dim() - 1] - int64_t dim_normalized = normalize(dim, in_dim); - dims_set.insert(dim_normalized); - } - } - - // Reduce the higher dimensionalities first, otherwise when keepdim is - // false, it will be reducing the wrong dimension. - // We add intermediate nodes before the final output node, so we traverse - // until `std::prev(dims_set.rend())`. The final output node is added after - // the for loop. - ValueRef input = in; - for (auto dim = dims_set.rbegin(); dim != std::prev(dims_set.rend()); ++dim) { - ValueRef tmp_node = add_node(graph, input, *dim, keepdim_val); - add_sum_dim_node(graph, input, *dim, keepdim_val, tmp_node); - input = tmp_node; - } - // We add the final output node. - add_sum_dim_node(graph, input, *dims_set.begin(), keepdim_val, out); -} - -void sum_dim_IntList(ComputeGraph& graph, const std::vector& args) { - // args[3] represents dtype, however serialization of `ScalarType` is not - // supported yet. Since our usecase for this op is always float/half, it's - // removed from parameters for now. - return add_sum_dim_IntList(graph, args[0], args[1], args[2], args[4]); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(aten.sum.dim_IntList, sum_dim_IntList); -} - -} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fb30522209..304636f2fb 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -952,32 +952,35 @@ def get_split_tensor_inputs(): return test_suite +def get_reduce_inputs(is_softmax: bool = False): + bool_arg = False if is_softmax else True + return [ + ((L), 0, bool_arg), + ((L), -1, bool_arg), + ((M, L), 0, bool_arg), + ((M, L), 1, bool_arg), + ((L, M), -1, bool_arg), + ((M, L), -2, bool_arg), + ((S, S1, S2), 0, bool_arg), + ((S, S1, S2), 1, bool_arg), + ((S, S1, S2), 2, bool_arg), + ((S, S1, S2), -1, bool_arg), + ((S, S1, S2), -2, bool_arg), + ((S, S1, S2), -3, bool_arg), + ((1, S, S1, S2), 1, bool_arg), + ((1, S, S1, S2), 2, bool_arg), + ((1, S, S1, S2), 3, bool_arg), + ((1, S, S1, S2), -1, bool_arg), + ((1, S, S1, S2), -2, bool_arg), + ((1, S, S1, S2), -3, bool_arg), + # Test batches > 1 where the reduction dim is not the concat dim + ((S, S2, S1, 128), -1, bool_arg), + ] + + @register_test_suite(["aten._softmax.default", "aten._log_softmax.default"]) def get_softmax_inputs(): - test_suite = VkTestSuite( - [ - ((L), 0, False), - ((L), -1, False), - ((M, L), 0, False), - ((M, L), 1, False), - ((L, M), -1, False), - ((M, L), -2, False), - ((S, S1, S2), 0, False), - ((S, S1, S2), 1, False), - ((S, S1, S2), 2, False), - ((S, S1, S2), -1, False), - ((S, S1, S2), -2, False), - ((S, S1, S2), -3, False), - ((1, S, S1, S2), 1, False), - ((1, S, S1, S2), 2, False), - ((1, S, S1, S2), 3, False), - ((1, S, S1, S2), -1, False), - ((1, S, S1, S2), -2, False), - ((1, S, S1, S2), -3, False), - # Test batches > 1 where the reduction dim is not the concat dim - ((S, S2, S1, 128), -1, False), - ] - ) + test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True)) test_suite.layouts = [ "utils::kWidthPacked", "utils::kChannelsPacked", @@ -985,6 +988,18 @@ def get_softmax_inputs(): return test_suite +@register_test_suite( + ["aten.amax.default", "aten.amin.default", "aten.sum.dim_IntList", "aten.mean.dim"] +) +def get_reduce_op_inputs(): + test_suite = VkTestSuite(get_reduce_inputs()) + test_suite.layouts = [ + "utils::kChannelsPacked", + "utils::kWidthPacked", + ] + return test_suite + + @register_test_suite( [ "aten.sqrt.default", diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f9820f825e..54db1a4b77 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -725,6 +725,9 @@ def forward(self, x): self.lower_module_and_test_output(module, sample_inputs) + @unittest.skip( + "Reduce shader does not support multiple reduction axes at the moment" + ) def test_vulkan_backend_sum_dim_list(self): class SumModule(torch.nn.Module): def __init__(self): @@ -744,6 +747,9 @@ def forward(self, x): memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + @unittest.skip( + "Reduce shader does not support multiple reduction axes at the moment" + ) def test_vulkan_backend_sum(self): class SumModule(torch.nn.Module): def __init__(self): @@ -1441,6 +1447,9 @@ def forward(self, x): self.lower_unary_module_and_test_output(GeluModule()) + @unittest.skip( + "Reduce shader does not support multiple reduction axes at the moment" + ) def test_vulkan_backend_mean(self): class MeanModule(torch.nn.Module): def __init__(self, dims, keepdim=True): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 777a56e364..0e116ad2c4 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -15,7 +15,6 @@ from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform -from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform @@ -65,7 +64,6 @@ def preprocess( # noqa: C901 FuseViewCopyTransform(), FuseBatchNormWithConvPass(program), FuseClampPass(), - MeanToSumDiv(), SpecPropPass(), ConstraintBasedSymShapeEvalPass(), RemoveLocalScalarDenseOpsTransform(),