Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Refactor attributes of pool #18728

Merged
merged 13 commits into from
Dec 27, 2023
37 changes: 20 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
Expand Down Expand Up @@ -41,9 +40,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo

const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads, dilations});
} else {
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads});
}
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
Expand Down Expand Up @@ -258,6 +257,9 @@ export interface PoolCommonAttributes extends FormatAttributes {
readonly pads: readonly number[];
}

const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
(`${attributes.format as string};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`);
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
format: attributes.format as FormatAttributes['format'],
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
Expand All @@ -267,7 +269,7 @@ const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCom
pads: attributes.pads as [number, number, number, number]
});

export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface AveragePoolAttributes extends PoolCommonAttributes {
readonly countIncludePad: boolean;
}

Expand All @@ -287,13 +289,13 @@ const createAveragePoolProgramInfo =
}
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad,
hint: createShaderKeyFromAttributes(attributes) +
`;${adjustedAttributes.countIncludePad};${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
Expand All @@ -315,8 +317,7 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}

return createAttributeWithCacheKey({countIncludePad, ...attr});
return {countIncludePad, ...attr};
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
Expand All @@ -338,15 +339,15 @@ const globalPoolAttributes = {

export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
guschmue marked this conversation as resolved.
Show resolved Hide resolved
return {format, ...globalPoolAttributes};
};

export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
};

export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface MaxPoolAttributes extends PoolCommonAttributes {
readonly storageOrder: number;
readonly dilations: number[];
}
Expand All @@ -363,11 +364,14 @@ const createMaxPoolProgramInfo =
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
return {
name,
shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies},
shaderCache: {
hint: createShaderKeyFromAttributes(attributes) +
`;${attributes.storageOrder};${attributes.dilations}${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
Expand Down Expand Up @@ -396,13 +400,12 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}

return createAttributeWithCacheKey({storageOrder, dilations, ...attr});
return {storageOrder, dilations, ...attr};
};

export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
return {format, ...globalPoolAttributes};
};

export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
Expand Down