From 8d48d3e9cce4a7b5640051899f824c69a6a67c62 Mon Sep 17 00:00:00 2001 From: xhcao Date: Fri, 3 Nov 2023 03:51:48 +0800 Subject: [PATCH] [js/web] optimize reduce related operators (#17957) ### Description ### Motivation and Context --- .../lib/wasm/jsep/webgpu/ops/reduce-shared.ts | 266 ++++++++++++++++++ js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 125 +++++++- 2 files changed, 380 insertions(+), 11 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts new file mode 100644 index 0000000000000..54a7414360c4e --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createReduceAttributesFromInputs, ReduceAttributes} from './reduce'; +import {createTransposeProgramInfo} from './transpose'; + +const reduceOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate * candidate', + logSumExp: 'bestValue + exp(candidate)', + l1: 'bestValue + abs(candidate)', + l2: 'bestValue + candidate * candidate', + logSum: 'bestValue + candidate' +}; + +const reduceSharedOps: {[key: string]: string} = { + max: 'select(bestValue, candidate, candidate > bestValue)', + min: 'select(bestValue, candidate, candidate < bestValue)', + mean: 'bestValue + candidate', + sum: 'bestValue + candidate', + prod: 'bestValue * candidate', + sumSquare: 'bestValue + candidate', + logSumExp: 'bestValue + candidate', + l1: 'bestValue + candidate', + l2: 'bestValue + candidate', + logSum: 'bestValue + candidate' +}; + +const reduceInitValues: {[key: string]: string} = { + max: '_A[offset]', + min: '_A[offset]', + mean: '0', + sum: '0', + prod: '1', + sumSquare: '0', + logSumExp: '0', + l1: '0', + l2: '0', + logSum: '0' +}; + +const reduceOutputValues: {[key: string]: string} = { + max: 'bestValue', + min: 'bestValue', + sum: 'bestValue', + prod: 'bestValue', + sumSquare: 'bestValue', + logSumExp: 'log(bestValue)', + l1: 'bestValue', + l2: 'sqrt(bestValue)', + logSum: 'log(bestValue)' +}; + +const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => { + const res = []; + for (let i = rank - numInnerAxes; i < rank; ++i) { + res.push(i); + } + return res; +}; + +const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => { + const outputShape = []; + const rank = shape.length; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outputShape.push(shape[dim]); + } + } + const reduceShape = axes.map(dim => shape[dim]); + return [outputShape, reduceShape]; +}; + +const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => { + const rank = shape.length + axes.length; + const expandShape = []; + let shapeIdx = 0; + for (let dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + expandShape.push(shape[shapeIdx++]); + } else { + expandShape.push(1); + } + } + return expandShape; +}; + +const areAxesInnerMostDims = (axes: number[], rank: number): boolean => { + for (let i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; +}; + +const getAxesPermutation = (axes: number[], rank: number): number[] => { + const res = []; + if (!areAxesInnerMostDims(axes, rank)) { + for (let i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + res.push(i); + } + } + axes.forEach(axis => res.push(axis)); + } + return res; +}; + +export const createReduceSharedProgramInfo = + (name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceType: string, + outputDataType: DataType, outputShape: number[], reduceShape: number[]): ProgramInfo => { + const inputShape = inputs[0].dims; + + const outputSize = ShapeUtil.size(outputShape); + const reduceSize = ShapeUtil.size(reduceShape); + + const input = inputVariable('_A', inputs[0].dataType, inputShape); + const output = outputVariable('output', outputDataType, outputShape); + + const workgroupSize = 32; + + const sharedMemorySnippet = ` + var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + `; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)} + ${sharedMemorySnippet} + fn DIV_CEIL(a : u32, b : u32) -> u32 { + return ((a - 1u) / b + 1u); + } + ${shaderHelper.mainStart(workgroupSize)} + let local_idx = local_id.x; + + let outputIndex = global_idx / ${workgroupSize}; + let offset = outputIndex * uniforms.reduceSize; + + var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + let Length = uniforms.reduceSize; + for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { + let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + bestValue = ${reduceOps[reduceType]}; + } + aBestValues[local_idx] = bestValue; + workgroupBarrier(); + + var reduceSize = min(Length, ${workgroupSize}u); + for (var currentSize = reduceSize / 2u; reduceSize > 1u; + currentSize = reduceSize / 2u) { + let interval = DIV_CEIL(reduceSize, 2u); + if (local_idx < currentSize) { + let candidate = aBestValues[local_idx + interval]; + bestValue = ${reduceSharedOps[reduceType]}; + aBestValues[local_idx] = bestValue; + } + reduceSize = interval; + workgroupBarrier(); + } + + if (local_idx == 0u) { + ${ + output.setByOffset( + 'outputIndex', + `${ + reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : + `${reduceOutputValues[reduceType]}`}`)}; + } + }`; + + // One work group is responsible for only one element of output. + return { + name, + shaderCache, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: outputSize}, + programUniforms: [{type: 'uint32', data: reduceSize}] + }), + }; + }; + +const reduceCommon = + (context: ComputeContext, name: string, attributes: ReduceAttributes, + reduceType: 'sum'|'sumSquare'|'prod'|'min'|'max'|'mean'|'logSumExp'|'l1'|'l2'|'logSum'): void => { + const updatedAttributes: ReduceAttributes = + context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes); + + let updatedAxes = updatedAttributes.axes; + if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { + updatedAxes = context.inputs[0].dims.map((s, i) => i); + } + const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); + + let axes = normalizeAxes; + let input = context.inputs[0]; + const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length); + if (permutedAxes.length > 0) { + input = context.compute( + createTransposeProgramInfo(context.inputs[0], permutedAxes), {inputs: [0], outputs: [-1]})[0]; + axes = getInnerMostAxes(axes.length, input.dims.length); + } + + const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes); + let finalOutputShape = outputShape; + if (updatedAttributes.keepDims) { + finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes); + } + + context.compute( + createReduceSharedProgramInfo( + name, {hint: updatedAttributes.cacheKey, inputDependencies: ['type']}, [input], reduceType, + context.inputs[0].dataType, finalOutputShape, reduceShape), + {inputs: [input]}); + }; + +export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMeanShared', attributes, 'mean'); +}; + +export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL1Shared', attributes, 'l1'); +}; + +export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceL2Shared', attributes, 'l2'); +}; + +export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp'); +}; + +export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMaxShared', attributes, 'max'); +}; + +export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceMinShared', attributes, 'min'); +}; + +export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceProdShared', attributes, 'prod'); +}; + +export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumShared', attributes, 'sum'); +}; + +export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare'); +}; + +export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => { + reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum'); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 44d6332852d2a..b5c956e57a9b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -8,6 +8,7 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types'; import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length === 0 || inputs.length > 2) { @@ -106,7 +107,7 @@ export const createReduceProgramInfo = }; }; -const createReduceAttributesFromInputs = +export const createReduceAttributesFromInputs = (inputs: readonly TensorView[], attributes: ReduceAttributes): ReduceAttributes => { const axes: number[] = []; if (inputs[1].dims[0] > 0) { @@ -131,7 +132,7 @@ const runReduceProgram = {inputs: [0]}); }; -export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -142,7 +143,7 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp); }; -export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -153,7 +154,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL1', attributes, reduceOp); }; -export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -164,7 +165,7 @@ export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): runReduceProgram(context, 'ReduceL2', attributes, reduceOp); }; -export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -175,7 +176,7 @@ export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp); }; -export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -195,7 +196,7 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMax', attributes, reduceOp); }; -export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output, axes) => { let size = 1.0; @@ -216,7 +217,7 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceMean', attributes, reduceOp); }; -export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, _output, axes) => { const idxZero = []; @@ -236,7 +237,7 @@ export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceMin', attributes, reduceOp); }; -export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(1);`, @@ -247,7 +248,7 @@ export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes runReduceProgram(context, 'ReduceProd', attributes, reduceOp); }; -export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var value = ${output.type.storage}(0);`, @@ -258,7 +259,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) runReduceProgram(context, 'ReduceSum', attributes, reduceOp); }; -export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { +const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, @@ -269,5 +270,107 @@ export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttri runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp); }; +const useNaiveReduceMethod = + (shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => { + if (axes.length === 0) { + return noopWithEmptyAxes ? true : false; + } + + let outputSize = 1; + let reduceSize = 1; + for (let dim = 0; dim < axes.length; dim++) { + if (axes.indexOf(dim) === -1) { + outputSize *= shape[dim]; + } else { + reduceSize *= shape[dim]; + } + } + + // The condition data is very rough, although considering the count of Execution Unit (EU), the potential + // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments + // on some machines. + return reduceSize < 32 && outputSize > 1024 ? true : false; + }; + +export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMeanNaive(context, attributes); + } else { + reduceMeanShared(context, attributes); + } +}; + +export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL1Naive(context, attributes); + } else { + reduceL1Shared(context, attributes); + } +}; + +export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceL2Naive(context, attributes); + } else { + reduceL2Shared(context, attributes); + } +}; + +export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumExpNaive(context, attributes); + } else { + reduceLogSumExpShared(context, attributes); + } +}; + +export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMaxNaive(context, attributes); + } else { + reduceMaxShared(context, attributes); + } +}; + +export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceMinNaive(context, attributes); + } else { + reduceMinShared(context, attributes); + } +}; + +export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceProdNaive(context, attributes); + } else { + reduceProdShared(context, attributes); + } +}; + +export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumNaive(context, attributes); + } else { + reduceSumShared(context, attributes); + } +}; + +export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceSumSquareNaive(context, attributes); + } else { + reduceSumSquareShared(context, attributes); + } +}; + +export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => { + if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) { + reduceLogSumNaive(context, attributes); + } else { + reduceLogSumShared(context, attributes); + } +}; + export const parseReduceAttributes = (attributes: Record): ReduceAttributes => createAttributeWithCacheKey(attributes as Omit);