From a3f0b000482892559456624ccb0b75fda3d3b595 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 23 Oct 2024 09:45:18 -0700 Subject: [PATCH] [ET-VK] Implement generic reduction shader + mean, sum, amax, amin ## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/) [ghstack-poisoned] --- backends/vulkan/partitioner/supported_ops.py | 4 + .../vulkan/runtime/graph/ops/glsl/reduce.glsl | 214 ++++++++++++++++++ .../vulkan/runtime/graph/ops/glsl/reduce.yaml | 29 +++ .../vulkan/runtime/graph/ops/impl/Reduce.cpp | 123 ++++++++++ backends/vulkan/test/op_tests/cases.py | 63 ++++-- .../test/op_tests/utils/gen_correctness_vk.py | 2 + 6 files changed, 411 insertions(+), 24 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/reduce.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Reduce.cpp diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 09759b0d0e..3d2acc6a08 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -89,6 +89,10 @@ def __contains__(self, op): # Reduction exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.amin.default, # 2D Pooling exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.max_pool2d_with_indices.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl new file mode 100644 index 0000000000..b0de34e68c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce.glsl @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "tin_limits")} +${layout_declare_ubo(B, "ivec4", "tin_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; +layout(constant_id = 4) const int reduce_dim = 0; +layout(constant_id = 5) const int group_dim = 1; + +// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of +// threads that will co-operate to compute one reduction output. There may be +// multiple groups computing distinct reduction outputs within one work group. +#define NWORKERS 4 + +// Sets an upper limit on the total size of a work group based on how many +// elements are allocated in the shared memory array below. Each thread in the +// work group will write into its assigned element in the shared array. +#define MAX_NTHREADS 16 + + +shared vec4 shared_vecs[MAX_NTHREADS]; + +#include "indexing_utils.h" + +int tid_to_smi(const ivec2 tid) { + return tid.x + tid.y * NWORKERS; +} + +/* + * The functions below compute reduction along a single dimension for a tensor. + * The shader template generalize reduction by abstracting the initial value of + * the accumulator, the calculation used to update the accumulator with new + * values, and a postprocessing calculation that can be used to modify the + * accumulator before writing to output. + * + * This shader also utilize shared memory to have multiple threads help compute + * the max and sum reduction operations. A total of NGROUPS x NWORKERS threads + * are expected to be launched. Each group works on a unique reduction "row", and + * within a group NWORKERS threads co-operate to compute the max and sum of one + * "row". Each worker in the group is responsible for computing a partial output + * of the "row" and uploading it to shared memory; the overall reduction output + * can then be determined by aggregating the partial outputs stored in shared + * memory. + * + * As a caveat, this shader does not currently support cases where `batch` > 1 + * and the reduce dim happens to also be the batch concatenation dim. To support + * this, there will need to be additional logic to set the starting value of + * `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case, + * supporting this case is left as an exercise for when it is required. + */ + +// Initializing the accumulator accepts the first value in the reduction row, +// since some reduction operations (i.e. amax, amin) prefer to initialize with +// a data point instead of a static value. +#define INIT_ACCUM(first_val) ${INIT_ACCUM} +#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} +// Useful for operators such as mean which want to perform a final calculation +// with the accumulator. +#define POSTPROCESS(accum) ${POSTPROCESS} + +/* + * Computes reduction where the reduction dim is orthogonal to the packed dim. + * This case is simpler because each element of a texel belongs to a separate + * reduction "group", meaning we don't have to perform reduction along a texel. + */ +void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + scan_pos[reduce_dim] = 0; + vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); + + scan_pos[reduce_dim] = tid.x; + // Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of + // the reduction row + for (int i = tid.x; i < tin_sizes[reduce_dim]; + i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + // Write partial output to shared memory and synchronize work group + shared_vecs[smi] = accum; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + // Iterate over the partial outputs to obtain the overall output + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; ++i, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + + // Determine if there are any padding elements in the final texel of the + // packed dimension + const int nspill = mod4(tin_sizes[packed_dim]); + // Detect if this thread is working on the final texels of the packed + // dimension, which may have padding elements + const bool is_last_texel = + scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); + + // Explicitly set padding elements to 0 + if (is_last_texel && nspill > 0) { + [[unroll]] for (int i = nspill; i < 4; ++i) { + accum[i] = 0; + } + } + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, POSTPROCESS(accum)); + } +} + +/* + * Compute reduction where the reduction dim is also the packed dim. This case is + * complex because the reduction needs to occur over the individual texels. + * Therefore, in this algorithm each element of the accumulator texels are + * themselves partial outputs. Special care has to be taken to ignore padding + * elements in texels (which occur when the size of the packed dim is not a + * multiple of 4) so that they do not influence the output of reduction. + */ +void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + // Number of non-padding elements in the last texel in the reduction row + const int nspill = mod4(tin_sizes[packed_dim]); + // Only reduce up to the last "complete" texel. The last texel will need to be + // handled specially if it has padding elements. + const int reduce_len = tin_sizes[packed_dim] - nspill; + + scan_pos[reduce_dim] = 0; + vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x)); + + // Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of + // the reduction row + scan_pos[reduce_dim] = tid.x; + for (int i = tid.x * 4; i < reduce_len; + i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + // For the last texel in the dim, if there are padding elements then each + // element of the texel needs to be processed individually such that the + // padding elements are ignored + if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) { + const vec4 intex = load_texel(tin, scan_pos); + for (int i = 0; i < nspill; ++i) { + accum.x = UPDATE_ACCUM(accum.x, intex[i]); + } + } + // Write partial output to shared memory and synchronize work group + shared_vecs[smi] = accum; + barrier(); + + // Since the reduction row is reduced to only one element, only the "main" + // thread in the group needs aggregate the partial outputs + if (tid.x == 0) { + // Iterate over the partial maximums to obtain the overall maximum + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; ++i, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + // Each element of the texel is itself a partial maximum; iterate over the + // texel to find the actual maximum + float accum_final = accum.x; + [[unroll]] for (int i = 1; i < 4; ++i) { + accum_final = UPDATE_ACCUM(accum[i], accum_final); + } + + scan_pos[reduce_dim] = tid.x; + write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0))); + } +} + +void main() { + ivec3 scan_pos = ivec3(gl_GlobalInvocationID); + scan_pos[reduce_dim] = 0; + + const ivec2 tid = ivec2( + gl_LocalInvocationID[reduce_dim], + gl_LocalInvocationID[group_dim]); + + if (any(greaterThanEqual(scan_pos, tin_limits))) { + return; + } + + if (reduce_dim != packed_dim) { + reduce_nonpacked_dim(tid, scan_pos); + } else { + reduce_packed_dim(tid, scan_pos); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml new file mode 100644 index 0000000000..21a7132b8d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce.yaml @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + INIT_ACCUM: VEC4_T(0) + UPDATE_ACCUM: accum + new_val + POSTPROCESS: accum + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: sum + - NAME: mean + POSTPROCESS: (accum / tin_sizes[reduce_dim]) + - NAME: amax + INIT_ACCUM: first_val + UPDATE_ACCUM: max(accum, new_val) + POSTPROCESS: accum + - NAME: amin + INIT_ACCUM: first_val + UPDATE_ACCUM: min(accum, new_val) + POSTPROCESS: accum diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp new file mode 100644 index 0000000000..ffef14731f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +namespace vkcompute { + +using namespace utils; + +void resize_reduce_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + std::vector in_sizes = in->sizes(); + // out->virtual_resize(in_sizes); +} + +void add_reduce_node( + ComputeGraph& graph, + ValueRef in, + const int dim, + ValueRef out, + const std::string& op_name) { + VK_CHECK_COND( + !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), + "Vulkan reduction only supports texture storage"); + + const int64_t ndim = graph.dim_of(in); + + int32_t reduce_dim = dim; + reduce_dim = normalize(reduce_dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + // Check that the concat dim is not the reduction dim, if the tensor has a + // batch dim greater than 1. + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND( + graph.concat_dim_of(in) != reduce_dim, + "Reduce shader currently does not support concat dim == reduce dim"); + VK_CHECK_COND( + graph.concat_dim_of(out) != reduce_dim, + "Reduce shader currently does not support concat dim == reduce dim"); + } + + vkapi::ShaderInfo shader_descriptor; + std::string kernel_name = op_name; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // This should match the value of MAX_NTHREADS in the softmax shader. + constexpr uint32_t max_nthreads = 16; + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); + + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[reduce_dim] = 1; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + const int other_dim_1 = (reduce_dim + 1) % 3; + const int other_dim_2 = (reduce_dim + 2) % 3; + int32_t group_dim; + if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { + local_wg_size[other_dim_1] = ngroups; + group_dim = other_dim_1; + } else { + local_wg_size[other_dim_2] = ngroups; + group_dim = other_dim_2; + } + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + // shader_descriptor, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, + // Specialization Constants + {graph.packed_dim_of(out), reduce_dim, group_dim}, + // Resizing Logic + resize_reduce_node)); +} + +#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + const IntListPtr dims_list = graph.get_int_list(args[1]); \ + VK_CHECK_COND(dims_list->size() == 1); \ + return add_reduce_node( \ + graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \ + } + +DEFINE_REDUCE_FN(sum, 4) +DEFINE_REDUCE_FN(mean, 4) +DEFINE_REDUCE_FN(amax, 3) +DEFINE_REDUCE_FN(amin, 3) + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.sum.dim_IntList, sum); + VK_REGISTER_OP(aten.mean.dim, mean); + VK_REGISTER_OP(aten.amax.default, amax); + VK_REGISTER_OP(aten.amin.default, amin); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fb30522209..304636f2fb 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -952,32 +952,35 @@ def get_split_tensor_inputs(): return test_suite +def get_reduce_inputs(is_softmax: bool = False): + bool_arg = False if is_softmax else True + return [ + ((L), 0, bool_arg), + ((L), -1, bool_arg), + ((M, L), 0, bool_arg), + ((M, L), 1, bool_arg), + ((L, M), -1, bool_arg), + ((M, L), -2, bool_arg), + ((S, S1, S2), 0, bool_arg), + ((S, S1, S2), 1, bool_arg), + ((S, S1, S2), 2, bool_arg), + ((S, S1, S2), -1, bool_arg), + ((S, S1, S2), -2, bool_arg), + ((S, S1, S2), -3, bool_arg), + ((1, S, S1, S2), 1, bool_arg), + ((1, S, S1, S2), 2, bool_arg), + ((1, S, S1, S2), 3, bool_arg), + ((1, S, S1, S2), -1, bool_arg), + ((1, S, S1, S2), -2, bool_arg), + ((1, S, S1, S2), -3, bool_arg), + # Test batches > 1 where the reduction dim is not the concat dim + ((S, S2, S1, 128), -1, bool_arg), + ] + + @register_test_suite(["aten._softmax.default", "aten._log_softmax.default"]) def get_softmax_inputs(): - test_suite = VkTestSuite( - [ - ((L), 0, False), - ((L), -1, False), - ((M, L), 0, False), - ((M, L), 1, False), - ((L, M), -1, False), - ((M, L), -2, False), - ((S, S1, S2), 0, False), - ((S, S1, S2), 1, False), - ((S, S1, S2), 2, False), - ((S, S1, S2), -1, False), - ((S, S1, S2), -2, False), - ((S, S1, S2), -3, False), - ((1, S, S1, S2), 1, False), - ((1, S, S1, S2), 2, False), - ((1, S, S1, S2), 3, False), - ((1, S, S1, S2), -1, False), - ((1, S, S1, S2), -2, False), - ((1, S, S1, S2), -3, False), - # Test batches > 1 where the reduction dim is not the concat dim - ((S, S2, S1, 128), -1, False), - ] - ) + test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True)) test_suite.layouts = [ "utils::kWidthPacked", "utils::kChannelsPacked", @@ -985,6 +988,18 @@ def get_softmax_inputs(): return test_suite +@register_test_suite( + ["aten.amax.default", "aten.amin.default", "aten.sum.dim_IntList", "aten.mean.dim"] +) +def get_reduce_op_inputs(): + test_suite = VkTestSuite(get_reduce_inputs()) + test_suite.layouts = [ + "utils::kChannelsPacked", + "utils::kWidthPacked", + ] + return test_suite + + @register_test_suite( [ "aten.sqrt.default", diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 6c165a777d..32ba925e9e 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -141,6 +141,8 @@ def gen_parameterization(self) -> str: std::cout << "vulkan: " << std::endl; print(t2, 150); std::cout << std::endl; + print(at::abs(t2 - t1), 150); + std::cout << std::endl; } return is_close; }