Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Implement generic reduction shader + mean, sum, amax, amin #6457

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_conv_with_clamp",
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir:graph_module",
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __contains__(self, op):
# Reduction
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
# 2D Pooling
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
Expand All @@ -101,9 +105,6 @@ def __contains__(self, op):
]

NO_DYNAMIC_SHAPE = [
# Reduction
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
# Normalization
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
Expand Down
214 changes: 214 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "tin_limits")}
${layout_declare_ubo(B, "ivec4", "tin_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim = 0;
layout(constant_id = 5) const int group_dim = 1;

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define MAX_NTHREADS 16


shared vec4 shared_vecs[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

/*
* The functions below compute reduction along a single dimension for a tensor.
* The shader template generalize reduction by abstracting the initial value of
* the accumulator, the calculation used to update the accumulator with new
* values, and a postprocessing calculation that can be used to modify the
* accumulator before writing to output.
*
* This shader also utilize shared memory to have multiple threads help compute
* the max and sum reduction operations. A total of NGROUPS x NWORKERS threads
* are expected to be launched. Each group works on a unique reduction "row", and
* within a group NWORKERS threads co-operate to compute the max and sum of one
* "row". Each worker in the group is responsible for computing a partial output
* of the "row" and uploading it to shared memory; the overall reduction output
* can then be determined by aggregating the partial outputs stored in shared
* memory.
*
* As a caveat, this shader does not currently support cases where `batch` > 1
* and the reduce dim happens to also be the batch concatenation dim. To support
* this, there will need to be additional logic to set the starting value of
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
* supporting this case is left as an exercise for when it is required.
*/

// Initializing the accumulator accepts the first value in the reduction row,
// since some reduction operations (i.e. amax, amin) prefer to initialize with
// a data point instead of a static value.
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
// Useful for operators such as mean which want to perform a final calculation
// with the accumulator.
#define POSTPROCESS(accum) ${POSTPROCESS}

/*
* Computes reduction where the reduction dim is orthogonal to the packed dim.
* This case is simpler because each element of a texel belongs to a separate
* reduction "group", meaning we don't have to perform reduction along a texel.
*/
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

scan_pos[reduce_dim] = 0;
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));

scan_pos[reduce_dim] = tid.x;
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
// the reduction row
for (int i = tid.x; i < tin_sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
// Write partial output to shared memory and synchronize work group
shared_vecs[smi] = accum;
barrier();

// Since the reduction row is reduced to only one element, only the "main"
// thread in the group needs aggregate the partial outputs
if (tid.x == 0) {
// Iterate over the partial outputs to obtain the overall output
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const bool is_last_texel =
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);

// Explicitly set padding elements to 0
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; i++) {
accum[i] = 0;
}
}
scan_pos[reduce_dim] = tid.x;
write_texel(tout, scan_pos, POSTPROCESS(accum));
}
}

/*
* Compute reduction where the reduction dim is also the packed dim. This case is
* complex because the reduction needs to occur over the individual texels.
* Therefore, in this algorithm each element of the accumulator texels are
* themselves partial outputs. Special care has to be taken to ignore padding
* elements in texels (which occur when the size of the packed dim is not a
* multiple of 4) so that they do not influence the output of reduction.
*/
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

// Number of non-padding elements in the last texel in the reduction row
const int nspill = mod4(tin_sizes[packed_dim]);
// Only reduce up to the last "complete" texel. The last texel will need to be
// handled specially if it has padding elements.
const int reduce_len = tin_sizes[packed_dim] - nspill;

scan_pos[reduce_dim] = 0;
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));

// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
// the reduction row
scan_pos[reduce_dim] = tid.x;
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
// For the last texel in the dim, if there are padding elements then each
// element of the texel needs to be processed individually such that the
// padding elements are ignored
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
const vec4 intex = load_texel(tin, scan_pos);
for (int i = 0; i < nspill; i++) {
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
}
}
// Write partial output to shared memory and synchronize work group
shared_vecs[smi] = accum;
barrier();

// Since the reduction row is reduced to only one element, only the "main"
// thread in the group needs aggregate the partial outputs
if (tid.x == 0) {
// Iterate over the partial maximums to obtain the overall maximum
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}
// Each element of the texel is itself a partial maximum; iterate over the
// texel to find the actual maximum
float accum_final = accum.x;
[[unroll]] for (int i = 1; i < 4; i++) {
accum_final = UPDATE_ACCUM(accum[i], accum_final);
}

scan_pos[reduce_dim] = tid.x;
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
}
}

void main() {
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
scan_pos[reduce_dim] = 0;

const ivec2 tid = ivec2(
gl_LocalInvocationID[reduce_dim],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tin_limits))) {
return;
}

if (reduce_dim != packed_dim) {
reduce_nonpacked_dim(tid, scan_pos);
} else {
reduce_packed_dim(tid, scan_pos);
}
}
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

reduce:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
INIT_ACCUM: VEC4_T(0)
UPDATE_ACCUM: accum + new_val
POSTPROCESS: accum
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sum
- NAME: mean
POSTPROCESS: (accum / tin_sizes[reduce_dim])
- NAME: amax
INIT_ACCUM: first_val
UPDATE_ACCUM: max(accum, new_val)
POSTPROCESS: accum
- NAME: amin
INIT_ACCUM: first_val
UPDATE_ACCUM: min(accum, new_val)
POSTPROCESS: accum
108 changes: 0 additions & 108 deletions backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl

This file was deleted.

16 changes: 0 additions & 16 deletions backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml

This file was deleted.

Loading
Loading