From fa106942a7962e68f1659cd65f5a7cdb498b8c03 Mon Sep 17 00:00:00 2001
From: Xu Xing <xing.xu@intel.com>
Date: Thu, 23 Nov 2023 06:42:55 +0800
Subject: [PATCH] [js/webgpu] Refactor matmul conv to support uniforms for
 matmul (#18452)

This change refactored matmul/conv related programs to support shape
uniforms. Currently only matmul shape uniforms are fully enabled.
TODOs: add input dependencies for conv related programs, turn clipMax
and clipMin to uniforms.
---
 .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts  | 73 ++++++++--------
 .../ops/3rd-party/conv_backprop_mm_webgpu.ts  | 73 +++++++++-------
 .../jsep/webgpu/ops/3rd-party/conv_util.ts    |  6 +-
 .../ops/3rd-party/matmul_packed_webgpu.ts     | 87 ++++++++++++++-----
 js/web/lib/wasm/jsep/webgpu/ops/common.ts     | 33 +++++--
 5 files changed, 174 insertions(+), 98 deletions(-)

diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
index 089e783d7e22f..22f942a0d9ab4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
@@ -21,9 +21,8 @@
 
 import {LOG_DEBUG} from '../../../log';
 import {TensorView} from '../../../tensor-view';
-import {ShapeUtil} from '../../../util';
-import {ProgramInfo} from '../../types';
-import {tensorTypeToWsglStorageType} from '../common';
+import {ProgramInfo, ProgramUniform} from '../../types';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
 import {ConvAttributes} from '../conv';
 import {getActivationSnippet} from '../fuse-utils';
 
@@ -50,9 +49,9 @@ const conv2dCommonSnippet =
       const getWSnippet = (innerElementSize: number) => {
         switch (innerElementSize) {
           case 1:
-            return 'return w[row * wShape[3] + colIn];';
+            return 'return w[row * i32(uniforms.w_shape[3]) + colIn];';
           case 4:
-            return 'return w[row * wShape[3] / 4 + colIn];';
+            return 'return w[row * i32(uniforms.w_shape[3]) / 4 + colIn];';
           default:
             throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
         }
@@ -79,13 +78,13 @@ const conv2dCommonSnippet =
       col % outWidth);
     `;
 
-      const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]';
-      const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]';
+      const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])';
+      const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])';
       const row = isChannelsLast ? 'row' : 'col';
       const col = isChannelsLast ? 'col' : 'row';
       const readXSnippet = `
-    let inChannels = wShape[2];
-    let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
+    let inChannels = i32(uniforms.w_shape[2]);
+    let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
     let outRow = ${row} / outWidth;
     let outCol = ${row} % outWidth;
 
@@ -99,7 +98,7 @@ const conv2dCommonSnippet =
     // the 'same' padding type.
     if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) {
       ${coordASnippet}
-      let xIndex = getIndexFromCoords4D(coord, xShape);
+      let xIndex = getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape));
       ${getXSnippet(innerElementSizeX)}
     }
     return resData;`;
@@ -109,7 +108,7 @@ const conv2dCommonSnippet =
     ${readXSnippet}` :
                                                                 `
     let col = colIn * ${innerElementSizeX};
-    if (row < dimAOuter && col < dimInner) {
+    if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
       ${readXSnippet}
     }
     return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
@@ -118,7 +117,7 @@ const conv2dCommonSnippet =
     ${readXSnippet}` :
                                                                 `
     let col = colIn * ${innerElementSizeX};
-    if (row < dimInner && col < dimBOuter) {
+    if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
       ${readXSnippet}
     }
     return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
@@ -143,10 +142,10 @@ const conv2dCommonSnippet =
 
     fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
       let col = colIn * ${innerElementSize};
-      if (row < dimAOuter && col < dimBOuter)
+      if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
       {
       var value = valueIn;
-      let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
+      let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
       ${coordResSnippet}
       ${biasSnippet(addBias)}
       ${applyActivation}
@@ -194,10 +193,17 @@ export const createConv2DMatMulProgramInfo =
       const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
       const t = tensorTypeToWsglStorageType(inputs[0].dataType);
 
-      const declareInputs = [
-        `@group(0) @binding(0) var<storage, read> x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`,
-        `@group(0) @binding(1) var<storage, read> w: array<${isVec4 ? `vec4<${t}>` : t}>;`
-      ];
+      // TODO: support component 2, 3.
+      const components = isVec4 ? 4 : 1;
+      const programUniforms: ProgramUniform[] =
+          [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
+      const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
+      const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
+      const inputVariables = [x, w];
+
+      programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
+      programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
+
       let declareFunctions = `
       fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
         result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
