From 656ca66186c7fd362abd8f33915bd0f96483bf43 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 07:37:05 +0800 Subject: [PATCH] [js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753) --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 125 +++++++------ .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 154 ++++++++-------- .../ops/3rd-party/conv_backprop_webgpu.ts | 174 +++++++++++------- .../ops/3rd-party/matmul_packed_webgpu.ts | 108 +++++------ .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 86 +++++---- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 15 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 18 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 39 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 43 +++-- 9 files changed, 418 insertions(+), 344 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7dbe..1a03621512888 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -88,10 +88,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +129,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +141,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.dilations} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e14..33e50a9a39cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; - let declareFunctions = ''; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, + {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ @@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo = elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + undefined, sequentialAccessByThreads)}`; + }; + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50b0841a0200a..380efc8bc577a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,24 +20,18 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource = `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, + {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, + {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a296712..ee71110245252 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +204,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +380,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +392,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +402,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +423,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +441,67 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aShapeOrRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bShapeOrRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), + ...createTensorShapeVariables(bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchShapeOrRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21b4953d3f90c..f81d6577890c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramUniform} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, + {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const applyActivation = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; @@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const applyActivation = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo = .registerUniform('strides', 'i32', 2) .registerUniform('pads', 'i32', 2) .declareVariables(...inputVars, output)} - ${activationFunction} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo = return { name: 'GroupedConv-Vectorize', shaderCache: { - hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94ca..33d16754c737a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7af2c5db49f40..5afec0389fac8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 0b5c0db2b5112..2e0aa33a957dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -7,30 +7,21 @@ export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; - case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; - case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + // TODO: adding other activations that can be fused. + default: + return ''; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { @@ -38,7 +29,7 @@ export const parseInternalActivationAttributes = if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436f..c946ea6366123 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: 'uint32', data: K} ]; + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } };