Skip to content

Commit

Permalink
[ET-VK] Implement generic reduction shader + mean, sum, amax, amin
Browse files Browse the repository at this point in the history
Pull Request resolved: #6457

## Context

Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented.
ghstack-source-id: 249659867
@exported-using-ghexport

Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/)
  • Loading branch information
SS-JIA committed Oct 23, 2024
1 parent 57837e2 commit bdaec7e
Show file tree
Hide file tree
Showing 13 changed files with 418 additions and 424 deletions.
1 change: 0 additions & 1 deletion backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ runtime.python_library(
"//executorch/backends/transforms:fuse_conv_with_clamp",
"//executorch/backends/transforms:fuse_dequant_linear",
"//executorch/backends/transforms:fuse_view_copy",
"//executorch/backends/transforms:mean_to_sum_div",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/exir:graph_module",
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __contains__(self, op):
# Reduction
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
# 2D Pooling
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
Expand All @@ -101,9 +105,6 @@ def __contains__(self, op):
]

NO_DYNAMIC_SHAPE = [
# Reduction
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
# Normalization
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
Expand Down
214 changes: 214 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "tin_limits")}
${layout_declare_ubo(B, "ivec4", "tin_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim = 0;
layout(constant_id = 5) const int group_dim = 1;

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define MAX_NTHREADS 16


shared vec4 shared_vecs[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

/*
* The functions below compute reduction along a single dimension for a tensor.
* The shader template generalize reduction by abstracting the initial value of
* the accumulator, the calculation used to update the accumulator with new
* values, and a postprocessing calculation that can be used to modify the
* accumulator before writing to output.
*
* This shader also utilize shared memory to have multiple threads help compute
* the max and sum reduction operations. A total of NGROUPS x NWORKERS threads
* are expected to be launched. Each group works on a unique reduction "row", and
* within a group NWORKERS threads co-operate to compute the max and sum of one
* "row". Each worker in the group is responsible for computing a partial output
* of the "row" and uploading it to shared memory; the overall reduction output
* can then be determined by aggregating the partial outputs stored in shared
* memory.
*
* As a caveat, this shader does not currently support cases where `batch` > 1
* and the reduce dim happens to also be the batch concatenation dim. To support
* this, there will need to be additional logic to set the starting value of
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
* supporting this case is left as an exercise for when it is required.
*/

// Initializing the accumulator accepts the first value in the reduction row,
// since some reduction operations (i.e. amax, amin) prefer to initialize with
// a data point instead of a static value.
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
// Useful for operators such as mean which want to perform a final calculation
// with the accumulator.
#define POSTPROCESS(accum) ${POSTPROCESS}

/*
* Computes reduction where the reduction dim is orthogonal to the packed dim.
* This case is simpler because each element of a texel belongs to a separate
* reduction "group", meaning we don't have to perform reduction along a texel.
*/
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

scan_pos[reduce_dim] = 0;
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));

scan_pos[reduce_dim] = tid.x;
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
// the reduction row
for (int i = tid.x; i < tin_sizes[reduce_dim];
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
// Write partial output to shared memory and synchronize work group
shared_vecs[smi] = accum;
barrier();

// Since the reduction row is reduced to only one element, only the "main"
// thread in the group needs aggregate the partial outputs
if (tid.x == 0) {
// Iterate over the partial outputs to obtain the overall output
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const bool is_last_texel =
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);

// Explicitly set padding elements to 0
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; i++) {
accum[i] = 0;
}
}
scan_pos[reduce_dim] = tid.x;
write_texel(tout, scan_pos, POSTPROCESS(accum));
}
}

/*
* Compute reduction where the reduction dim is also the packed dim. This case is
* complex because the reduction needs to occur over the individual texels.
* Therefore, in this algorithm each element of the accumulator texels are
* themselves partial outputs. Special care has to be taken to ignore padding
* elements in texels (which occur when the size of the packed dim is not a
* multiple of 4) so that they do not influence the output of reduction.
*/
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

// Number of non-padding elements in the last texel in the reduction row
const int nspill = mod4(tin_sizes[packed_dim]);
// Only reduce up to the last "complete" texel. The last texel will need to be
// handled specially if it has padding elements.
const int reduce_len = tin_sizes[packed_dim] - nspill;

scan_pos[reduce_dim] = 0;
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x));

// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of
// the reduction row
scan_pos[reduce_dim] = tid.x;
for (int i = tid.x * 4; i < reduce_len;
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
// For the last texel in the dim, if there are padding elements then each
// element of the texel needs to be processed individually such that the
// padding elements are ignored
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) {
const vec4 intex = load_texel(tin, scan_pos);
for (int i = 0; i < nspill; i++) {
accum.x = UPDATE_ACCUM(accum.x, intex[i]);
}
}
// Write partial output to shared memory and synchronize work group
shared_vecs[smi] = accum;
barrier();

// Since the reduction row is reduced to only one element, only the "main"
// thread in the group needs aggregate the partial outputs
if (tid.x == 0) {
// Iterate over the partial maximums to obtain the overall maximum
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}
// Each element of the texel is itself a partial maximum; iterate over the
// texel to find the actual maximum
float accum_final = accum.x;
[[unroll]] for (int i = 1; i < 4; i++) {
accum_final = UPDATE_ACCUM(accum[i], accum_final);
}

scan_pos[reduce_dim] = tid.x;
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0)));
}
}

void main() {
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
scan_pos[reduce_dim] = 0;

const ivec2 tid = ivec2(
gl_LocalInvocationID[reduce_dim],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tin_limits))) {
return;
}

if (reduce_dim != packed_dim) {
reduce_nonpacked_dim(tid, scan_pos);
} else {
reduce_packed_dim(tid, scan_pos);
}
}
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

reduce:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
INIT_ACCUM: VEC4_T(0)
UPDATE_ACCUM: accum + new_val
POSTPROCESS: accum
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sum
- NAME: mean
POSTPROCESS: (accum / tin_sizes[reduce_dim])
- NAME: amax
INIT_ACCUM: first_val
UPDATE_ACCUM: max(accum, new_val)
POSTPROCESS: accum
- NAME: amin
INIT_ACCUM: first_val
UPDATE_ACCUM: min(accum, new_val)
POSTPROCESS: accum
108 changes: 0 additions & 108 deletions backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl

This file was deleted.

16 changes: 0 additions & 16 deletions backends/vulkan/runtime/graph/ops/glsl/sum_dim.yaml

This file was deleted.

Loading

0 comments on commit bdaec7e

Please sign in to comment.