-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ET-VK] Implement generic reduction shader + mean, sum, amax, amin
Pull Request resolved: #6457 ## Context Introduce a generic shader to compute reduction along a single dim, and `keepdim = True`. With the generic shader template, `mean`, `sum`, `amin`, and `amax` can be implemented. ghstack-source-id: 249659867 @exported-using-ghexport Differential Revision: [D64840504](https://our.internmc.facebook.com/intern/diff/D64840504/)
- Loading branch information
Showing
13 changed files
with
418 additions
and
424 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#version 450 core | ||
|
||
#define PRECISION ${PRECISION} | ||
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} | ||
|
||
${define_active_storage_type(STORAGE)} | ||
|
||
#extension GL_EXT_control_flow_attributes : require | ||
|
||
layout(std430) buffer; | ||
|
||
${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} | ||
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} | ||
|
||
${layout_declare_ubo(B, "ivec3", "tin_limits")} | ||
${layout_declare_ubo(B, "ivec4", "tin_sizes")} | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
layout(constant_id = 3) const int packed_dim = 0; | ||
layout(constant_id = 4) const int reduce_dim = 0; | ||
layout(constant_id = 5) const int group_dim = 1; | ||
|
||
// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of | ||
// threads that will co-operate to compute one reduction output. There may be | ||
// multiple groups computing distinct reduction outputs within one work group. | ||
#define NWORKERS 4 | ||
|
||
// Sets an upper limit on the total size of a work group based on how many | ||
// elements are allocated in the shared memory array below. Each thread in the | ||
// work group will write into its assigned element in the shared array. | ||
#define MAX_NTHREADS 16 | ||
|
||
|
||
shared vec4 shared_vecs[MAX_NTHREADS]; | ||
|
||
#include "indexing_utils.h" | ||
|
||
int tid_to_smi(const ivec2 tid) { | ||
return tid.x + tid.y * NWORKERS; | ||
} | ||
|
||
/* | ||
* The functions below compute reduction along a single dimension for a tensor. | ||
* The shader template generalize reduction by abstracting the initial value of | ||
* the accumulator, the calculation used to update the accumulator with new | ||
* values, and a postprocessing calculation that can be used to modify the | ||
* accumulator before writing to output. | ||
* | ||
* This shader also utilize shared memory to have multiple threads help compute | ||
* the max and sum reduction operations. A total of NGROUPS x NWORKERS threads | ||
* are expected to be launched. Each group works on a unique reduction "row", and | ||
* within a group NWORKERS threads co-operate to compute the max and sum of one | ||
* "row". Each worker in the group is responsible for computing a partial output | ||
* of the "row" and uploading it to shared memory; the overall reduction output | ||
* can then be determined by aggregating the partial outputs stored in shared | ||
* memory. | ||
* | ||
* As a caveat, this shader does not currently support cases where `batch` > 1 | ||
* and the reduce dim happens to also be the batch concatenation dim. To support | ||
* this, there will need to be additional logic to set the starting value of | ||
* `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case, | ||
* supporting this case is left as an exercise for when it is required. | ||
*/ | ||
|
||
// Initializing the accumulator accepts the first value in the reduction row, | ||
// since some reduction operations (i.e. amax, amin) prefer to initialize with | ||
// a data point instead of a static value. | ||
#define INIT_ACCUM(first_val) ${INIT_ACCUM} | ||
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} | ||
// Useful for operators such as mean which want to perform a final calculation | ||
// with the accumulator. | ||
#define POSTPROCESS(accum) ${POSTPROCESS} | ||
|
||
/* | ||
* Computes reduction where the reduction dim is orthogonal to the packed dim. | ||
* This case is simpler because each element of a texel belongs to a separate | ||
* reduction "group", meaning we don't have to perform reduction along a texel. | ||
*/ | ||
void reduce_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { | ||
// shared memory index of this thread | ||
const int smi = tid_to_smi(tid); | ||
|
||
scan_pos[reduce_dim] = 0; | ||
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); | ||
|
||
scan_pos[reduce_dim] = tid.x; | ||
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of | ||
// the reduction row | ||
for (int i = tid.x; i < tin_sizes[reduce_dim]; | ||
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { | ||
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); | ||
} | ||
// Write partial output to shared memory and synchronize work group | ||
shared_vecs[smi] = accum; | ||
barrier(); | ||
|
||
// Since the reduction row is reduced to only one element, only the "main" | ||
// thread in the group needs aggregate the partial outputs | ||
if (tid.x == 0) { | ||
// Iterate over the partial outputs to obtain the overall output | ||
int group_i = tid.y * NWORKERS; | ||
accum = shared_vecs[group_i++]; | ||
for (int i = 1; i < NWORKERS; i++, group_i++) { | ||
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); | ||
} | ||
|
||
// Determine if there are any padding elements in the final texel of the | ||
// packed dimension | ||
const int nspill = mod4(tin_sizes[packed_dim]); | ||
// Detect if this thread is working on the final texels of the packed | ||
// dimension, which may have padding elements | ||
const bool is_last_texel = | ||
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); | ||
|
||
// Explicitly set padding elements to 0 | ||
if (is_last_texel && nspill > 0) { | ||
[[unroll]] for (int i = nspill; i < 4; i++) { | ||
accum[i] = 0; | ||
} | ||
} | ||
scan_pos[reduce_dim] = tid.x; | ||
write_texel(tout, scan_pos, POSTPROCESS(accum)); | ||
} | ||
} | ||
|
||
/* | ||
* Compute reduction where the reduction dim is also the packed dim. This case is | ||
* complex because the reduction needs to occur over the individual texels. | ||
* Therefore, in this algorithm each element of the accumulator texels are | ||
* themselves partial outputs. Special care has to be taken to ignore padding | ||
* elements in texels (which occur when the size of the packed dim is not a | ||
* multiple of 4) so that they do not influence the output of reduction. | ||
*/ | ||
void reduce_packed_dim(const ivec2 tid, ivec3 scan_pos) { | ||
// shared memory index of this thread | ||
const int smi = tid_to_smi(tid); | ||
|
||
// Number of non-padding elements in the last texel in the reduction row | ||
const int nspill = mod4(tin_sizes[packed_dim]); | ||
// Only reduce up to the last "complete" texel. The last texel will need to be | ||
// handled specially if it has padding elements. | ||
const int reduce_len = tin_sizes[packed_dim] - nspill; | ||
|
||
scan_pos[reduce_dim] = 0; | ||
vec4 accum = INIT_ACCUM(vec4(load_texel(tin, scan_pos).x)); | ||
|
||
// Partially accumulate over elements i, i + NWORKERS, i + 2*NWORKERS, ... of | ||
// the reduction row | ||
scan_pos[reduce_dim] = tid.x; | ||
for (int i = tid.x * 4; i < reduce_len; | ||
i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { | ||
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); | ||
} | ||
// For the last texel in the dim, if there are padding elements then each | ||
// element of the texel needs to be processed individually such that the | ||
// padding elements are ignored | ||
if (scan_pos[reduce_dim] == tin_limits[reduce_dim] - 1 && nspill > 0) { | ||
const vec4 intex = load_texel(tin, scan_pos); | ||
for (int i = 0; i < nspill; i++) { | ||
accum.x = UPDATE_ACCUM(accum.x, intex[i]); | ||
} | ||
} | ||
// Write partial output to shared memory and synchronize work group | ||
shared_vecs[smi] = accum; | ||
barrier(); | ||
|
||
// Since the reduction row is reduced to only one element, only the "main" | ||
// thread in the group needs aggregate the partial outputs | ||
if (tid.x == 0) { | ||
// Iterate over the partial maximums to obtain the overall maximum | ||
int group_i = tid.y * NWORKERS; | ||
accum = shared_vecs[group_i++]; | ||
for (int i = 1; i < NWORKERS; i++, group_i++) { | ||
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); | ||
} | ||
// Each element of the texel is itself a partial maximum; iterate over the | ||
// texel to find the actual maximum | ||
float accum_final = accum.x; | ||
[[unroll]] for (int i = 1; i < 4; i++) { | ||
accum_final = UPDATE_ACCUM(accum[i], accum_final); | ||
} | ||
|
||
scan_pos[reduce_dim] = tid.x; | ||
write_texel(tout, scan_pos, POSTPROCESS(vec4(accum_final, 0, 0, 0))); | ||
} | ||
} | ||
|
||
void main() { | ||
ivec3 scan_pos = ivec3(gl_GlobalInvocationID); | ||
scan_pos[reduce_dim] = 0; | ||
|
||
const ivec2 tid = ivec2( | ||
gl_LocalInvocationID[reduce_dim], | ||
gl_LocalInvocationID[group_dim]); | ||
|
||
if (any(greaterThanEqual(scan_pos, tin_limits))) { | ||
return; | ||
} | ||
|
||
if (reduce_dim != packed_dim) { | ||
reduce_nonpacked_dim(tid, scan_pos); | ||
} else { | ||
reduce_packed_dim(tid, scan_pos); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
reduce: | ||
parameter_names_with_default_values: | ||
DTYPE: float | ||
STORAGE: texture3d | ||
INIT_ACCUM: VEC4_T(0) | ||
UPDATE_ACCUM: accum + new_val | ||
POSTPROCESS: accum | ||
generate_variant_forall: | ||
DTYPE: | ||
- VALUE: half | ||
- VALUE: float | ||
shader_variants: | ||
- NAME: sum | ||
- NAME: mean | ||
POSTPROCESS: (accum / tin_sizes[reduce_dim]) | ||
- NAME: amax | ||
INIT_ACCUM: first_val | ||
UPDATE_ACCUM: max(accum, new_val) | ||
POSTPROCESS: accum | ||
- NAME: amin | ||
INIT_ACCUM: first_val | ||
UPDATE_ACCUM: min(accum, new_val) | ||
POSTPROCESS: accum |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.