From e36770514b34cea21437f27f7cf7fc6ba6e47c85 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 15 Sep 2023 09:38:38 -0700 Subject: [PATCH 01/18] Add ConvTranspose implementation using MatMul. --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 176 ++++++++++++++++++ .../wasm/jsep/webgpu/ops/conv-transpose.ts | 62 +++++- .../jsep/webgpu/ops/conv2dtranspose-mm.ts | 29 +++ 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..8aaec04b13617 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,176 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvTransposeAttributes} from '../conv-transpose'; + +import {typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = W[getIndexFromCoords4D(coord, wShape)]; + let v1 = W[getIndexFromCoords4D(coord1, wShape)]; + let v2 = W[getIndexFromCoords4D(coord2, wShape)]; + let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + + const readASnippet = ` + let outRow = row / outShape[2]; + let outCol = row % outShape[2]; + + let WRow = col / (filterDims[1] * outBackprop[3]); + let WCol = col / outBackprop[3] % filterDims[1]; + let xR = f32(outRow - pads[0] + WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(outBackprop[1]) || fract(xR) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + if (xC < 0.0 || xC >= f32(outBackprop[2]) || fract(xC) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + let coord = vec4( + batch, + i32(xR), + i32(xC), + col % outBackprop[3]); + return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + + const sampleA = `if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);`; + + const userCode = ` + fn mm_readA(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { + ${sampleA} + } + + fn mm_readB(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { + let coordX = filterDims.x - 1 - + row / (filterDims[1] * outBackprop[3]); + let coordY = filterDims.y - 1 - + (row / outBackprop[3]) % filterDims[1]; + if (row < dimInner && col < dimBOuter && + coordX >= 0 && coordY >= 0) { + let rowInner = row % outBackprop[3]; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${typeSnippet(innerElementSize)}(0.0); + } + + fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${typeSnippet(innerElementSize)}) { + if (row < dimAOuter && col < dimBOuter) { + var value = valueInput; + let outCoord = vec4( + batch, + row / outShape[2], + row % outShape[2], + col); + result[getIndexFromCoords4D(outCoord, outShape)/${innerElementSize}] = value; + } + }`; + return userCode; +}; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || + (outWidth % 4 === 0 && !isChannelsLast)) && + outChannels % 4 === 0; + + const dispatchX = !isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = !isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = + isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, + `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;` + ]; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + ${declareInputs.join('')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); + const pads : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${conv2dTransposeCommonSnippet(innerElementSize)} + ${ + isVec4 ? + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e7d1ddf771650..8a90b14fd4f91 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -8,7 +8,9 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '. import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; +import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -226,11 +228,69 @@ const createConvTranspose2DProgramInfoLoader = }; }; +// for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1 || hasBias) { + context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + // const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; + // const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + + + // const dimAOuter = inputHeight * inputWidth; + // const dimBOuter = inputChannels; + // const dimInner = weightHeight * weightWidth * outChannels; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfoLoader( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts new file mode 100644 index 0000000000000..f793c78d9613d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +import {ConvTransposeAttributes} from './conv-transpose'; + + +const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DTransposeMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DTransposeMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DTransposeMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; From dfeb109e51c1df3a1736856e3acae258e7510b57 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 15 Sep 2023 11:17:20 -0700 Subject: [PATCH 02/18] Merged changes from main. --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 8aaec04b13617..abdbe8aeaaa94 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -20,7 +20,7 @@ // modified to fit the needs of the project import {LOG_DEBUG} from '../../../log'; -import {TensorView} from '../../../tensor'; +import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {ConvTransposeAttributes} from '../conv-transpose'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts index f793c78d9613d..da04b5063a9f0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {TensorView} from '../../tensor'; +import {TensorView} from '../../tensor-view'; import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; From fdd29bc6bee6f23d8bd0b2f6868cb33d5b956f55 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 18 Sep 2023 17:54:40 -0700 Subject: [PATCH 03/18] Hardcode dispatch. --- .../jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index abdbe8aeaaa94..f96b4ccd18ceb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -123,10 +123,9 @@ export const createConv2DTransposeMatMulProgramInfo = (outWidth % 4 === 0 && !isChannelsLast)) && outChannels % 4 === 0; - const dispatchX = !isChannelsLast ? outChannels : outWidth * outHeight; - const dispatchY = !isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? [8, 8, 1] : [4, 4, 1]; const elementsPerThread = isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; const dispatch = [ From 8ed9038b7c3dcdaac1453136f8c87b71e781fd84 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 15 Sep 2023 10:58:29 -0700 Subject: [PATCH 04/18] Added a new test case to exercise vec4 version. --- js/web/test/data/ops/conv-transpose.jsonc | 43 +++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..9079e466be400 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -285,5 +285,48 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] From 6794407eaf456ecb6011ae6bef719c2d7ead8570 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 19 Sep 2023 22:46:24 -0700 Subject: [PATCH 05/18] Fixed filter setting --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 72 +++++++++++-------- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 4 +- js/web/test/data/ops/conv-transpose.jsonc | 2 + 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index f96b4ccd18ceb..1bec542521202 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -25,17 +25,18 @@ import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {typeSnippet} from './activation_util'; +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { - const getWSnippet = (innerElementSize: number) => { - switch (innerElementSize) { - case 1: - return 'return W[getIndexFromCoords4D(coord, wShape)];'; - case 4: - return ` +const conv2dTransposeCommonSnippet = + (addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); @@ -45,12 +46,12 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { let v3 = W[getIndexFromCoords4D(coord3, wShape)]; return vec4(v0, v1, v2, v3); `; - default: - throw new Error(`innerElementSize ${innerElementSize} is not supported.`); - } - }; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; - const readASnippet = ` + const readASnippet = ` let outRow = row / outShape[2]; let outCol = row % outShape[2]; @@ -71,12 +72,13 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { col % outBackprop[3]); return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; - const sampleA = `if (row < dimAOuter && col < dimInner) { + const sampleA = `if (row < dimAOuter && col < dimInner) { ${readASnippet} } return ${typeSnippet(innerElementSize)}(0.0);`; - const userCode = ` + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { ${sampleA} } @@ -98,16 +100,17 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${typeSnippet(innerElementSize)}) { if (row < dimAOuter && col < dimBOuter) { var value = valueInput; - let outCoord = vec4( + let coords = vec4( batch, row / outShape[2], row % outShape[2], col); - result[getIndexFromCoords4D(outCoord, outShape)/${innerElementSize}] = value; + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; } }`; - return userCode; -}; + return userCode; + }; export const createConv2DTransposeMatMulProgramInfo = (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, @@ -125,7 +128,8 @@ export const createConv2DTransposeMatMulProgramInfo = const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = isVec4 ? [8, 8, 1] : [4, 4, 1]; + const workGroupSize: [number, number, number] = + isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; const elementsPerThread = isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; const dispatch = [ @@ -133,23 +137,32 @@ export const createConv2DTransposeMatMulProgramInfo = Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) ]; - const innerElementSize = isVec4 ? 4 : 1; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? (inChannels % 4 !== 0 ? 3 : 4) : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const declareInputs = [ `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;` + '@group(0) @binding(1) var W: array;' ]; - + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } return { ...metadata, outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: () => ` ${utilFunctions} - ${declareInputs.join('')} + ${declareInputs.join('\n')} @group(0) @binding(${declareInputs.length}) var result: array<${ isVec4 ? 'vec4' : 'f32'}>; const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -157,14 +170,17 @@ export const createConv2DTransposeMatMulProgramInfo = const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); const outShape : vec4 = vec4(${outputShape.join(',')}); const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pads : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const pads : vec2 = vec2(i32(filterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, + i32(filterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); const dimAOuter : i32 = ${dimAOuter}; const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; - ${conv2dTransposeCommonSnippet(innerElementSize)} + ${declareFunctions} + ${conv2dTransposeCommonSnippet(hasBias, undefined, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 8a90b14fd4f91..40ca2fbd1430d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -65,7 +65,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -236,7 +236,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1 || hasBias) { + if (adjustedAttributes.group !== 1 || !isChannelsLast) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 9079e466be400..9aa5d802ac10f 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -289,6 +289,8 @@ { "name": "ConvTranspose with bias addition C", "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], "cases": [ { From 33ecf4c2cd80b3190544a6babddb9b14383bc043 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 14:15:06 -0700 Subject: [PATCH 06/18] Added dilation support for ConvTranspose matmul implementation. --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 1bec542521202..ce1724a47101d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -57,8 +57,8 @@ const conv2dTransposeCommonSnippet = let WRow = col / (filterDims[1] * outBackprop[3]); let WCol = col / outBackprop[3] % filterDims[1]; - let xR = f32(outRow - pads[0] + WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + WCol) / f32(strides[1]); + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); if (xR < 0.0 || xR >= f32(outBackprop[1]) || fract(xR) > 0.0) { return ${typeSnippet(innerElementSize)}(0.0); } From 01c644f6f8bcfcc9539d318ecb96d7723cdfb412 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 15:05:33 -0700 Subject: [PATCH 07/18] Added dilation fix. --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index ce1724a47101d..b78eb8c9e2ca4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -172,8 +172,19 @@ export const createConv2DTransposeMatMulProgramInfo = const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const pads : vec2 = vec2(i32(filterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(filterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); const dimAOuter : i32 = ${dimAOuter}; From cbf19076ff195543dded5c0d2c05bdef9bebdd79 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 16:24:57 -0700 Subject: [PATCH 08/18] Formatted --- js/web/test/data/ops/conv-transpose.jsonc | 1 - onnxruntime/core/providers/cpu/nn/conv_transpose.h | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 9aa5d802ac10f..2ec2b55fccb3b 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -251,7 +251,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index c82cd5ad49d7e..511def01be9a7 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -41,6 +41,7 @@ class ConvTranspose : public OpKernel { Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; private: + ConvTransposeAttributes conv_transpose_attrs_; // for pre-packing usage From 2a1d2367d497324d27bd41bb0deb93e10dc45edb Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 16:40:49 -0700 Subject: [PATCH 09/18] Fix format. --- onnxruntime/core/providers/cpu/nn/conv_transpose.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose.h b/onnxruntime/core/providers/cpu/nn/conv_transpose.h index 511def01be9a7..c82cd5ad49d7e 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose.h @@ -41,7 +41,6 @@ class ConvTranspose : public OpKernel { Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; private: - ConvTransposeAttributes conv_transpose_attrs_; // for pre-packing usage From 36689641841d9efedc134d323c5ff3770814a2f8 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 22:10:16 -0700 Subject: [PATCH 10/18] Minor change --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b78eb8c9e2ca4..38b4e841209da 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -122,7 +122,7 @@ export const createConv2DTransposeMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || + const isVec4 = ((inChannels % 4 === 0 && isChannelsLast) || (outWidth % 4 === 0 && !isChannelsLast)) && outChannels % 4 === 0; From fd374e9c7e54ed3affbacb95f75778da2c492627 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 23:06:53 -0700 Subject: [PATCH 11/18] Minor optimization changes. --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..4c8922238ac5b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -197,14 +197,14 @@ const createConvTranspose2DOpProgramShaderSource = continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 40ca2fbd1430d..bbb1e2b610d7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -292,6 +292,7 @@ const convTranspose2d = sequentialAccessByThreads), {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; From 4cd2334f7a3046d40031ca81ba7b79a12e0b4de9 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 23:12:35 -0700 Subject: [PATCH 12/18] modified logic to increase readability. --- .../jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 38b4e841209da..f3f4151ce06ba 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -122,9 +122,9 @@ export const createConv2DTransposeMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = ((inChannels % 4 === 0 && isChannelsLast) || - (outWidth % 4 === 0 && !isChannelsLast)) && - outChannels % 4 === 0; + const isVec4 = + ((inChannels % 4 === 0 && outChannels % 4 === 0 && isChannelsLast) || + (outWidth % 4 === 0 && !isChannelsLast)); const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; From 1c32af2fdebf1ce2a449820e4719e898ab29824d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 22 Sep 2023 18:18:19 -0700 Subject: [PATCH 13/18] Add ConvTranspose MatMul support NHWC --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 131 ++++++++++++------ .../wasm/jsep/webgpu/ops/conv-transpose.ts | 9 +- js/web/test/data/ops/conv-transpose.jsonc | 85 ++++++++++-- .../providers/js/operators/conv_transpose.h | 17 +++ 4 files changed, 177 insertions(+), 65 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index f3f4151ce06ba..e68ea844b8fe8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -30,7 +30,8 @@ import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -50,63 +51,100 @@ const conv2dTransposeCommonSnippet = throw new Error(`innerElementSize ${innerElementSize} is not supported.`); } }; - + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let outRow = row / outShape[2]; - let outCol = row % outShape[2]; - - let WRow = col / (filterDims[1] * outBackprop[3]); - let WCol = col / outBackprop[3] % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); - if (xR < 0.0 || xR >= f32(outBackprop[1]) || fract(xR) > 0.0) { + let inChannels = wShape[2]; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow / strides[0] - pads[0] + dilation[0] * WRow); + let xC = f32(outCol / strides[1] - pads[1] + dilation[1] * WCol); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${typeSnippet(innerElementSize)}(0.0); } - if (xC < 0.0 || xC >= f32(outBackprop[2]) || fract(xC) > 0.0) { + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { return ${typeSnippet(innerElementSize)}(0.0); } - let coord = vec4( - batch, - i32(xR), - i32(xC), - col % outBackprop[3]); + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; - const sampleA = `if (row < dimAOuter && col < dimInner) { + const sampleA = isChannelsLast ? `let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);` : + `let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { + let col = colIn * ${innerElementSize}; ${readASnippet} } return ${typeSnippet(innerElementSize)}(0.0);`; + const sampleW = ` + let col = colIn; + let inChannels = wShape[2]; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < dimInner && col < dimBOuter' : + 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${typeSnippet(innerElementSize)}(0.0); + `; + + const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} - fn mm_readA(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { - ${sampleA} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleA : sampleW} } - fn mm_readB(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { - let coordX = filterDims.x - 1 - - row / (filterDims[1] * outBackprop[3]); - let coordY = filterDims.y - 1 - - (row / outBackprop[3]) % filterDims[1]; - if (row < dimInner && col < dimBOuter && - coordX >= 0 && coordY >= 0) { - let rowInner = row % outBackprop[3]; - let coord = vec4(coordX, coordY, col, rowInner); - ${getWSnippet(innerElementSize)} - } - return ${typeSnippet(innerElementSize)}(0.0); + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleW : sampleA} } - fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${typeSnippet(innerElementSize)}) { + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + let col = colIn * ${innerElementSize}; if (row < dimAOuter && col < dimBOuter) { var value = valueInput; - let coords = vec4( - batch, - row / outShape[2], - row % outShape[2], - col); - ${biasActivationSnippet(addBias, activation)} - result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; } }`; return userCode; @@ -123,15 +161,16 @@ export const createConv2DTransposeMatMulProgramInfo = const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; const isVec4 = - ((inChannels % 4 === 0 && outChannels % 4 === 0 && isChannelsLast) || - (outWidth % 4 === 0 && !isChannelsLast)); + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), @@ -140,7 +179,7 @@ export const createConv2DTransposeMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (inChannels % 4 !== 0 ? 3 : 4) : 1; + const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); @@ -191,7 +230,7 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(hasBias, undefined, false, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index bbb1e2b610d7d..c993102ae1596 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -236,7 +236,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1 || !isChannelsLast) { + if (adjustedAttributes.group !== 1) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } @@ -246,15 +246,8 @@ const convTranspose2d = const outChannels = outputShape[isChannelsLast ? 3 : 1]; const weightHeight = inputs[1].dims[2]; const weightWidth = inputs[1].dims[3]; - // const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; - // const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - - // const dimAOuter = inputHeight * inputWidth; - // const dimBOuter = inputChannels; - // const dimInner = weightHeight * weightWidth * outChannels; - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; const dimInner = weightHeight * weightWidth * inputChannels; diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 2ec2b55fccb3b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index a5aeae8646373..c3babbc5ce81f 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel { } } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + protected: ConvTransposeAttributes conv_transpose_attrs_; bool w_is_const_; From 8200c4c345c6c0cc8a49feca5bca9fd394d32d38 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 25 Sep 2023 16:26:38 -0700 Subject: [PATCH 14/18] Use effectiveFilterDims. --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index e68ea844b8fe8..c710f1d134164 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -77,8 +77,9 @@ const conv2dTransposeCommonSnippet = const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; + const readASnippet = ` - let inChannels = wShape[2]; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,23 +100,24 @@ const conv2dTransposeCommonSnippet = ${coordASnippet} return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; - const sampleA = isChannelsLast ? `let col = colIn * ${innerElementSize}; + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; if (row < dimAOuter && col < dimInner) { ${readASnippet} } return ${typeSnippet(innerElementSize)}(0.0);` : - `let col = colIn * ${innerElementSize}; - if (row < dimInner && col < dimBOuter) { - let col = colIn * ${innerElementSize}; + ` + let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { ${readASnippet} } return ${typeSnippet(innerElementSize)}(0.0);`; const sampleW = ` - let col = colIn; - let inChannels = wShape[2]; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = effectiveFilterDims.x - 1 - row / (effectiveFilterDims[1] * inChannels); + let coordY = effectiveFilterDims.y - 1 - (row / inChannels) % effectiveFilterDims[1]; if (${ isChannelsLast ? 'row < dimInner && col < dimBOuter' : 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { From 1568f5eac2bd05e3088d9a7e9ff77a378bd53626 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 25 Sep 2023 16:27:26 -0700 Subject: [PATCH 15/18] Use naive implementation if stride is non-trivial. --- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index c993102ae1596..eb6d43d372684 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -228,7 +228,7 @@ const createConvTranspose2DProgramInfoLoader = }; }; -// for transposing weight tensor from [M, C/group, KH, KW] to [KH, KW, C/group, M] +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); const convTranspose2d = @@ -236,7 +236,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1) { + if (adjustedAttributes.group !== 1 || attributes.strides.reduce((a, b) => a * b, 1) !== 1) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } From 0ec2ed47287ea706fd316033fd3004d2e471350d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 26 Sep 2023 17:39:42 -0700 Subject: [PATCH 16/18] Enable MatMul version of ConvTranspose even for non-unit strides. --- .../webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index c710f1d134164..e487ba9bfc9b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -86,8 +86,8 @@ const conv2dTransposeCommonSnippet = let WRow = ${col} / (filterDims[1] * inChannels); let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow / strides[0] - pads[0] + dilation[0] * WRow); - let xC = f32(outCol / strides[1] - pads[1] + dilation[1] * WCol); + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${typeSnippet(innerElementSize)}(0.0); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index eb6d43d372684..5641386cce849 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -97,9 +97,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -236,7 +238,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1 || attributes.strides.reduce((a, b) => a * b, 1) !== 1) { + if (adjustedAttributes.group !== 1) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } From f98341f7ac25537cfa60a66b8c0ebaa101dc039c Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 27 Sep 2023 10:20:19 -0700 Subject: [PATCH 17/18] Use naive implementation for non-unit dilations. --- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 5641386cce849..06c8a4476c1b5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -238,7 +238,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1) { + if (adjustedAttributes.group !== 1 || adjustedAttributes.dilations.reduce((x, y) => x * y, 1) !== 1) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } From cedc66bfad92c8a1bed8f7be05c325be3217c400 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 28 Sep 2023 09:26:59 -0700 Subject: [PATCH 18/18] Enable ConvTranspose matmul implementation even if dilations > 1. --- .../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index e487ba9bfc9b1..3925e1cb4f564 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -116,8 +116,8 @@ const conv2dTransposeCommonSnippet = const sampleW = ` let col = colIn * ${innerElementSize}; let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = effectiveFilterDims.x - 1 - row / (effectiveFilterDims[1] * inChannels); - let coordY = effectiveFilterDims.y - 1 - (row / inChannels) % effectiveFilterDims[1]; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; if (${ isChannelsLast ? 'row < dimInner && col < dimBOuter' : 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 06c8a4476c1b5..5641386cce849 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -238,7 +238,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1 || adjustedAttributes.dilations.reduce((x, y) => x * y, 1) !== 1) { + if (adjustedAttributes.group !== 1) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; }