Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Refactor attributes of pool #18728

Merged
merged 13 commits into from
Dec 27, 2023
8 changes: 4 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Atanh', [unaryOps.atanh]],
['Attention', [attention, parseAttentionAttributes]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['AveragePool', [pool.averagePool]],
['BatchNormalization', [batchNorm]],
['BiasAdd', [biasAdd]],
['BiasSplitGelu', [biasSplitGelu]],
Expand All @@ -78,8 +78,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
guschmue marked this conversation as resolved.
Show resolved Hide resolved
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['GlobalAveragePool', [pool.globalAveragePool]],
['GlobalMaxPool', [pool.globalMaxPool]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
Expand All @@ -90,7 +90,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Log', [unaryOps.log]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['MaxPool', [pool.maxPool]],
['Mul', [binaryOps.mul]],
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
Expand Down
66 changes: 31 additions & 35 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
Expand Down Expand Up @@ -41,9 +40,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo

const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads, dilations});
} else {
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads});
}
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
Expand Down Expand Up @@ -246,7 +245,7 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
}
};

export interface FormatAttributes {
export interface FormatAttributes extends Record<string, unknown> {
readonly format: 'NHWC'|'NCHW';
}

Expand All @@ -258,7 +257,10 @@ export interface PoolCommonAttributes extends FormatAttributes {
readonly pads: readonly number[];
}

const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
(`${attributes.format as string};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`);
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

const parsePoolCommonAttributes = (attributes: PoolCommonAttributes): PoolCommonAttributes => ({
format: attributes.format as FormatAttributes['format'],
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
ceilMode: attributes.ceil_mode as number,
Expand All @@ -267,7 +269,7 @@ const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCom
pads: attributes.pads as [number, number, number, number]
});

export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface AveragePoolAttributes extends PoolCommonAttributes {
readonly countIncludePad: boolean;
}

Expand All @@ -287,13 +289,13 @@ const createAveragePoolProgramInfo =
}
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad,
hint: createShaderKeyFromAttributes(attributes) +
`;${adjustedAttributes.countIncludePad};${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
Expand All @@ -307,21 +309,21 @@ const createAveragePoolProgramInfo =
};
};

export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const parseAveragePoolAttributes = (attributes: AveragePoolAttributes): AveragePoolAttributes => {
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;

const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode'
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}

return createAttributeWithCacheKey({countIncludePad, ...attr});
return {countIncludePad, ...attr};
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
context.compute(
createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, parseAveragePoolAttributes(attributes)));
};

const globalPoolAttributes = {
Expand All @@ -336,17 +338,13 @@ const globalPoolAttributes = {
cacheKey: ''
};

export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
guschmue marked this conversation as resolved.
Show resolved Hide resolved
};

export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
context.compute(createAveragePoolProgramInfo(
'GlobalAveragePool', context.inputs[0], true, {format: attributes.format, ...globalPoolAttributes}));
};

export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface MaxPoolAttributes extends PoolCommonAttributes {
readonly storageOrder: number;
readonly dilations: number[];
}
Expand All @@ -363,11 +361,14 @@ const createMaxPoolProgramInfo =
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
return {
name,
shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies},
shaderCache: {
hint: createShaderKeyFromAttributes(attributes) +
`;${attributes.storageOrder};${attributes.dilations}${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
Expand All @@ -379,12 +380,7 @@ const createMaxPoolProgramInfo =
};
};

export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
};

export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const parseMaxPoolAttributes = (attributes: MaxPoolAttributes): MaxPoolAttributes => {
const storageOrder = attributes.storage_order as number;
const dilations = attributes.dilations as [number, number];

Expand All @@ -396,16 +392,16 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}

return createAttributeWithCacheKey({storageOrder, dilations, ...attr});
return {storageOrder, dilations, ...attr};
};

export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, parseMaxPoolAttributes(attributes)));
};

export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));
context.compute(createMaxPoolProgramInfo(
'GlobalMaxPool', context.inputs[0], true, {format: attributes.format, ...globalPoolAttributes}));
};