From 6794407eaf456ecb6011ae6bef719c2d7ead8570 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 19 Sep 2023 22:46:24 -0700 Subject: [PATCH] Fixed filter setting --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 72 +++++++++++-------- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 4 +- js/web/test/data/ops/conv-transpose.jsonc | 2 + 3 files changed, 48 insertions(+), 30 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index f96b4ccd18ceb..1bec542521202 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -25,17 +25,18 @@ import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {typeSnippet} from './activation_util'; +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { - const getWSnippet = (innerElementSize: number) => { - switch (innerElementSize) { - case 1: - return 'return W[getIndexFromCoords4D(coord, wShape)];'; - case 4: - return ` +const conv2dTransposeCommonSnippet = + (addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); @@ -45,12 +46,12 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { let v3 = W[getIndexFromCoords4D(coord3, wShape)]; return vec4(v0, v1, v2, v3); `; - default: - throw new Error(`innerElementSize ${innerElementSize} is not supported.`); - } - }; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; - const readASnippet = ` + const readASnippet = ` let outRow = row / outShape[2]; let outCol = row % outShape[2]; @@ -71,12 +72,13 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { col % outBackprop[3]); return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; - const sampleA = `if (row < dimAOuter && col < dimInner) { + const sampleA = `if (row < dimAOuter && col < dimInner) { ${readASnippet} } return ${typeSnippet(innerElementSize)}(0.0);`; - const userCode = ` + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, col : i32) -> ${typeSnippet(innerElementSize)} { ${sampleA} } @@ -98,16 +100,17 @@ const conv2dTransposeCommonSnippet = (innerElementSize = 4): string => { fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${typeSnippet(innerElementSize)}) { if (row < dimAOuter && col < dimBOuter) { var value = valueInput; - let outCoord = vec4( + let coords = vec4( batch, row / outShape[2], row % outShape[2], col); - result[getIndexFromCoords4D(outCoord, outShape)/${innerElementSize}] = value; + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; } }`; - return userCode; -}; + return userCode; + }; export const createConv2DTransposeMatMulProgramInfo = (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, @@ -125,7 +128,8 @@ export const createConv2DTransposeMatMulProgramInfo = const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = isVec4 ? [8, 8, 1] : [4, 4, 1]; + const workGroupSize: [number, number, number] = + isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; const elementsPerThread = isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; const dispatch = [ @@ -133,23 +137,32 @@ export const createConv2DTransposeMatMulProgramInfo = Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) ]; - const innerElementSize = isVec4 ? 4 : 1; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? (inChannels % 4 !== 0 ? 3 : 4) : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const declareInputs = [ `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;` + '@group(0) @binding(1) var W: array;' ]; - + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } return { ...metadata, outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: () => ` ${utilFunctions} - ${declareInputs.join('')} + ${declareInputs.join('\n')} @group(0) @binding(${declareInputs.length}) var result: array<${ isVec4 ? 'vec4' : 'f32'}>; const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -157,14 +170,17 @@ export const createConv2DTransposeMatMulProgramInfo = const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); const outShape : vec4 = vec4(${outputShape.join(',')}); const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pads : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const pads : vec2 = vec2(i32(filterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, + i32(filterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); const dimAOuter : i32 = ${dimAOuter}; const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; - ${conv2dTransposeCommonSnippet(innerElementSize)} + ${declareFunctions} + ${conv2dTransposeCommonSnippet(hasBias, undefined, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 8a90b14fd4f91..40ca2fbd1430d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -65,7 +65,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -236,7 +236,7 @@ const convTranspose2d = const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const hasBias = inputs.length === 3; - if (adjustedAttributes.group !== 1 || hasBias) { + if (adjustedAttributes.group !== 1 || !isChannelsLast) { context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); return; } diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index 9079e466be400..9aa5d802ac10f 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -289,6 +289,8 @@ { "name": "ConvTranspose with bias addition C", "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], "cases": [ {