Skip to content

Commit

Permalink
[js/webgpu] Refactor attributes of pool
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 6, 2023
1 parent c012e41 commit 2988d51
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['GlobalAveragePool', [pool.globalAveragePool]],
['GlobalMaxPool', [pool.globalMaxPool]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
Expand All @@ -90,7 +90,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Log', [unaryOps.log]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['MaxPool', [pool.maxPool]],
['Mul', [binaryOps.mul]],
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
Expand Down
64 changes: 30 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
Expand Down Expand Up @@ -41,9 +40,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo

const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads, dilations});
} else {
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
Object.assign(newAttributes, {kernelShape, strides, pads});
}
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
Expand Down Expand Up @@ -258,6 +257,9 @@ export interface PoolCommonAttributes extends FormatAttributes {
readonly pads: readonly number[];
}

const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
(`${attributes.format as string};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`);

const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
format: attributes.format as FormatAttributes['format'],
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
Expand All @@ -267,7 +269,7 @@ const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCom
pads: attributes.pads as [number, number, number, number]
});

export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface AveragePoolAttributes extends PoolCommonAttributes {
readonly countIncludePad: boolean;
}

Expand All @@ -287,13 +289,13 @@ const createAveragePoolProgramInfo =
}
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad,
hint: createShaderKeyFromAttributes(attributes) +
`;${adjustedAttributes.countIncludePad};${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
Expand All @@ -307,21 +309,21 @@ const createAveragePoolProgramInfo =
};
};

export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;

const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode'
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}

return createAttributeWithCacheKey({countIncludePad, ...attr});
return {countIncludePad, ...attr};
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
export const averagePool = (context: ComputeContext, attributes: Record<string, unknown>): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
context.compute(
createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, parseAveragePoolAttributes(attributes)));
};

const globalPoolAttributes = {
Expand All @@ -336,17 +338,13 @@ const globalPoolAttributes = {
cacheKey: ''
};

export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
};

export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
context.compute(createAveragePoolProgramInfo(
'GlobalAveragePool', context.inputs[0], true, {format: attributes.format, ...globalPoolAttributes}));
};

export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
export interface MaxPoolAttributes extends PoolCommonAttributes {
readonly storageOrder: number;
readonly dilations: number[];
}
Expand All @@ -363,11 +361,14 @@ const createMaxPoolProgramInfo =
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
return {
name,
shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies},
shaderCache: {
hint: createShaderKeyFromAttributes(attributes) +
`;${attributes.storageOrder};${attributes.dilations}${hasPads};${pwStartEnd};${phStartEnd}`,
inputDependencies
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
Expand All @@ -379,12 +380,7 @@ const createMaxPoolProgramInfo =
};
};

export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
};

export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const storageOrder = attributes.storage_order as number;
const dilations = attributes.dilations as [number, number];

Expand All @@ -396,16 +392,16 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}

return createAttributeWithCacheKey({storageOrder, dilations, ...attr});
return {storageOrder, dilations, ...attr};
};

export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes, cacheKey: format};
export const maxPool = (context: ComputeContext, attributes: Record<string, unknown>): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, parseMaxPoolAttributes(attributes)));
};

export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));
context.compute(createMaxPoolProgramInfo(
'GlobalMaxPool', context.inputs[0], true, {format: attributes.format, ...globalPoolAttributes}));
};

0 comments on commit 2988d51

Please sign in to comment.