Skip to content

Commit

Permalink
[js/webgpu] Refactor attributes of pool (microsoft#18728)
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging authored Dec 27, 2023
1 parent b7f75ae commit 8376e55
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
Expand Down Expand Up @@ -63,7 +63,7 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
const sw = attributes.strides[attributes.strides.length - 1];
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
const pwEnd = attributes.pads[attributes.pads.length - 1];
const pwStartEnd = !!(pwStart + pwEnd);
const pwStartEndNotZero = !!(pwStart + pwEnd);
programUniforms.push(
{type: 'uint32', data: kw},
{type: 'uint32', data: sw},
Expand All @@ -74,13 +74,13 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
{name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'},
{name: 'pwEnd', type: 'u32'});

let phStartEnd = false;
let phStartEndNotZero = false;
if (attributes.kernelShape.length === 2) {
const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
const sh = attributes.strides[attributes.strides.length - 2];
const phStart = attributes.pads[attributes.pads.length / 2 - 2];
const phEnd = attributes.pads[attributes.pads.length - 2];
phStartEnd = !!(phStart + phEnd);
phStartEndNotZero = !!(phStart + phEnd);
programUniforms.push(
{type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart},
{type: 'uint32', data: phEnd});
Expand All @@ -89,7 +89,7 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
{name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'},
{name: 'phEnd', type: 'u32'});
}
return [programUniforms, uniforms, true, pwStartEnd, phStartEnd];
return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
} else {
if (isChannelsLast) {
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
Expand All @@ -110,8 +110,8 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo

const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType,
op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEnd: boolean,
phStartEnd: boolean): string => {
op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEndNotZero: boolean,
phStartEndNotZero: boolean): string => {
const isChannelsLast = attributes.format === 'NHWC';
const dataType = x.type.value;
const output = outputVariable('output', x.type.tensor, outputShapeRank);
Expand All @@ -121,7 +121,7 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
let codeH = '';
let codeHEnd = '';
const dimIdxW = rank - (isChannelsLast ? 2 : 1);
if (pwStartEnd === true) {
if (pwStartEndNotZero) {
codeW = `
for (var i: u32 = 0u; i < uniforms.kw; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
Expand All @@ -144,7 +144,7 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool

if (attributes.kernelShape.length === 2) {
const dimIdxH = rank - (isChannelsLast ? 3 : 2);
if (phStartEnd === true) {
if (phStartEndNotZero) {
codeH = `
for (var j: u32 = 0u; j < uniforms.kh; j++) {
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
Expand Down Expand Up @@ -258,6 +258,15 @@ export interface PoolCommonAttributes extends FormatAttributes {
readonly pads: readonly number[];
}

const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
(`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`);

const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string =>
(`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`);

const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string =>
(`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`);

const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
format: attributes.format as FormatAttributes['format'],
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
Expand Down Expand Up @@ -285,25 +294,22 @@ const createAveragePoolProgramInfo =
} else {
op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
}
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad,
inputDependencies
},
shaderCache:
{hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource: shaderHelper => generatePoolingCode(
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms,
hasPads, pwStartEnd, phStartEnd),
hasPads, pwStartEndNotZero, phStartEndNotZero),
};
};

Expand All @@ -315,8 +321,8 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}

return createAttributeWithCacheKey({countIncludePad, ...attr});
const averagePoolAttributes = {countIncludePad, ...attr, cacheKey: ''};
return {...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes)};
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
Expand All @@ -332,8 +338,7 @@ const globalPoolAttributes = {
strides: [],
pads: [],
storageOrder: 0,
dilations: [],
cacheKey: ''
dilations: []
};

export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
Expand Down Expand Up @@ -361,21 +366,21 @@ const createMaxPoolProgramInfo =
const op2 = '';
const x = inputVariable('x', input.dataType, input.dims.length);
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
return {
name,
shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies},
shaderCache:
{hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource: shaderHelper => generatePoolingCode(
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms,
hasPads, pwStartEnd, phStartEnd),
hasPads, pwStartEndNotZero, phStartEndNotZero),
};
};

Expand All @@ -396,8 +401,8 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}

return createAttributeWithCacheKey({storageOrder, dilations, ...attr});
const maxPoolAttributes = {storageOrder, dilations, ...attr, cacheKey: ''};
return {...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes)};
};

export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
Expand Down

0 comments on commit 8376e55

Please sign in to comment.