Skip to content

Commit

Permalink
[js/webgpu] Increase workgroupSize if only one workgroup is dispached (
Browse files Browse the repository at this point in the history
…#22709)

#22031

For reduce related ops, we should increase workgroupSize to improve
parallelism if only one workgroup is dispatched.

The total ReduceMean time becomes 8.98 ms from 77.79 ms on my iGPUs.
  • Loading branch information
qjia7 authored and guschmue committed Dec 2, 2024
1 parent d1e442e commit b906b30
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
15 changes: 10 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types';
import { ComputeContext, ProgramInfo } from '../types';

import { inputVariable, outputVariable, ShaderHelper } from './common';
import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce';
Expand Down Expand Up @@ -119,7 +119,7 @@ const getAxesPermutation = (axes: number[], rank: number): number[] => {

export const createReduceSharedProgramInfo = (
name: string,
shaderCache: ProgramShaderCacheInfo,
cacheKey: string,
inputs: readonly TensorView[],
reduceType: string,
outputDataType: DataType,
Expand All @@ -134,7 +134,11 @@ export const createReduceSharedProgramInfo = (
const input = inputVariable('_A', inputs[0].dataType, inputShape);
const output = outputVariable('output', outputDataType, outputShape);

const workgroupSize = 32;
let workgroupSize = 64;
// If only one workgroup is dispatched, increase workgroupSize to improve parallelism.
if (outputSize === 1) {
workgroupSize = 256;
}

const sharedMemorySnippet = `
var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
Expand Down Expand Up @@ -188,7 +192,8 @@ export const createReduceSharedProgramInfo = (
// One work group is responsible for only one element of output.
return {
name,
shaderCache,
// Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is.
shaderCache: { hint: `${cacheKey};${workgroupSize}`, inputDependencies: ['type'] },
getShaderSource,
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: outputDataType }],
Expand Down Expand Up @@ -233,7 +238,7 @@ const reduceCommon = (
context.compute(
createReduceSharedProgramInfo(
name,
{ hint: updatedAttributes.cacheKey, inputDependencies: ['type'] },
updatedAttributes.cacheKey,
[input],
reduceType,
context.inputs[0].dataType,
Expand Down
12 changes: 8 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt
const input = context.inputs[0];
const inputShape = input.dims;
const outputSize = ShapeUtil.size(inputShape);
const WG = 64;
const inputRank = inputShape.length;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const isTransposeRequired = axis < inputShape.length - 1;
Expand All @@ -60,7 +59,11 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;

let WG = 64;
// If only one workgroup is dispatched, increase workgroupSize to improve parallelism.
if (rows === 1) {
WG = 256;
}
const maxVector = (name: string, components: number) => {
if (components === 4) {
return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
Expand Down Expand Up @@ -95,7 +98,7 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt
result[index] = value;
}
${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
${shaderHelper.mainStart()}
${shaderHelper.mainStart(WG)}
let gindex = i32(global_idx);
let lindex = i32(local_idx);
const wg = ${WG};
Expand Down Expand Up @@ -156,7 +159,8 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt
const result = context.compute(
{
name: 'Softmax',
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
// Note that in JSEP, WG size is not included in cache by default, but WebGPU EP it is.
shaderCache: { hint: `${components};${WG}`, inputDependencies: ['type'] },
getRunData: () => ({
outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }],
dispatchGroup: { x: rows },
Expand Down

0 comments on commit b906b30

Please sign in to comment.