diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 089e783d7e22f..9ec527733deeb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -22,8 +22,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -32,13 +32,13 @@ import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dCommonSnippet = - (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, - attributes: ConvAttributes, innerElementSizeX = 4, innerElementSizeW = 4, innerElementSize = 4, - dataType = 'f32'): string => { + (xShapeStr: string, wShapeStr: string, outputShapeStr: string, isChannelsLast: boolean, fitAOuter: boolean, + fitBOuter: boolean, fitInner: boolean, addBias = false, attributes: ConvAttributes, innerElementSizeX = 4, + innerElementSizeW = 4, innerElementSize = 4, dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'resData = x[xIndex];'; + return `resData = x[xIndex];`; case 3: return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: @@ -50,9 +50,9 @@ const conv2dCommonSnippet = const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return w[row * wShape[3] + colIn];'; + return `return w[row * i32(${wShapeStr}[3]) + colIn];`; case 4: - return 'return w[row * wShape[3] / 4 + colIn];'; + return `return w[row * i32(${wShapeStr}[3]) / 4 + colIn];`; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); } @@ -79,13 +79,13 @@ const conv2dCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'xShape[1]' : 'xShape[2]'; - const xWidth = isChannelsLast ? 'xShape[2]' : 'xShape[3]'; + const xHeight = isChannelsLast ? `i32(${xShapeStr}[1])` : `i32(${xShapeStr}[2])`; + const xWidth = isChannelsLast ? `i32(${xShapeStr}[2])` : `i32(${xShapeStr}[3])`; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readXSnippet = ` - let inChannels = wShape[2]; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let inChannels = i32(${wShapeStr}[2]); + let outWidth = ${isChannelsLast ? `i32(${outputShapeStr}[2])` : `i32(${outputShapeStr}[3])`}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,7 +99,7 @@ const conv2dCommonSnippet = // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { ${coordASnippet} - let xIndex = getIndexFromCoords4D(coord, xShape); + let xIndex = getIndexFromCoords4D(coord, vec4(${xShapeStr})); ${getXSnippet(innerElementSizeX)} } return resData;`; @@ -109,7 +109,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -118,7 +118,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -143,10 +143,10 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? `i32(${outputShapeStr}[2])` : `i32(${outputShapeStr}[3])`}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} @@ -194,10 +194,29 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const t = tensorTypeToWsglStorageType(inputs[0].dataType); - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` - ]; + const components = isVec4 ? 4 : 1; + const enableXShapesUniforms = enableShapesUniforms(inputs[0].dims.length); + const xShapeOrRank = enableXShapesUniforms ? inputs[0].dims.length : inputs[0].dims; + + const enableWShapesUniforms = enableShapesUniforms(inputs[1].dims.length); + const wShapeOrRank = enableWShapesUniforms ? inputs[1].dims.length : inputs[1].dims; + + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const x = inputVariable('x', inputs[0].dataType, xShapeOrRank, components); + const w = inputVariable('w', inputs[1].dataType, wShapeOrRank, components); + const inputVariables = [x, w]; + + if (enableXShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + } + if (enableWShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + } + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); @@ -207,46 +226,52 @@ export const createConv2DMatMulProgramInfo = setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); + const enableBiasShapesUniforms = enableShapesUniforms(inputs[2].dims.length); + const biasShapeOrRank = enableBiasShapesUniforms ? inputs[2].dims.length : inputs[2].dims; + const bias = inputVariable('bias', inputs[2].dataType, biasShapeOrRank, components); + inputVariables.push(bias); + if (enableBiasShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - + const xShapeStr = enableXShapesUniforms ? 'uniforms.x_shape' : 'x_shape'; + const wShapeStr = enableWShapesUniforms ? 'uniforms.w_shape' : 'w_shape'; + const outputShapeStr = enableOutputShapesUniforms ? 'uniforms.result_shape' : 'result_shape'; + const output = outputVariable('result', inputs[0].dataType, outputShapeOrRank, components); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } return { name: 'Conv2DMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, }), - getShaderSource: () => ` - ${utilFunctions} + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions(enableOutputShapesUniforms ? 'uniforms.result_strides' : 'result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${declareInputs.join('')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? `vec4<${t}>` : t}>; - //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; - - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], - elementsSize[2], t)} + xShapeStr, wShapeStr, outputShapeStr, isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 85cf7bf87f52c..408b0d6ba6310 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; +import {ProgramInfo, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -31,21 +31,22 @@ import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { + (xShapeStr: string, wShapeStr: string, outputShapeStr: string, isChannelsLast: boolean, addBias = false, + attributes: ConvTransposeAttributes, innerElementSize = 4): string => { const type = typeSnippet(innerElementSize, 'f32'); const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: - return 'return W[getIndexFromCoords4D(coord, wShape)];'; + return `return w[getIndexFromCoords4D(coord, vec4(${wShapeStr}))];`; case 4: return ` let coord1 = vec4(coordX, coordY, col + 1, rowInner); let coord2 = vec4(coordX, coordY, col + 2, rowInner); let coord3 = vec4(coordX, coordY, col + 3, rowInner); - let v0 = W[getIndexFromCoords4D(coord, wShape)]; - let v1 = W[getIndexFromCoords4D(coord1, wShape)]; - let v2 = W[getIndexFromCoords4D(coord2, wShape)]; - let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + let v0 = w[getIndexFromCoords4D(coord, vec4(${wShapeStr}))]; + let v1 = w[getIndexFromCoords4D(coord1, vec4(${wShapeStr}))]; + let v2 = w[getIndexFromCoords4D(coord2, vec4(${wShapeStr}))]; + let v3 = w[getIndexFromCoords4D(coord3, vec4(${wShapeStr}))]; return vec4(v0, v1, v2, v3); `; default: @@ -81,7 +82,7 @@ const conv2dTransposeCommonSnippet = const readASnippet = ` let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? `i32(${outputShapeStr}[2])` : `i32(${outputShapeStr}[3])`}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; @@ -99,17 +100,17 @@ const conv2dTransposeCommonSnippet = let iXC = i32(xC); let xCh = ${col} % inChannels; ${coordASnippet} - return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + return x[getIndexFromCoords4D(coord, vec4(${xShapeStr}))/${innerElementSize}];`; const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimInner) { + if (row < uniforms.dimAOuter && col < uniforms.dimInner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < dimInner && col < dimBOuter) { + if (row < uniforms.dimInner && col < uniforms.dimBOuter) { ${readASnippet} } return ${type}(0.0);`; @@ -120,8 +121,8 @@ const conv2dTransposeCommonSnippet = let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; if (${ - isChannelsLast ? 'row < dimInner && col < dimBOuter' : - 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : + 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -142,13 +143,13 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueInput; - let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outWidth = ${isChannelsLast ? `i32(${outputShapeStr}[2])` : `i32(${outputShapeStr}[3])`}; ${coordResSnippet} ${biasSnippet(addBias)} ${applyActivation} - result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + result[getIndexFromCoords4D(coords, vec4(${outputShapeStr}))/${innerElementSize}] = value; } }`; return userCode; @@ -185,37 +186,65 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const enableXShapesUniforms = enableShapesUniforms(inputs[0].dims.length); + const xShapeOrRank = enableXShapesUniforms ? inputs[0].dims.length : inputs[0].dims; + const enableWShapesUniforms = enableShapesUniforms(inputs[1].dims.length); + const wShapeOrRank = enableWShapesUniforms ? inputs[1].dims.length : inputs[1].dims; + const x = inputVariable('x', inputs[0].dataType, xShapeOrRank, components); + const w = inputVariable('w', inputs[1].dataType, wShapeOrRank, 1); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('result', inputs[0].dataType, outputShapeOrRank, components); + const inputVariables = [x, w]; + if (enableXShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + } + if (enableWShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + } - - const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, - '@group(0) @binding(1) var W: array;' - ]; let declareFunctions = ''; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + const enableBiasShapesUniforms = enableShapesUniforms(inputs[2].dims.length); + const biasShapeOrRank = enableBiasShapesUniforms ? inputs[2].dims.length : inputs[2].dims; + const bias = inputVariable('bias', inputs[2].dataType, biasShapeOrRank, components); + inputVariables.push(bias); + if (enableBiasShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } + + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } + + const xShapeStr = enableXShapesUniforms ? 'uniforms.x_shape' : 'x_shape'; + const wShapeStr = enableWShapesUniforms ? 'uniforms.w_shape' : 'w_shape'; + const outputShapeStr = enableOutputShapesUniforms ? 'uniforms.result_shape' : 'result_shape'; + return { name: 'Conv2DTransposeMatMul', shaderCache: {hint: attributes.cacheKey}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), - getShaderSource: () => ` - ${utilFunctions} - ${declareInputs.join('\n')} - @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions(enableOutputShapesUniforms ? 'uniforms.result_strides' : 'result_strides')} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)}; const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); - const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ attributes.kernelShape[isChannelsLast ? 2 : 3]}); const effectiveFilterDims : vec2 = filterDims + vec2( @@ -237,7 +266,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + xShapeStr, wShapeStr, outputShapeStr, isChannelsLast, hasBias, attributes, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts index 0ba48a33fbc47..2220069a51482 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_util.ts @@ -19,13 +19,15 @@ // // modified to fit the needs of the project -export const utilFunctions = ` +export const utilFunctions = (strideStr: string) => { + return ` fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 { return dot(coords, vec4( shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1)); } fn getOutputIndexFromCoords(coords : vec4) -> i32 { return dot(coords, vec4( - outShapeStrides.x, outShapeStrides.y, outShapeStrides.z, 1)); + i32(${strideStr}.x), i32(${strideStr}.y), i32(${strideStr}.z), 1)); } `; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 335de01c596b7..bbb24024e1405 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -21,8 +21,8 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,7 +112,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; @@ -322,7 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; + let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -384,7 +384,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimAOuter && col < dimInner) + if(row < uniforms.dimAOuter && col < uniforms.dimInner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -396,7 +396,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < dimInner && col < dimBOuter) + if(row < uniforms.dimInner && col < uniforms.dimBOuter) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -406,13 +406,13 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < dimAOuter && col < dimBOuter) { + if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ hasBias ? `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : - '' } + ''} ${applyActivation} ${outputVariable.setByIndices('vec3(coords)', 'value')} } @@ -430,8 +430,11 @@ export const createMatmulProgramInfo = const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); + const enableBatchUniforms = enableShapesUniforms(outerDims.length); + const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; + const batchDims = inputVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1, true); const variables = [batchDims]; const batchShapes = [outerDimsA, outerDimsB, outerDims]; const batchSize = ShapeUtil.size(outerDims); @@ -452,39 +455,89 @@ export const createMatmulProgramInfo = const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); + + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; + const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); + const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; + + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; + const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); + const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; + + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + const enableOutputShapesUniforms = enableShapesUniforms(outputShapeTemp.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShapeTemp.length : outputShapeTemp; + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeOrRank, components); variables.push(A); variables.push(B); variables.push(output); - const inputVariables = [A, B]; + const inputVariables = [batchDims, A, B]; + const programUniforms: ProgramUniform[] = + [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + if (enableBatchUniforms) { + programUniforms.push(...createTensorShapeVariables(outerDims)); + } + if (enableAShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(aShapeTemp)); + } + if (enableBShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(bShapeTemp)); + } + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + const hasBias = inputs.length > 2; const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); if (hasBias) { + const enableBiasShapesUniforms = enableShapesUniforms(inputs[2].dims.length); + const biasShapeOrRank = enableBShapesUniforms ? inputs[2].dims.length : inputs[2].dims; + const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, biasShapeOrRank, biasComponents)); + if (enableBiasShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } + inputDependencies.push(enableBiasShapesUniforms ? 'rank' : 'dims'); } + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dimAOuter: i32 = ${dimAOuter}; - const dimBOuter: i32 = ${dimBOuter}; - const dimInner: i32 = ${dimInner}; - ${shaderHelper.declareVariables(...inputVariables, output)} + ${ + shaderHelper.registerUniform('dimAOuter', 'i32') + .registerUniform('dimBOuter', 'i32') + .registerUniform('dimInner', 'i32') + .declareVariables(...inputVariables, output)} ${activationFunction} ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} - ${batchDims.impl()}`; + `; + // TODO: turn clipMax and clipMin to uniforms. return { name: 'MatMul', - shaderCache: {hint: activationAttributes.activationCacheKey}, + shaderCache: { + hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + + `${activationAttributes.activation}` + + `${activationAttributes.clipMax}` + + `${activationAttributes.clipMin}` + + `${isVec4}` + + `${hasBias}` + + `${isChannelsLast}`, + inputDependencies + }, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]} + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 38dc14f23682e..e3d39f33fa49f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -210,6 +210,11 @@ export interface IndicesHelper { * a string representing the variable name for the strides of the input or output. */ readonly strides: string; + + /** + * representing variable with uniforms, but without binding. + */ + readonly uniformOnly: boolean; } const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { @@ -335,8 +340,8 @@ export const sumVector = (name: string, components: number) => { * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, - components: 1|2|3|4): IndicesHelper => { + (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4, + uniformOnly = false): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; const rankIdentity = [...new Array(rank).keys()]; @@ -358,7 +363,7 @@ const createIndicesHelper = getByIndices: false, }; - const uniformPrefix = useUniform ? 'uniforms.' : ''; + const uniformPrefix = useUniform || uniformOnly ? 'uniforms.' : ''; const shape = `${uniformPrefix}${name}_shape`; const strides = `${uniformPrefix}${name}_strides`; let o2iSnippet = ''; @@ -616,7 +621,8 @@ const createIndicesHelper = name, strides, shape, - rank + rank, + uniformOnly }; }; @@ -630,8 +636,8 @@ const createIndicesHelper = * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shapeOrRank, true, components); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, uniformOnly = false): + IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, uniformOnly); /** * Create a IndicesHelper for an output. @@ -731,7 +737,7 @@ class ShaderHelperImpl implements ShaderHelper { `; } - private declareVariable(variable: IndicesHelper, bindingIndex: number): string { + private declareVariable(variable: IndicesHelper, bindingIndex = -1): string { this.indicesHelpers.push(variable); if (variable.rank !== 0) { if (variable.shape.startsWith('uniforms.')) { @@ -743,11 +749,22 @@ class ShaderHelperImpl implements ShaderHelper { } const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; + if (variable.uniformOnly) { + return ''; + } return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } declareVariables(...variables: IndicesHelper[]): string { - return variables.map(v => this.declareVariable(v, this.variableIndex++)).join('\n'); + return variables + .map(v => { + if (v.uniformOnly === true) { + return this.declareVariable(v); + } else { + return this.declareVariable(v, this.variableIndex++); + } + }) + .join('\n'); } registerUniform(name: string, type: string): ShaderHelper {