Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] add ceilmode support to pool #21231

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions js/web/lib/wasm/jsep/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,11 @@ export class PoolConvUtil {
* @param pads Padding for the beginning and ending along each axis.
* @param autoPad DEPRECATED attribute supported for legacy models. Specifies how to implicitly calculate pads in each
* dimension. Can take values NOTSET, SAME_UPPER, SAME_LOWER, or VALID.
* @param ceilMode: 0=floor, 1=ceil
*/
static computePoolOutputShape(
isGlobalOperator: boolean, inputDims: readonly number[], strides: number[], dilations: number[],
kernelShape: number[], pads: number[], autoPad?: string): number[] {
kernelShape: number[], pads: number[], autoPad?: string, ceilMode?: number): number[] {
if (inputDims.length <= 0) {
throw new Error('input shape must be of size greater than 0');
}
Expand All @@ -356,7 +357,7 @@ export class PoolConvUtil {
const outputDims = [inputDims[0], inputDims[1]];

PoolConvUtil.computeShapeHelper(
isGlobalOperator, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad);
isGlobalOperator, inputDims, outputDims, strides, dilations, kernelShape, pads, autoPad, ceilMode);
return outputDims;
}

Expand Down Expand Up @@ -389,7 +390,8 @@ export class PoolConvUtil {
// adjust pads based on 'autoPad' attribute prior to shape computation
private static computeShapeHelper(
isGlobalOperator: boolean, inputDims: readonly number[], outputDims: number[], strides: readonly number[],
dilations: readonly number[], kernelShape: readonly number[], pads: number[], autoPad?: string) {
dilations: readonly number[], kernelShape: readonly number[], pads: number[], autoPad?: string,
ceilMode?: number) {
if (isGlobalOperator) {
for (let dim = 0; dim < inputDims.length - 2; dim++) {
outputDims.push(1);
Expand All @@ -398,7 +400,7 @@ export class PoolConvUtil {
for (let dim = 0; dim < inputDims.length - 2; dim++) {
outputDims.push(PoolConvUtil.adjustPadAndReturnShape(
inputDims[dim + 2], strides[dim], dilations[dim], kernelShape[dim], pads, dim, dim + inputDims.length - 2,
autoPad));
autoPad, ceilMode));
}
}
}
Expand All @@ -407,7 +409,8 @@ export class PoolConvUtil {
// adjusts pad value for given 'autoPad' string and computes output shape along a particular dimension
private static adjustPadAndReturnShape(
inSize: number, stride: number, dilation: number, kernel: number, pads: number[], padHeadIndex: number,
padTailIndex: number, autoPad?: string): number {
padTailIndex: number, autoPad?: string, ceilMode?: number): number {
const ceilFunc = (ceilMode) ? Math.ceil : Math.floor;
const dkernel = dilation * (kernel - 1) + 1;
if (autoPad && autoPad !== 'NOTSET') {
switch (autoPad) {
Expand All @@ -425,13 +428,13 @@ export class PoolConvUtil {
pads[padHeadIndex] =
(autoPad === 'SAME_LOWER') ? Math.floor((padNeeded + 1) / 2) : Math.floor(padNeeded / 2);
pads[padTailIndex] = padNeeded - pads[padHeadIndex];
return Math.floor(((inSize + padNeeded - kernel) / stride) + 1);
return ceilFunc(((inSize + padNeeded - kernel) / stride) + 1);
}
default:
throw new Error('Unsupported AutoPad type');
}
} else {
return Math.floor(((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride) + 1);
return ceilFunc(((inSize + pads[padHeadIndex] + pads[padTailIndex] - dkernel) / stride) + 1);
}
}
}
Expand Down
12 changes: 3 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePo
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);

const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape(
isGlobalOperator, inputShapeAsChannelFirst, strides, dilations, kernelShape, pads, attributes.autoPad);
isGlobalOperator, inputShapeAsChannelFirst, strides, dilations, kernelShape, pads, attributes.autoPad,
attributes.ceilMode);

const newAttributes = Object.assign({}, attributes);
if (hasDilations) {
Expand Down Expand Up @@ -319,10 +320,6 @@ export const parseAveragePoolAttributes = (attributes: Record<string, unknown>):
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;

const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode'
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
}
const averagePoolAttributes = {countIncludePad, ...attr, cacheKey: ''};
return {...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes)};
};
Expand Down Expand Up @@ -397,13 +394,10 @@ export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): Max
const dilations = attributes.dilations as [number, number];

const attr = parsePoolCommonAttributes(attributes);
// TODO: support attribute 'ceil_mode' and 'storage_order'
// TODO: support attribute 'storage_order'
if (storageOrder !== 0) {
throw new Error('column major storage order is not yet supported for MaxPool');
}
if (attr.ceilMode !== 0) {
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
}
const maxPoolAttributes = {storageOrder, dilations, ...attr, cacheKey: ''};
return {...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes)};
};
Expand Down
2 changes: 1 addition & 1 deletion js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@
"test_atanh_example",
"test_atanh",
// "test_averagepool_1d_default",
// "test_averagepool_2d_ceil",
"test_averagepool_2d_ceil",
"test_averagepool_2d_default",
"test_averagepool_2d_pads_count_include_pad",
"test_averagepool_2d_pads",
Expand Down
Loading