@@ -207,41 +213,40 @@ export const createConv2DMatMulProgramInfo =
         setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
       }`;
       if (hasBias) {
-        declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? `vec4<${t}>` : t}>;`);
+        const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
+        inputVariables.push(bias);
+
+        programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+
         declareFunctions += `
         fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? `vec4<${t}>` : t} {
           return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
         }`;
       }
-
+      const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
+      programUniforms.push(...createTensorShapeVariables(outputShape));
       return {
         name: 'Conv2DMatMul',
         shaderCache: {hint: attributes.cacheKey},
         getRunData: () => ({
           outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
           dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
+          programUniforms,
         }),
-        getShaderSource: () => `
-        ${utilFunctions}
+        getShaderSource: (shaderHelper: ShaderHelper) => `
+        ${utilFunctions('uniforms.result_strides')}
         //struct Uniforms { xShape : vec4<i32>, wShape : vec4<i32>, outShape : vec4<i32>,
         //  outShapeStrides: vec3<i32>, filterDims : vec2<i32>, pad : vec2<i32>, stride : vec2<i32>,
         //  dilation : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
-        ${declareInputs.join('')}
-        @group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
-            isVec4 ? `vec4<${t}>` : t}>;
-        //@group(0) @binding(${declareInputs.length + 1}) var<uniform> uniforms: Uniforms;
-
-        const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
-        const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
-        const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
-        const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
+        ${
+            shaderHelper.registerUniform('dimAOuter', 'i32')
+                .registerUniform('dimBOuter', 'i32')
+                .registerUniform('dimInner', 'i32')
+                .declareVariables(...inputVariables, output)}
         const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
         const pad : vec2<i32> = vec2<i32>(${attributes.pads[0]}, ${attributes.pads[1]});
         const stride : vec2<i32> = vec2<i32>(${attributes.strides[0]}, ${attributes.strides[1]});
         const dilation : vec2<i32> = vec2<i32>(${attributes.dilations[0]}, ${attributes.dilations[1]});
