Skip to content

Commit

Permalink
[js/web] optimize reduce related operators (#17957)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
xhcao authored Nov 2, 2023
1 parent 8978bdc commit 8d48d3e
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 11 deletions.
266 changes: 266 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce';
import {createTransposeProgramInfo} from './transpose';

const reduceOps: {[key: string]: string} = {
max: 'select(bestValue, candidate, candidate > bestValue)',
min: 'select(bestValue, candidate, candidate < bestValue)',
mean: 'bestValue + candidate',
sum: 'bestValue + candidate',
prod: 'bestValue * candidate',
sumSquare: 'bestValue + candidate * candidate',
logSumExp: 'bestValue + exp(candidate)',
l1: 'bestValue + abs(candidate)',
l2: 'bestValue + candidate * candidate',
logSum: 'bestValue + candidate'
};

const reduceSharedOps: {[key: string]: string} = {
max: 'select(bestValue, candidate, candidate > bestValue)',
min: 'select(bestValue, candidate, candidate < bestValue)',
mean: 'bestValue + candidate',
sum: 'bestValue + candidate',
prod: 'bestValue * candidate',
sumSquare: 'bestValue + candidate',
logSumExp: 'bestValue + candidate',
l1: 'bestValue + candidate',
l2: 'bestValue + candidate',
logSum: 'bestValue + candidate'
};

const reduceInitValues: {[key: string]: string} = {
max: '_A[offset]',
min: '_A[offset]',
mean: '0',
sum: '0',
prod: '1',
sumSquare: '0',
logSumExp: '0',
l1: '0',
l2: '0',
logSum: '0'
};

const reduceOutputValues: {[key: string]: string} = {
max: 'bestValue',
min: 'bestValue',
sum: 'bestValue',
prod: 'bestValue',
sumSquare: 'bestValue',
logSumExp: 'log(bestValue)',
l1: 'bestValue',
l2: 'sqrt(bestValue)',
logSum: 'log(bestValue)'
};

const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => {
const res = [];
for (let i = rank - numInnerAxes; i < rank; ++i) {
res.push(i);
}
return res;
};

const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => {
const outputShape = [];
const rank = shape.length;
for (let dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
outputShape.push(shape[dim]);
}
}
const reduceShape = axes.map(dim => shape[dim]);
return [outputShape, reduceShape];
};

const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => {
const rank = shape.length + axes.length;
const expandShape = [];
let shapeIdx = 0;
for (let dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
expandShape.push(shape[shapeIdx++]);
} else {
expandShape.push(1);
}
}
return expandShape;
};

const areAxesInnerMostDims = (axes: number[], rank: number): boolean => {
for (let i = 0; i < axes.length; ++i) {
if (axes[axes.length - i - 1] !== rank - 1 - i) {
return false;
}
}
return true;
};

const getAxesPermutation = (axes: number[], rank: number): number[] => {
const res = [];
if (!areAxesInnerMostDims(axes, rank)) {
for (let i = 0; i < rank; ++i) {
if (axes.indexOf(i) === -1) {
res.push(i);
}
}
axes.forEach(axis => res.push(axis));
}
return res;
};

export const createReduceSharedProgramInfo =
(name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string,
outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => {
const inputShape = inputs[0].dims;

const outputSize = ShapeUtil.size(outputShape);
const reduceSize = ShapeUtil.size(reduceShape);

const input = inputVariable('_A', inputs[0].dataType, inputShape);
const output = outputVariable('output', outputDataType, outputShape);

const workgroupSize = 32;

const sharedMemorySnippet = `
var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
`;

const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)}
${sharedMemorySnippet}
fn DIV_CEIL(a : u32, b : u32) -> u32 {
return ((a - 1u) / b + 1u);
}
${shaderHelper.mainStart(workgroupSize)}
let local_idx = local_id.x;
let outputIndex = global_idx / ${workgroupSize};
let offset = outputIndex * uniforms.reduceSize;
var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
let Length = uniforms.reduceSize;
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
bestValue = ${reduceOps[reduceType]};
}
aBestValues[local_idx] = bestValue;
workgroupBarrier();
var reduceSize = min(Length, ${workgroupSize}u);
for (var currentSize = reduceSize / 2u; reduceSize > 1u;
currentSize = reduceSize / 2u) {
let interval = DIV_CEIL(reduceSize, 2u);
if (local_idx < currentSize) {
let candidate = aBestValues[local_idx + interval];
bestValue = ${reduceSharedOps[reduceType]};
aBestValues[local_idx] = bestValue;
}
reduceSize = interval;
workgroupBarrier();
}
if (local_idx == 0u) {
${
output.setByOffset(
'outputIndex',
`${
reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
`${reduceOutputValues[reduceType]}`}`)};
}
}`;

// One work group is responsible for only one element of output.
return {
name,
shaderCache,
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: outputSize},
programUniforms: [{type: 'uint32', data: reduceSize}]
}),
};
};

const reduceCommon =
(context: ComputeContext, name: string, attributes: ReduceAttributes,
reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => {
const updatedAttributes: ReduceAttributes =
context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes);

let updatedAxes = updatedAttributes.axes;
if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) {
updatedAxes = context.inputs[0].dims.map((s, i) => i);
}
const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length);

let axes = normalizeAxes;
let input = context.inputs[0];
const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length);
if (permutedAxes.length > 0) {
input = context.compute(
createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0];
axes = getInnerMostAxes(axes.length, input.dims.length);
}

const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes);
let finalOutputShape = outputShape;
if (updatedAttributes.keepDims) {
finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes);
}

context.compute(
createReduceSharedProgramInfo(
name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType,
context.inputs[0].dataType, finalOutputShape, reduceShape),
{inputs: [input]});
};

export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceMeanShared', attributes, 'mean');
};

export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceL1Shared', attributes, 'l1');
};

export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceL2Shared', attributes, 'l2');
};

export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp');
};

export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceMaxShared', attributes, 'max');
};

export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceMinShared', attributes, 'min');
};

export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceProdShared', attributes, 'prod');
};

export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceSumShared', attributes, 'sum');
};

export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare');
};

export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum');
};
Loading

0 comments on commit 8d48d3e

Please sign in to comment.