Skip to content

Commit

Permalink
Reenable cacheKey
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 15, 2023
1 parent 8a65b01 commit 7318183
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {env} from 'onnxruntime-common';

import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
Expand Down Expand Up @@ -40,9 +41,9 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo

const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Object.assign(newAttributes, {kernelShape, strides, pads, dilations});
Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey});
} else {
Object.assign(newAttributes, {kernelShape, strides, pads});
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
}
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
Expand All @@ -62,7 +63,7 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
const sw = attributes.strides[attributes.strides.length - 1];
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
const pwEnd = attributes.pads[attributes.pads.length - 1];
const pwStartEnd = !!(pwStart + pwEnd);
const pwStartEndNotZero = !!(pwStart + pwEnd);
programUniforms.push(
{type: 'uint32', data: kw},
{type: 'uint32', data: sw},
Expand All @@ -73,13 +74,13 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
{name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'},
{name: 'pwEnd', type: 'u32'});

let phStartEnd = false;
let phStartEndNotZero = false;
if (attributes.kernelShape.length === 2) {
const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
const sh = attributes.strides[attributes.strides.length - 2];
const phStart = attributes.pads[attributes.pads.length / 2 - 2];
const phEnd = attributes.pads[attributes.pads.length - 2];
phStartEnd = !!(phStart + phEnd);
phStartEndNotZero = !!(phStart + phEnd);
programUniforms.push(
{type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart},
{type: 'uint32', data: phEnd});
Expand All @@ -88,7 +89,7 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
{name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'},
{name: 'phEnd', type: 'u32'});
}
return [programUniforms, uniforms, true, pwStartEnd, phStartEnd];
return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
} else {
if (isChannelsLast) {
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
Expand All @@ -109,8 +110,8 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo

const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType,
op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEnd: boolean,
phStartEnd: boolean): string => {
op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEndNotZero: boolean,
phStartEndNotZero: boolean): string => {
const isChannelsLast = attributes.format === 'NHWC';
const dataType = x.type.value;
const output = outputVariable('output', x.type.tensor, outputShapeRank);
Expand All @@ -120,7 +121,7 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool
let codeH = '';
let codeHEnd = '';
const dimIdxW = rank - (isChannelsLast ? 2 : 1);
if (pwStartEnd === true) {
if (pwStartEndNotZero === true) {
codeW = `
for (var i: u32 = 0u; i < uniforms.kw; i++) {
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
Expand All @@ -143,7 +144,7 @@ const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPool

if (attributes.kernelShape.length === 2) {
const dimIdxH = rank - (isChannelsLast ? 3 : 2);
if (phStartEnd === true) {
if (phStartEndNotZero === true) {
codeH = `
for (var j: u32 = 0u; j < uniforms.kh; j++) {
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
Expand Down Expand Up @@ -269,7 +270,7 @@ const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCom
pads: attributes.pads as [number, number, number, number]
});

export interface AveragePoolAttributes extends PoolCommonAttributes {
export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
readonly countIncludePad: boolean;
}

Expand All @@ -287,15 +288,15 @@ const createAveragePoolProgramInfo =
} else {
op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
}
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
shaderCache: {
hint: createShaderKeyFromAttributes(attributes) +
`;${adjustedAttributes.countIncludePad};${hasPads};${pwStartEnd};${phStartEnd}`,
hint: `${attributes.cacheKey};${attributes.countIncludePad};${hasPads};${pwStartEndNotZero};${
phStartEndNotZero}`,
inputDependencies
},
getRunData: () => ({
Expand All @@ -305,7 +306,7 @@ const createAveragePoolProgramInfo =
}),
getShaderSource: shaderHelper => generatePoolingCode(
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms,
hasPads, pwStartEnd, phStartEnd),
hasPads, pwStartEndNotZero, phStartEndNotZero),
};
};

Expand All @@ -317,7 +318,7 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}
return {countIncludePad, ...attr};
return {countIncludePad, ...attr, cacheKey: createShaderKeyFromAttributes(attr)};
};

export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
Expand All @@ -338,15 +339,16 @@ const globalPoolAttributes = {

export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes};
const attributesWithoutCacheKey = {format, ...globalPoolAttributes};
return {...attributesWithoutCacheKey, cacheKey: createShaderKeyFromAttributes(attributesWithoutCacheKey)};
};

export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
validateInputs(context.inputs);
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
};

export interface MaxPoolAttributes extends PoolCommonAttributes {
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
readonly storageOrder: number;
readonly dilations: number[];
}
Expand All @@ -361,14 +363,14 @@ const createMaxPoolProgramInfo =
const op2 = '';
const x = inputVariable('x', input.dataType, input.dims.length);
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
return {
name,
shaderCache: {
hint: createShaderKeyFromAttributes(attributes) +
`;${attributes.storageOrder};${attributes.dilations}${hasPads};${pwStartEnd};${phStartEnd}`,
hint: `${attributes.cacheKey};${attributes.storageOrder};${attributes.dilations};${hasPads};${
pwStartEndNotZero};${phStartEndNotZero}`,
inputDependencies
},
getRunData: () => ({
Expand All @@ -378,7 +380,7 @@ const createMaxPoolProgramInfo =
}),
getShaderSource: shaderHelper => generatePoolingCode(
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms,
hasPads, pwStartEnd, phStartEnd),
hasPads, pwStartEndNotZero, phStartEndNotZero),
};
};

Expand All @@ -399,12 +401,13 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}
return {storageOrder, dilations, ...attr};
return {storageOrder, dilations, ...attr, cacheKey: createShaderKeyFromAttributes(attr)};
};

export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
const format = attributes.format as FormatAttributes['format'];
return {format, ...globalPoolAttributes};
const attributesWithoutCacheKey = {format, ...globalPoolAttributes};
return {...attributesWithoutCacheKey, cacheKey: createShaderKeyFromAttributes(attributesWithoutCacheKey)};
};

export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
Expand Down

0 comments on commit 7318183

Please sign in to comment.