-        const dimAOuter : i32 = ${dimAOuter};
-        const dimBOuter : i32 = ${dimBOuter};
-        const dimInner : i32 = ${dimInner};
         ${declareFunctions}
         ${
             conv2dCommonSnippet(
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
index 85cf7bf87f52c..d425155857e14 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
@@ -21,8 +21,8 @@
 
 import {LOG_DEBUG} from '../../../log';
 import {TensorView} from '../../../tensor-view';
-import {ShapeUtil} from '../../../util';
-import {ProgramInfo} from '../../types';
+import {ProgramInfo, ProgramUniform} from '../../types';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
 import {ConvTransposeAttributes} from '../conv-transpose';
 import {getActivationSnippet} from '../fuse-utils';
 
@@ -36,16 +36,16 @@ const conv2dTransposeCommonSnippet =
       const getWSnippet = (innerElementSize: number) => {
         switch (innerElementSize) {
           case 1:
-            return 'return W[getIndexFromCoords4D(coord, wShape)];';
+            return 'return w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];';
           case 4:
             return `
             let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
             let coord2 = vec4<i32>(coordX, coordY, col + 2, rowInner);
             let coord3 = vec4<i32>(coordX, coordY, col + 3, rowInner);
-            let v0 = W[getIndexFromCoords4D(coord, wShape)];
-            let v1 = W[getIndexFromCoords4D(coord1, wShape)];
-            let v2 = W[getIndexFromCoords4D(coord2, wShape)];
-            let v3 = W[getIndexFromCoords4D(coord3, wShape)];
+            let v0 = w[getIndexFromCoords4D(coord, vec4<i32>(uniforms.w_shape))];
+            let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
+            let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
+            let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
             return vec4<f32>(v0, v1, v2, v3);
             `;
           default:
@@ -81,7 +81,7 @@ const conv2dTransposeCommonSnippet =
 
       const readASnippet = `
       let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'};
-      let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
+      let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
       let outRow = ${row} / outWidth;
       let outCol = ${row} % outWidth;
 
@@ -99,17 +99,17 @@ const conv2dTransposeCommonSnippet =
       let iXC = i32(xC);
       let xCh = ${col} % inChannels;
       ${coordASnippet}
-      return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`;
+      return x[getIndexFromCoords4D(coord, vec4<i32>(uniforms.x_shape))/${innerElementSize}];`;
 
       const sampleA = isChannelsLast ? `
       let col = colIn * ${innerElementSize};
-      if (row < dimAOuter && col < dimInner) {
+      if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
         ${readASnippet}
       }
       return ${type}(0.0);` :
                                        `
       let col = colIn * ${innerElementSize};
-      if (row < dimInner && col < dimBOuter) {
+      if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
         ${readASnippet}
       }
       return ${type}(0.0);`;
@@ -120,8 +120,8 @@ const conv2dTransposeCommonSnippet =
       let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels);
       let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1];
       if (${
-          isChannelsLast ? 'row < dimInner && col < dimBOuter' :
-                           'row < dimInner && col < dimAOuter'}  && coordX >= 0 && coordY >= 0) {
+          isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' :
+                           'row < uniforms.dimInner && col < uniforms.dimAOuter'}  && coordX >= 0 && coordY >= 0) {
         let rowInner = row % inChannels;
         let coord = vec4<i32>(coordX, coordY, col, rowInner);
         ${getWSnippet(innerElementSize)}
@@ -142,13 +142,13 @@ const conv2dTransposeCommonSnippet =
 
   fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) {
     let col = colIn * ${innerElementSize};
-    if (row < dimAOuter && col < dimBOuter) {
+    if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
       var value = valueInput;
-      let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'};
+      let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
       ${coordResSnippet}
       ${biasSnippet(addBias)}
       ${applyActivation}
-      result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value;
+      result[getIndexFromCoords4D(coords, vec4<i32>(uniforms.result_shape))/${innerElementSize}] = value;
     }
   }`;
       return userCode;
@@ -185,37 +185,46 @@ export const createConv2DTransposeMatMulProgramInfo =
 
       const innerElementSize = isVec4 ? 4 : 1;
       const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
+      const components = isVec4 ? 4 : 1;
+      const programUniforms: ProgramUniform[] =
+          [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
+      const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
+      const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1);
+      const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
+      const inputVariables = [x, w];
+      programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
+      programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
 
-
-      const declareInputs = [
-        `@group(0) @binding(0) var<storage, read> x: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`,
-        '@group(0) @binding(1) var<storage, read> W: array<f32>;'
-      ];
       let declareFunctions = '';
       if (hasBias) {
-        declareInputs.push(`@group(0) @binding(2) var<storage, read> bias: array<${isVec4 ? 'vec4<f32>' : 'f32'}>;`);
+        const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
+        inputVariables.push(bias);
+        programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+
         declareFunctions += `
         fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
           return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
         }`;
       }
+
+      programUniforms.push(...createTensorShapeVariables(outputShape));
+
       return {
         name: 'Conv2DTransposeMatMul',
         shaderCache: {hint: attributes.cacheKey},
         getRunData: () => ({
           outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
-          dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}
+          dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
+          programUniforms
         }),
-        getShaderSource: () => `
-        ${utilFunctions}
-        ${declareInputs.join('\n')}
-        @group(0) @binding(${declareInputs.length}) var<storage, read_write> result: array<${
-            isVec4 ? 'vec4<f32>' : 'f32'}>;
+        getShaderSource: (shaderHelper: ShaderHelper) => `
+        ${utilFunctions('uniforms.result_strides')}
+        ${
+            shaderHelper.registerUniform('dimAOuter', 'i32')
+                .registerUniform('dimBOuter', 'i32')
+                .registerUniform('dimInner', 'i32')
+                .declareVariables(...inputVariables, output)};
         const outBackprop : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
-        const xShape : vec4<i32> = vec4<i32>(${inputs[0].dims.join(',')});
-        const wShape : vec4<i32> = vec4<i32>(${inputs[1].dims.join(',')});
-        const outShape : vec4<i32> = vec4<i32>(${outputShape.join(',')});
-        const outShapeStrides : vec3<i32> = vec3<i32>(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')});
         const filterDims : vec2<i32> = vec2<i32>(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${
             attributes.kernelShape[isChannelsLast ? 2 : 3]});
         const effectiveFilterDims : vec2<i32> = filterDims + vec2<i32>(
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts
index 0ba48a33fbc47..6f2c0231104dc 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts
@@ -19,13 +19,13 @@
 //
 // modified to fit the needs of the project
 
-export const utilFunctions = `
+export const utilFunctions = (strideStr: string) => (`
 fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
   return dot(coords, vec4<i32>(
       shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
 }
 fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
   return dot(coords, vec4<i32>(
-    outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1));
+    i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1));
 }
-`;
+`);
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
index 335de01c596b7..3e520571779e4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
@@ -21,8 +21,8 @@
 
 import {TensorView} from '../../../tensor-view';
 import {ShapeUtil} from '../../../util';
-import {ProgramInfo} from '../../types';
-import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
+import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
+import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
 import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';
 
 import {typeSnippet} from './activation_util';
@@ -112,7 +112,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
   ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
   let globalRowStart = i32(workgroupId.y) * ${tileAOuter};
 
-  let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
+  let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
   var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
 
   var acc: array<vec4<${type}>, rowPerThread>;
@@ -322,7 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3<u32>,
         @builtin(workgroup_id) workgroupId : vec3<u32>) {
     let batch = ${splitK ? '0' : 'i32(globalId.z)'};
     ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''}
-    let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'};
+    let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'};
     var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'};
 
     var acc : array<array<${type}, colPerThread>, rowPerThread>;
@@ -384,7 +384,7 @@ const matMulReadWriteFnSource =
           typeSnippet(component, dataType)} {
       var value = ${typeSnippet(component, dataType)}(0.0);
       let col = colIn * ${component};
-      if(row < dimAOuter && col < dimInner)
+      if(row < uniforms.dimAOuter && col < uniforms.dimInner)
       {
         ${getAIndices()}
         value = ${aVariable.getByIndices('aIndices')};
@@ -396,7 +396,7 @@ const matMulReadWriteFnSource =
           typeSnippet(component, dataType)} {
       var value = ${typeSnippet(component, dataType)}(0.0);
       let col = colIn * ${component};
-      if(row < dimInner && col < dimBOuter)
+      if(row < uniforms.dimInner && col < uniforms.dimBOuter)
       {
         ${getBIndices()}
         value = ${bVariable.getByIndices('bIndices')};
@@ -406,7 +406,7 @@ const matMulReadWriteFnSource =
 
     fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) {
       let col = colIn * ${component};
-      if (row < dimAOuter && col < dimBOuter) {
+      if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
         var value = valueIn;
         let coords = vec3<i32>(batch, row, colIn);
         ${
@@ -430,8 +430,11 @@ export const createMatmulProgramInfo =
 
       const outerDimsA = aShape.slice(0, -2);
       const outerDimsB = bShape.slice(0, -2);
+
       const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
-      const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims);
+      const enableBatchUniforms = enableShapesUniforms(outerDims.length);
+      const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims;
+      const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true);
       const variables = [batchDims];
       const batchShapes = [outerDimsA, outerDimsB, outerDims];
       const batchSize = ShapeUtil.size(outerDims);
@@ -452,39 +455,81 @@ export const createMatmulProgramInfo =
 
       const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
       const components = isVec4 ? 4 : 1;
-      const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components);
-      const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components);
-      const output =
-          outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components);
+
+      const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
+      const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length);
+      const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp;
+
+      const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
+      const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length);
+      const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp;
+
+      const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
+
+      const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
+      const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
+      const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
       variables.push(A);
       variables.push(B);
       variables.push(output);
