Skip to content

Commit

Permalink
[JS/Web] Added uniforms to Reduce, Resize and Split Ops. (#18727)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Added uniforms to Reduce op


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve perforamnce.
  • Loading branch information
satyajandhyala authored Dec 12, 2023
1 parent 81796a3 commit 0ca8454
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 204 deletions.
22 changes: 11 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi
import {pad, parsePadAttributes} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
Expand Down Expand Up @@ -99,16 +99,16 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Pow', [binaryOps.pow]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
['ReduceMin', [reduceMin, parseReduceAttributes]],
['ReduceMean', [reduceMean, parseReduceAttributes]],
['ReduceMax', [reduceMax, parseReduceAttributes]],
['ReduceSum', [reduceSum, parseReduceAttributes]],
['ReduceProd', [reduceProd, parseReduceAttributes]],
['ReduceL1', [reduceL1, parseReduceAttributes]],
['ReduceL2', [reduceL2, parseReduceAttributes]],
['ReduceLogSum', [reduceLogSum, parseReduceAttributes]],
['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]],
['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]],
['ReduceMin', [reduceMin]],
['ReduceMean', [reduceMean]],
['ReduceMax', [reduceMax]],
['ReduceSum', [reduceSum]],
['ReduceProd', [reduceProd]],
['ReduceL1', [reduceL1]],
['ReduceL2', [reduceL2]],
['ReduceLogSum', [reduceLogSum]],
['ReduceLogSumExp', [reduceLogSumExp]],
['ReduceSumSquare', [reduceSumSquare]],
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
['Sigmoid', [unaryOps.sigmoid]],
Expand Down
32 changes: 16 additions & 16 deletions js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes)
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`,
`if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
value = ${input.getByOffset('inputOffset')};
bestIndex = i32(lastIndex);
`${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
value = ${input.getByIndices('input_indices')};
best_index = i32(last_index);
}`,
'', output.setByOffset('global_idx', 'bestIndex')
'', output.setByOffset('global_idx', 'best_index')
];
};

context.compute(
createReduceProgramInfo(
'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64,
attributes.keepDims),
'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
[attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};

Expand All @@ -59,23 +59,23 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes)
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`,
`if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
value = ${input.getByOffset('inputOffset')};
bestIndex = i32(lastIndex);
`${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
`if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
value = ${input.getByIndices('input_indices')};
best_index = i32(last_index);
}`,
'', output.setByOffset('global_idx', 'bestIndex')
'', output.setByOffset('global_idx', 'best_index')
];
};

context.compute(
createReduceProgramInfo(
'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64,
attributes.keepDims),
'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
[attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};

Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper} from './common';


export interface CumSumAttributes extends AttributeWithCacheKey {
Expand All @@ -26,7 +26,7 @@ const createCumsumProgramInfo =
const axis = ShapeUtil.normalizeAxis(axisValue, rank);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `;
const max = rank === 1 ? 'i32(uniforms.input_shape)' : 'i32(uniforms.input_shape[uniforms.axis])';
const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank);
const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0';
const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1');
return `
Expand Down
114 changes: 56 additions & 58 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types';

import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared';

const validateInputs = (inputs: readonly TensorView[]): void => {
Expand All @@ -30,14 +30,14 @@ export type ReduceOp =
(input: IndicesHelper, output: IndicesHelper,
axes: readonly number[]) => [string, string, string, string, ...string[]];

const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByOffset('inputOffset')};`, ''];
const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, ''];
export const createReduceProgramInfo =
(name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp,
axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => {
const outputShape: number[] = [];
const inputShape = inputs[0].dims;

const axes = ShapeUtil.normalizeAxes(axesInput, inputs[0].dims.length);
const inputRank = inputShape.length;
const axes = ShapeUtil.normalizeAxes(axesInput, inputRank);
const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0;
inputShape.forEach((d, i) => {
if (reduceOnAllAxes || axes.indexOf(i) >= 0) {
Expand All @@ -48,61 +48,62 @@ export const createReduceProgramInfo =
outputShape.push(d);
}
});

const idxCopy: string[] = []; // copy output indexes to input indexes

const input = inputVariable('_A', inputs[0].dataType, inputShape);
const output = outputVariable('output', outputDataType, outputShape);
const ops = reduceOp(input, output, axes);
const inputOffsetAssignment = `inputOffset = ${input.indicesToOffset('inputIndices')};`;
const initinputOffsetLet = `let ${inputOffsetAssignment};`;
const initinputOffsetVar = `var ${inputOffsetAssignment};`;
const initinputOffset = (ops[1] === '') ? '' : initinputOffsetVar;
let reduceOps = ((ops[1] === '') ? initinputOffsetLet : inputOffsetAssignment) + '\n' + ops[2];

for (let k = 0, l = 0; k < inputs[0].dims.length; k++) {
// if this axis is reduced
if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
if (keepDims) {
const outputRank = outputShape.length;
const outputSize = ShapeUtil.size(outputShape);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const idxCopy: string[] = []; // copy output indexes to input indexes

const input = inputVariable('_A', inputs[0].dataType, inputRank);
const output = outputVariable('output', outputDataType, outputRank);
const ops = reduceOp(input, output, axes);
let reduceOps = ops[2];

for (let k = 0, l = 0; k < inputRank; k++) {
// if this axis is reduced
if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
if (keepDims) {
l++;
}
// loop over the d-th axis
reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) {
${ops[2].includes('last_index') ? `let last_index = j${k};` : ''}
${input.indicesSet('input_indices', k, `j${k}`)}
${reduceOps}
}`;
} else {
idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`);
l++;
}
// loop over the d-th axis
reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
${ops[2].includes('lastIndex') ? `let lastIndex = j${k};` : ''}
${input.indicesSet('inputIndices', k, `j${k}`)}
${reduceOps}
}`;
} else {
idxCopy.push(`${input.indicesSet('inputIndices', k, output.indicesGet('outputIndices', l))};`);
l++;
}
}
return `
const outputSize = ShapeUtil.size(outputShape);
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var inputIndices: ${input.type.indices};
let outputIndices = ${output.offsetToIndices('global_idx')};
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var input_indices: ${input.type.indices};
let output_indices = ${output.offsetToIndices('global_idx')};
${idxCopy.join('\n')}
${ops[0]} // init ops for reduce max/min
${initinputOffset}
${ops[1]}
${reduceOps}
${ops[3]}
${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')}
}`;
};

return {
name,
shaderCache,
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
]
}),
};
};
Expand All @@ -125,7 +126,7 @@ const runReduceProgram =

context.compute(
createReduceProgramInfo(
name, {hint: updatedAttributes.cacheKey}, [inputs[0]],
name, {hint: updatedAttributes.cacheKey, inputDependencies: ['rank']}, [inputs[0]],
updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp,
updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims,
updatedAttributes.noopWithEmptyAxes),
Expand All @@ -137,7 +138,7 @@ const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
`value += ${input.getByOffset('inputOffset')};`,
`value += ${input.getByIndices('input_indices')};`,
'value = log(value);',
];
runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp);
Expand All @@ -148,7 +149,7 @@ const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): v
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
`value += abs(${input.getByOffset('inputOffset')});`,
`value += abs(${input.getByIndices('input_indices')});`,
'',
];
runReduceProgram(context, 'ReduceL1', attributes, reduceOp);
Expand All @@ -159,7 +160,7 @@ const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): v
const reduceOp: ReduceOp = (input, output) =>
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += (t * t);`,
`t = ${input.getByIndices('input_indices')}; value += (t * t);`,
'value = sqrt(value);',
];
runReduceProgram(context, 'ReduceL2', attributes, reduceOp);
Expand All @@ -170,7 +171,7 @@ const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttribu
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
`value += exp(${input.getByOffset('inputOffset')});`,
`value += exp(${input.getByIndices('input_indices')});`,
'value = log(value);',
];
runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp);
Expand All @@ -182,14 +183,14 @@ const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes):
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(input.indicesSet('inputIndices', k, 0));
idxZero.push(input.indicesSet('input_indices', k, 0));
}
}

return [
`${idxZero.join('\n')}`,
`var value = ${input.getByOffset('inputOffset')};`,
`value = max(value, ${input.getByOffset('inputOffset')});`,
`var value = ${input.getByIndices('input_indices')};`,
`value = max(value, ${input.getByIndices('input_indices')});`,
'',
];
};
Expand All @@ -210,7 +211,7 @@ const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes):
return [
'var sum = f32(0);',
'',
`sum += f32(${input.getByOffset('inputOffset')});`,
`sum += f32(${input.getByIndices('input_indices')});`,
`let value = ${output.type.value}(sum / ${size});`,
];
};
Expand All @@ -223,14 +224,14 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes):
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}

return [
`${idxZero.join('\n')}`,
`var value = ${input.getByOffset('inputOffset')};`,
`value = min(value, ${input.getByOffset('inputOffset')});`,
`var value = ${input.getByIndices('input_indices')};`,
`value = min(value, ${input.getByIndices('input_indices')});`,
'',
];
};
Expand All @@ -242,7 +243,7 @@ const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes):
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(1);`,
'',
`value *= ${input.getByOffset('inputOffset')};`,
`value *= ${input.getByIndices('input_indices')};`,
'',
];
runReduceProgram(context, 'ReduceProd', attributes, reduceOp);
Expand All @@ -253,7 +254,7 @@ const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes):
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
`value += ${input.getByOffset('inputOffset')};`,
`value += ${input.getByIndices('input_indices')};`,
'',
];
runReduceProgram(context, 'ReduceSum', attributes, reduceOp);
Expand All @@ -264,7 +265,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu
const reduceOp: ReduceOp = (input, output) =>
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += t * t;`,
`t = ${input.getByIndices('input_indices')}; value += t * t;`,
'',
];
runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp);
Expand All @@ -273,7 +274,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu
const useNaiveReduceMethod =
(shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => {
if (axes.length === 0) {
return noopWithEmptyAxes ? true : false;
return noopWithEmptyAxes;
}

let outputSize = 1;
Expand All @@ -289,7 +290,7 @@ const useNaiveReduceMethod =
// The condition data is very rough, although considering the count of Execution Unit (EU), the potential
// work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments
// on some machines.
return reduceSize < 32 && outputSize > 1024 ? true : false;
return reduceSize < 32 && outputSize > 1024;
};

export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
Expand Down Expand Up @@ -371,6 +372,3 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut
reduceLogSumShared(context, attributes);
}
};

export const parseReduceAttributes = (attributes: Record<string, unknown>): ReduceAttributes =>
createAttributeWithCacheKey(attributes as Omit<ReduceAttributes, keyof AttributeWithCacheKey>);
Loading

0 comments on commit 0ca8454

Please sign in to comment.