From 01c644f6f8bcfcc9539d318ecb96d7723cdfb412 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 20 Sep 2023 15:05:33 -0700 Subject: [PATCH] Added dilation fix. --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index ce1724a47101d..b78eb8c9e2ca4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -172,8 +172,19 @@ export const createConv2DTransposeMatMulProgramInfo = const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const pads : vec2 = vec2(i32(filterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(filterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); const dimAOuter : i32 = ${dimAOuter};