diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index fee872f4120e3..52bb490f454ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -358,7 +358,7 @@ const matMulReadWriteFnSource = resStr += `\naIndices[${i}] = 0;`; }); resStr += `\naIndices[${aRank - 2}] = u32(row); - aIndices[${aRank - 1}] = u32(colIn);`; + aIndices[${aRank - 1}] = u32(col);`; return resStr; }; const getBIndices = () => { @@ -372,7 +372,7 @@ const matMulReadWriteFnSource = resStr += `\nbIndices[${i}] = 0;`; }); resStr += `\nbIndices[${bRank - 2}] = u32(row); - bIndices[${bRank - 1}] = u32(colIn);`; + bIndices[${bRank - 1}] = u32(col);`; return resStr; }; const source = ` @@ -383,7 +383,8 @@ const matMulReadWriteFnSource = if(row < dimAOuter && col < dimInner) { ${getAIndices()} - value = ${aVariable.getByIndices('aIndices')}; + // TODO: recover to use getByIndices/setByIndices when dxc is supported. + value = a[${aVariable.indicesToOffset('aIndices')}${component === 1 ? '' : `/ ${component}`}]; } return value; } @@ -395,7 +396,7 @@ const matMulReadWriteFnSource = if(row < dimInner && col < dimBOuter) { ${getBIndices()} - value = ${bVariable.getByIndices('bIndices')}; + value = b[${bVariable.indicesToOffset('bIndices')}${component === 1 ? '' : `/ ${component}`}]; } return value; } @@ -404,10 +405,11 @@ const matMulReadWriteFnSource = let col = colIn * ${component}; if (row < dimAOuter && col < dimBOuter) { var value = valueIn; - let coords = vec3(batch, row, colIn); + let coords = vec3(batch, row, col); ${hasBias ? 'value = value + bias[colIn];' : ''} ${applyActivation} - ${outputVariable.setByIndices('vec3(coords)', 'value')} + result[${outputVariable.indicesToOffset('vec3(coords)')}${ + component === 1 ? '' : `/ ${component}`}] = value; } } `; @@ -446,10 +448,9 @@ export const createMatmulProgramInfo = ]; const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); + const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner], components); + const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter], components); + const output = outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter], components); variables.push(A); variables.push(B); variables.push(output); @@ -457,7 +458,7 @@ export const createMatmulProgramInfo = const hasBias = inputs.length > 2; const declareFunctions = matMulReadWriteFnSource(component, hasBias, applyActivation, variables); if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter], components)); } const getShaderSource = (shaderHelper: ShaderHelper) => ` const dimAOuter: i32 = ${dimAOuter};