Skip to content

Commit

Permalink
[JS/WebGPU] Use non-matmul implementation for ConvTranspose in channe…
Browse files Browse the repository at this point in the history
…l-first case. (microsoft#20022)

### Description
Avoid using vec4 Matmul implementation for ConvTranspose with channel-last



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored and Ted Themistokleous committed May 7, 2024
1 parent 6954f59 commit 6fe83c2
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo =
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 =
isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0;
// TODO: enable vec4 for NCHW
const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ?
[8, 8, 1] :
[(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1];
const workGroupSize: [number, number, number] = [8, 8, 1];
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Expand Down
Loading

0 comments on commit 6fe83c2

Please sign in to comment.