diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index e6d4039d8131b..d013279cdd915 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -163,17 +163,14 @@ export const createConv2DMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || - (outWidth % 4 === 0 && !isChannelsLast)) && - outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 82f8c82291f4b..ef591fe2bc15c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub: array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub: array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -339,7 +339,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, }; const matMulReadWriteFnSource = - (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => { + (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], + isChannelsLast = false): string => { const batchAVariable = variables[0]; const batchBVariable = variables[1]; const batchVariable = variables[2]; @@ -407,7 +408,10 @@ const matMulReadWriteFnSource = if (row < dimAOuter && col < dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); - ${hasBias ? 'value = value + bias[colIn];' : ''} + ${ + hasBias ? + `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : + '' } ${applyActivation} ${outputVariable.setByIndices('vec3(coords)', 'value')} } @@ -418,7 +422,8 @@ const matMulReadWriteFnSource = export const createMatmulProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, - outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => { + outputShape: readonly number[], reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; @@ -457,9 +462,10 @@ export const createMatmulProgramInfo = variables.push(output); const inputVariables = [A, B]; const hasBias = inputs.length > 2; - const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables); + const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, isChannelsLast); if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components)); + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents)); } const getShaderSource = (shaderHelper: ShaderHelper) => ` const dimAOuter: i32 = ${dimAOuter}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7afc3ce1b9d77..2ba0183fd977d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -134,15 +134,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes - const hasBias = inputs.length === 3; // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ - const isChannelsLast = attributes.format === 'NHWC'; - if (!isChannelsLast || attributes.group !== 1) { + if (attributes.group !== 1) { context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes)); return; } - // const batchSize = context.inputs[0].dims[0]; + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; @@ -155,47 +154,59 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const batch = outputShape[0]; - const sameSize = - isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID'; + const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && + attributes.pads[0] === 0 && attributes.pads[1] === 0; if (sameSize || (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 && attributes.pads[1] === 0)) { // conv2dByMatMul - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) - }, - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; - if (attributes.wIsConst && !context.kernelCustomData.wT) { - context.kernelCustomData.wT = transposedWeight; - } - + const batch = outputShape[0]; + let xReshaped, wReshaped, matmulOutputShape; const matmulInputs = []; - matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels])); - matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels])); + if (isChannelsLast) { + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + if (sameSize) { + const sharedDim = inputHeight * inputWidth * inputChannels; + xReshaped = inputs[0].reshape([1, batch, sharedDim]); + wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]); + matmulOutputShape = [1, batch, outChannels]; + } else { + xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]); + wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]); + matmulOutputShape = [batch, outHeight * outWidth, outChannels]; + } + matmulInputs.push(xReshaped); + matmulInputs.push(wReshaped); + } else { + xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]); + wReshaped = inputs[1].reshape([1, outChannels, inputChannels]); + matmulOutputShape = [batch, outChannels, outHeight * outWidth]; + matmulInputs.push(wReshaped); + matmulInputs.push(xReshaped); + } if (hasBias) { matmulInputs.push(inputs[2]); } - const matmulOutputShape = [batch, outHeight * outWidth, outChannels]; context.compute( - createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape), + createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), {inputs: matmulInputs}); - return; } // TODO: implement conv2dWithIm2Col() - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; // STEP.1: transpose weight @@ -214,14 +225,13 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.2: prepare reshaped inputs const convInputs = [inputs[0], transposedWeight]; if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convInputs.push(inputs[2]); - } + convInputs.push(inputs[2]); } // STEP.3: compute matmul + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; context.compute( createConv2DMatMulProgramInfoLoader( convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 7dadf9a6205ea..647d588bb605e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -17,11 +17,12 @@ const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ export const createMatmulProgramInfoLoader = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], - reshapedOutputShape?: readonly number[]): ProgramInfoLoader => { + reshapedOutputShape?: readonly number[], isChannelsLast = false): ProgramInfoLoader => { const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); return { ...metadata, - get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape, reshapedOutputShape) + get: () => createMatmulProgramInfo( + metadata, inputs, activationAttributes, outputShape, reshapedOutputShape, isChannelsLast) }; }; diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 928192bb219f2..219e15eb4648f 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -125,6 +125,42 @@ } ] }, + { + "name": "conv with bias addition C", + "operator": "Conv", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 3, 4, 5], + "dims": [2, 1, 2, 2], + "type": "float32" + }, + { + "data": [5, 6], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [15, 46, 31, 102, 47, 158], + "dims": [3, 2, 1, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "conv - group - A", "operator": "Conv",