-      const inputVariables = [A, B];
+      const inputVariables = [batchDims, A, B];
+      const programUniforms: ProgramUniform[] =
+          [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
+      if (enableBatchUniforms) {
+        programUniforms.push(...createTensorShapeVariables(outerDims));
+      }
+      if (enableAShapesUniforms) {
+        programUniforms.push(...createTensorShapeVariables(aShapeTemp));
+      }
+      if (enableBShapesUniforms) {
+        programUniforms.push(...createTensorShapeVariables(bShapeTemp));
+      }
+      const inputDependencies: ProgramInputTensorInfoDependency[] = [];
+      inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims');
+      inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims');
+
       const hasBias = inputs.length > 2;
       const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
       const declareFunctions =
           matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast);
       if (hasBias) {
         const biasComponents = isChannelsLast ? components : 1;
-        inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents));
+        inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
+        programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+
+        inputDependencies.push('rank');
       }
+      programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
+
       const getShaderSource = (shaderHelper: ShaderHelper) => `
-  const dimAOuter: i32 = ${dimAOuter};
-  const dimBOuter: i32 = ${dimBOuter};
-  const dimInner: i32 = ${dimInner};
-  ${shaderHelper.declareVariables(...inputVariables, output)}
+  ${
+          shaderHelper.registerUniform('dimAOuter', 'i32')
+              .registerUniform('dimBOuter', 'i32')
+              .registerUniform('dimInner', 'i32')
+              .declareVariables(...inputVariables, output)}
   ${activationFunction}
   ${declareFunctions}
   ${
           isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
                    makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
-                   ${batchDims.impl()}`;
+                   `;
+      // TODO: turn clipMax and clipMin to uniforms.
       return {
         name: 'MatMul',
-        shaderCache: {hint: activationAttributes.activationCacheKey},
+        shaderCache: {
+          hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
+              `${activationAttributes.activation}` +
+              `${activationAttributes.clipMax}` +
+              `${activationAttributes.clipMin}` +
+              `${isVec4}` +
+              `${hasBias}` +
+              `${isChannelsLast}`,
+          inputDependencies
+        },
         getRunData: () => ({
           outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
-          dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}
+          dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
+          programUniforms
         }),
         getShaderSource,
       };
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index 014d9d02f6f10..f7ae18998b218 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -210,6 +210,11 @@ export interface IndicesHelper {
    * a string representing the variable name for the strides of the input or output.
    */
   readonly strides: string;
+
+  /**
+   * representing variable with uniforms, but without binding.
+   */
+  readonly uniformOnly: boolean;
 }
 
 const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
@@ -335,8 +340,8 @@ export const sumVector = (name: string, components: number) => {
  *    vec4.
  */
 const createIndicesHelper =
-    (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean,
-     components: 1|2|3|4): IndicesHelper => {
+    (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4,
+     uniformOnly = false): IndicesHelper => {
       const useUniform = typeof shapeOrRank === 'number';
       const rank = useUniform ? shapeOrRank : shapeOrRank.length;
       const rankIdentity = [...new Array(rank).keys()];
@@ -358,7 +363,7 @@ const createIndicesHelper =
         getByIndices: false,
       };
 
-      const uniformPrefix = useUniform ? 'uniforms.' : '';
+      const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : '';
       const shape = `${uniformPrefix}${name}_shape`;
       const strides = `${uniformPrefix}${name}_strides`;
       let o2iSnippet = '';
@@ -616,7 +621,8 @@ const createIndicesHelper =
         name,
         strides,
         shape,
-        rank
+        rank,
+        uniformOnly
       };
     };
 
@@ -630,8 +636,8 @@ const createIndicesHelper =
  * @returns an IndicesHelper for the input.
  */
 export const inputVariable =
-    (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
-        createIndicesHelper(name, type, shapeOrRank, true, components);
+    (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false):
+        IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly);
 
 /**
  * Create a IndicesHelper for an output.
@@ -734,7 +740,7 @@ class ShaderHelperImpl implements ShaderHelper {
   `;
   }
 
-  private declareVariable(variable: IndicesHelper, bindingIndex: number): string {
+  private declareVariable(variable: IndicesHelper, bindingIndex = -1): string {
     this.indicesHelpers.push(variable);
     if (variable.rank !== 0) {
       if (variable.shape.startsWith('uniforms.')) {
@@ -744,13 +750,24 @@ class ShaderHelperImpl implements ShaderHelper {
         this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
       }
     }
+    if (variable.uniformOnly) {
+      return '';
+    }
     const access = variable.usage === 'input' ? 'read' : 'read_write';
     const storageType = variable.type.storage;
     return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
   }
 
   declareVariables(...variables: IndicesHelper[]): string {
-    return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n');
+    return variables
+        .map(v => {
+          if (v.uniformOnly === true) {
+            return this.declareVariable(v);
+          } else {
+            return this.declareVariable(v, this.variableIndex++);
+          }
+        })
+        .join('\n');
   }
 
   registerUniform(name: string, type: string): ShaderHelper {