From e797f53aa16b57793eb3c810a8d4b9bd1a6a923e Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Thu, 28 Sep 2023 13:47:05 +0400 Subject: [PATCH] Attention WIP, Conv speedup --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 13 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 237 +++++++++--------- js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts | 75 +++--- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 158 ++++++------ js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 17 +- .../lib/wasm/jsep/webgpu/program-manager.ts | 68 ++--- 6 files changed, 288 insertions(+), 280 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 5c692ee59f55a..e61bce222d8ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -28,7 +28,7 @@ import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; -import { tensorTypeToWsglStorageType } from '../common' +import { tensorTypeToWsglStorageType } from '../common'; const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, @@ -161,17 +161,14 @@ export const createConv2DMatMulProgramInfo = const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - const isVec4 = (((inChannels % 4 === 0 || inChannels % 3 === 0) && isChannelsLast) || - (outWidth % 4 === 0 && !isChannelsLast)) && - outChannels % 4 === 0; + // TODO: enable vec4 for NCHW + const isVec4 = isChannelsLast && (inChannels % 4 === 0 || inChannels % 3 === 0) && outChannels % 4 === 0; // TODO: fine tune size const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + const workGroupSize: [number, number, number] = [8, 8, 1]; + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; const dispatch = [ Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 7c3b012de6f26..36e842124afd6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -86,11 +86,11 @@ export const makeMatMulPackedVec4Source = const innerElementSize = tileAWidth / workgroupSize[0]; const rowPerThreadB = tileInner / workgroupSize[1]; - if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || - (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && - tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) { - throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ - innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. + if (!(((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || + (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && + tileAWidth % workgroupSize[0] === 0 && tileInner % workgroupSize[1] === 0 && workPerThread[0] === 4)) { + throw new Error(`If transposeA ${transposeA} is true, innerElementSize ${ + innerElementSize} and workPerThread[1] ${workPerThread[1]} must be 4. Otherwise, innerElementSize ${innerElementSize} must be 3 or 4. tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${workgroupSize[0]}. tileInner ${ tileInner} must be divisible by workgroupSize[1] ${workgroupSize[1]}. colPerThread ${ @@ -139,7 +139,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let inputRow = tileRowB + innerRow; let inputCol = tileCol; mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol${ - batchDims ? ', batchIndices' : ''}); + batchDims ? ', batchIndices' : ''}); } kStart = kStart + tileInner; workgroupBarrier(); @@ -161,7 +161,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); } }`; - }; + }; const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) => { if (transpose) { @@ -181,7 +181,7 @@ const writeDataToSubASnippet = (transpose: boolean, batchDims?: IndicesHelper) = }; const readDataFromSubASnippet = (transposeA: boolean) => - transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];'; + transposeA ? 'let ACached = mm_Asub[k][tileRow + innerRow];' : 'let ACached = mm_Asub[tileRow + innerRow][k];'; // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). @@ -194,17 +194,17 @@ export const makeMatMulPackedSource = const tileAWidth = transposeA ? tileAOuter : tileInner; const tileAHight = transposeA ? tileInner : tileAOuter; - if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && - tileInner % workgroupSize[1] === 0)) { - throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ - workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ - workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`); - } - const rowPerThreadA = tileAHight / workgroupSize[1]; - const colPerThreadA = tileAWidth / workgroupSize[0]; - const rowPerThreadB = tileInner / workgroupSize[1]; - const matmulSnippet = sequentialAccessByThreads ? - ` + if (!(tileAHight % workgroupSize[1] === 0 && tileAWidth % workgroupSize[0] === 0 && + tileInner % workgroupSize[1] === 0)) { + throw new Error(`tileAHight ${tileAHight} must be divisible by workgroupSize[1]${ + workgroupSize[1]}, tileAWidth ${tileAWidth} must be divisible by workgroupSize[0]${ + workgroupSize[0]}, tileInner ${tileInner} must be divisible by workgroupSize[1]${workgroupSize[1]}`); + } + const rowPerThreadA = tileAHight / workgroupSize[1]; + const colPerThreadA = tileAWidth / workgroupSize[0]; + const rowPerThreadB = tileInner / workgroupSize[1]; + const matmulSnippet = sequentialAccessByThreads ? + ` let localRow = i32(localId.y); let localCol = i32(localId.x); let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; @@ -237,8 +237,8 @@ export const makeMatMulPackedSource = } for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let ACached = ${ - transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` : - `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`} + transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` : + `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`} for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; @@ -255,7 +255,7 @@ export const makeMatMulPackedSource = } } ` : - ` + ` let tileRow = i32(localId.y) * rowPerThread; let tileCol = i32(localId.x) * colPerThread; @@ -343,48 +343,49 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${matmulSnippet} } `; - }; + }; const matMulReadWriteFnSource = - (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[]): string => { - const batchAVariable = variables[0]; - const batchBVariable = variables[1]; - const batchVariable = variables[2]; - const aVariable = variables[3]; - const bVariable = variables[4]; - const outputVariable = variables[5]; - const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); - const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); - const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); - const getAIndices = () => { - const aRank = aVariable.shape.length; - const batchRank = batchVariable.shape.length; - let resStr = `var aIndices: ${aVariable.type.indices};`; - for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastADims.forEach(i => { - resStr += `\naIndices[${i}] = 0;`; - }); - resStr += `\naIndices[${aRank - 2}] = u32(row); + (component: number, hasBias: boolean, applyActivation: string, variables: IndicesHelper[], isChannelsLast = false): + string => { + const batchAVariable = variables[0]; + const batchBVariable = variables[1]; + const batchVariable = variables[2]; + const aVariable = variables[3]; + const bVariable = variables[4]; + const outputVariable = variables[5]; + const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); + const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); + const getAIndices = () => { + const aRank = aVariable.shape.length; + const batchRank = batchVariable.shape.length; + let resStr = `var aIndices: ${aVariable.type.indices};`; + for (let i = aRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\naIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; + } + broadCastADims.forEach(i => { + resStr += `\naIndices[${i}] = 0;`; + }); + resStr += `\naIndices[${aRank - 2}] = u32(row); aIndices[${aRank - 1}] = u32(colIn);`; - return resStr; - }; - const getBIndices = () => { - const bRank = bVariable.shape.length; - const batchRank = batchVariable.shape.length; - let resStr = `var bIndices: ${bVariable.type.indices};`; - for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { - resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; - } - broadCastBDims.forEach(i => { - resStr += `\nbIndices[${i}] = 0;`; - }); - resStr += `\nbIndices[${bRank - 2}] = u32(row); + return resStr; + }; + const getBIndices = () => { + const bRank = bVariable.shape.length; + const batchRank = batchVariable.shape.length; + let resStr = `var bIndices: ${bVariable.type.indices};`; + for (let i = bRank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) { + resStr += `\nbIndices[${i}] = ${batchRank > 1 ? `batchIndices[${j}]` : 'batchIndices'};`; + } + broadCastBDims.forEach(i => { + resStr += `\nbIndices[${i}] = 0;`; + }); + resStr += `\nbIndices[${bRank - 2}] = u32(row); bIndices[${bRank - 1}] = u32(colIn);`; - return resStr; - }; - const source = ` + return resStr; + }; + const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); @@ -414,61 +415,63 @@ const matMulReadWriteFnSource = if (row < dimAOuter && col < dimBOuter) { var value = valueIn; let coords = vec3(batch, row, colIn); - ${hasBias ? 'value = value + bias[colIn];' : ''} + ${hasBias ? `value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` : ''} ${applyActivation} ${outputVariable.setByIndices('vec3(coords)', 'value')} } } `; - return source; - }; + return source; + }; export const createMatmulProgramInfo = - (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, - outputShape: readonly number[], reshapedOutputShape?: readonly number[]): ProgramInfo => { - const aShape = inputs[0].dims; - const bShape = inputs[1].dims; - - const outerDimsA = aShape.slice(0, -2); - const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); - const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA); - const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB); - const variables = [batchADims, batchBDims, batchDims]; - const batchSize = ShapeUtil.size(outerDims); - - const dimAOuter = aShape[aShape.length - 2]; - const dimInner = aShape[aShape.length - 1]; - const dimBOuter = bShape[bShape.length - 1]; - const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); - - // TODO: fine tune size - const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; - const workgroupSize: [number, number, number] = [8, 8, 1]; - const dispatch = [ - Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]), - Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) - ]; - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const components = isVec4 ? 4 : 1; - const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); - const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); - const output = - outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); - variables.push(A); - variables.push(B); - variables.push(output); - const inputVariables = [A, B]; - const hasBias = inputs.length > 2; - const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables); - if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [dimBOuter / components], components)); - } - const getShaderSource = (shaderHelper: ShaderHelper) => ` + (metadata: ProgramMetadata, inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, + outputShape: readonly number[], reshapedOutputShape?: readonly number[], + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; + + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchDims = inputVariable('batchDims', inputs[0].dataType, outerDims); + const batchADims = inputVariable('batchADims', inputs[0].dataType, outerDimsA); + const batchBDims = inputVariable('batchBDims', inputs[0].dataType, outerDimsB); + const variables = [batchADims, batchBDims, batchDims]; + const batchSize = ShapeUtil.size(outerDims); + + const dimAOuter = aShape[aShape.length - 2]; + const dimInner = aShape[aShape.length - 1]; + const dimBOuter = bShape[bShape.length - 1]; + const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; + const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + + // TODO: fine tune size + const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; + const workgroupSize: [number, number, number] = [8, 8, 1]; + const dispatch = [ + Math.ceil(dimBOuter / workgroupSize[0] / elementsPerThread[0]), + Math.ceil(dimAOuter / workgroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) + ]; + + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const components = isVec4 ? 4 : 1; + const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); + const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); + const output = + outputVariable('result', inputs[0].dataType, [batchSize, dimAOuter, dimBOuter / components], components); + variables.push(A); + variables.push(B); + variables.push(output); + const inputVariables = [A, B]; + const hasBias = inputs.length > 2; + const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, isChannelsLast); + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims, biasComponents)); + } + const getShaderSource = (shaderHelper: ShaderHelper) => ` const dimAOuter: i32 = ${dimAOuter}; const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; @@ -476,13 +479,13 @@ export const createMatmulProgramInfo = ${declareFunctions} ${activationFunction} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ${batchDims.impl()}`; - return { - ...metadata, - outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}) - }; - }; \ No newline at end of file + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}) + }; + }; \ No newline at end of file diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 004176b2822c6..b88094e535cf6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -15,7 +15,6 @@ import { sumVector, tensorTypeToWsglStorageType } from './common'; -import {transposeProgramMetadata} from './transpose'; export enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -385,7 +384,7 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const TILE_SIZE = 8; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), - y: Math.ceil(params.totalSequenceLength / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), z: params.batchSize * params.numHeads, }; @@ -430,9 +429,11 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: workgroupBarrier(); } - let headOffset = headIdx * M * N; +let batchIdx = workgroup_id.z / ${params.numHeads}; + let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads}; + let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize}; if (lm < M && ln < N) { - let outputIdx = headOffset + lm * N + ln; + let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + lm * ${params.vHiddenSize} + currentBatchHeadNumber * ${params.vHeadSize} + ln; output[outputIdx] = ${sumVector('value', components)}; } }`; @@ -457,41 +458,37 @@ export const applyAttention = computeVxAttentionScore(context, probs, v, parameters); // const attentionResult = computeVxAttentionScore(context, probs, v, parameters); - // - // const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; - // const input = inputVariable('input', q.dataType, attentionResult.dims); - // const output = outputVariable('output', q.dataType, outputShape); - // const getShaderSource = (shaderHelper: ShaderHelper) => ` - // ${shaderHelper.declareVariables(input, output)} - // - // ${shaderHelper.mainStart(parameters.numHeads * parameters.batchSize)} - // let headOffset = global_idx % ${parameters.vHeadSize}; - // let sequenceIndex = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength}; - // let batchIndex = global_idx / ${parameters.numHeads}; - // let headIndex = global_idx % ${parameters.numHeads}; - // // let in = input[0]; - // - // var inputOffset = ${parameters.sequenceLength * parameters.vHeadSize} * global_idx; - // var outputOffset = (batchIndex * ${parameters.sequenceLength * parameters.numHeads} + headIndex) - // * ${parameters.vHeadSize}; - // for (var j = 0; j < ${parameters.sequenceLength}; j++) { - // for (var i: u32 = 0; i < ${parameters.vHeadSize}; i++) { - // output[outputOffset + i] = input[inputOffset + i]; - // } - // inputOffset += ${parameters.vHeadSize}; - // outputOffset += ${parameters.vHiddenSize}; - // } - // }`; - // - // context.compute( - // { - // ...transposeProgramMetadata, - // cacheHint: JSON.stringify(parameters), - // outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], - // getShaderSource, - // dispatchGroup: () => ({ x: 1 }), - // }, - // {inputs: [attentionResult], outputs: [0]}); + +// const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; +// const input = inputVariable('input', q.dataType, attentionResult.dims); +// const output = outputVariable('output', q.dataType, outputShape); +// const outputSize = parameters.batchSize * parameters.sequenceLength * parameters.vHeadSize * parameters.numHeads; +// const getShaderSource = (shaderHelper: ShaderHelper) => ` +// ${shaderHelper.declareVariables(input, output)} +// +// ${shaderHelper.mainStart()} +// ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} +// let h = global_idx % ${parameters.vHeadSize}; +// let n = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength}; +// let s = (global_idx / (${parameters.vHeadSize * parameters.numHeads})) % ${parameters.sequenceLength}; +// let b = global_idx / (${parameters.vHeadSize * parameters.sequenceLength * parameters.numHeads}); +// +// var inputOffset = b * ${parameters.numHeads * parameters.sequenceLength * parameters.vHeadSize} + n * ${parameters.sequenceLength * parameters.vHeadSize} + s * ${parameters.vHeadSize} + h; +// var outputOffset = b * ${parameters.sequenceLength * parameters.vHiddenSize} + s * ${parameters.vHiddenSize} + n * ${parameters.vHeadSize} + h; +// +// output[outputOffset] = input[inputOffset]; +// }`; +// +// context.compute( +// { +// name: 'AttentionTranspose', +// cacheHint: JSON.stringify(parameters), +// inputTypes: [GpuDataType.default], +// outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], +// getShaderSource, +// dispatchGroup: () => ({ x: Math.ceil(outputSize / 64) }), +// }, +// {inputs: [attentionResult], outputs: [0]}); }; const prepare = (context: ComputeContext, parameters: AttentionParameters, attributes: AttentionAttrs) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index c13cb08d7780d..a303596c4b87a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -13,21 +13,21 @@ import {createMatmulProgramInfoLoader} from './matmul'; import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; export const calculateOutputShape = - (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], - adjustPads: readonly number[], strides: readonly number[], isChannelLast: boolean): number[] => { - const batchSize = inputShape[0]; - const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4); - const spatialRank = inputSpatialShape.length; - const outChannels = kernelShape[0]; - const kernelSpatialShape = kernelShape.slice(2); - const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); - const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); - const outputShape = - inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); - outputShape.splice(0, 0, batchSize); - outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); - return outputShape; - }; + (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], + adjustPads: readonly number[], strides: readonly number[], isChannelLast: boolean): number[] => { + const batchSize = inputShape[0]; + const inputSpatialShape = inputShape.slice(isChannelLast ? 1 : 2, isChannelLast ? 3 : 4); + const spatialRank = inputSpatialShape.length; + const outChannels = kernelShape[0]; + const kernelSpatialShape = kernelShape.slice(2); + const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1)); + const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]); + const outputShape = + inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i])); + outputShape.splice(0, 0, batchSize); + outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); + return outputShape; + }; export interface ConvAttributes extends InternalActivationAttributes, AttributeWithCacheKey { readonly autoPad: string; @@ -104,8 +104,8 @@ const getAdjustedConvAttributes = (attributes: T, inpu } const pads = attributes.pads.slice(); PoolConvUtil.adjustPadsBasedOnAutoPad( - inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.format === 'NHWC', - attributes.autoPad); + inputs[0].dims, attributes.strides, attributes.dilations, kernelShape, pads, attributes.format === 'NHWC', + attributes.autoPad); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); @@ -126,7 +126,7 @@ export const parseConvAttributes = (attributes: Record): ConvAt const wIsConst = (attributes.w_is_const as () => boolean)(); return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { @@ -134,15 +134,14 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // check attributes - const hasBias = inputs.length === 3; // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ - const isChannelsLast = attributes.format === 'NHWC'; - if (!isChannelsLast || attributes.group !== 1) { + if (attributes.group !== 1) { context.compute(createGroupedConvProgramInfoLoader(inputs, adjustedAttributes)); return; } - // const batchSize = context.inputs[0].dims[0]; + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; const inputHeight = inputs[0].dims[isChannelsLast ? 1 : 2]; const inputWidth = inputs[0].dims[isChannelsLast ? 2 : 3]; const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; @@ -150,63 +149,75 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut const weightWidth = inputs[1].dims[3]; const outputShape = calculateOutputShape( - inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, - isChannelsLast); + inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, + isChannelsLast); const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const batch = outputShape[0]; - const sameSize = - isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && attributes.autoPad === 'VALID'; + const sameSize = isChannelsLast && weightHeight === inputHeight && weightWidth === inputWidth && + attributes.pads[0] === 0 && attributes.pads[1] === 0; if (sameSize || - (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && - attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 && - attributes.pads[1] === 0)) { + (weightHeight === 1 && weightWidth === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1 && + attributes.strides[0] === 1 && attributes.strides[1] === 1 && attributes.pads[0] === 0 && + attributes.pads[1] === 0)) { // conv2dByMatMul - const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + const batch = outputShape[0]; + let xReshaped, wReshaped, matmulOutputShape; + const matmulInputs = []; + if (isChannelsLast) { + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) - }, - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; - if (attributes.wIsConst && !context.kernelCustomData.wT) { - context.kernelCustomData.wT = transposedWeight; + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + if (sameSize) { + const sharedDim = inputHeight * inputWidth * inputChannels; + xReshaped = inputs[0].reshape([1, batch, sharedDim]); + wReshaped = transposedWeight.reshape([1, sharedDim, outChannels]); + matmulOutputShape = [1, batch, outChannels]; + } else { + xReshaped = inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels]); + wReshaped = transposedWeight.reshape([1, inputChannels, outChannels]); + matmulOutputShape = [batch, outHeight * outWidth, outChannels]; + } + matmulInputs.push(xReshaped); + matmulInputs.push(wReshaped); + } else { + xReshaped = inputs[0].reshape([batch, inputChannels, inputHeight * inputWidth]); + wReshaped = inputs[1].reshape([1, outChannels, inputChannels]); + matmulOutputShape = [batch, outChannels, outHeight * outWidth]; + matmulInputs.push(wReshaped); + matmulInputs.push(xReshaped); } - - const matmulInputs = []; - matmulInputs.push(inputs[0].reshape([batch, inputHeight * inputWidth, inputChannels])); - matmulInputs.push(transposedWeight.reshape([1, inputChannels, outChannels])); if (hasBias) { matmulInputs.push(inputs[2]); } - const matmulOutputShape = [batch, outHeight * outWidth, outChannels]; context.compute( - createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape), - {inputs: matmulInputs}); - + createMatmulProgramInfoLoader(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast), + {inputs: matmulInputs}); return; } // TODO: implement conv2dWithIm2Col() - const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; - const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; - const dimInner = weightHeight * weightWidth * inputChannels; - const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? - context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) - }, - {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; } @@ -214,19 +225,18 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.2: prepare reshaped inputs const convInputs = [inputs[0], transposedWeight]; if (hasBias) { - if (!isChannelsLast && inputs[2].dims.length === 1) { - convInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); - } else { - convInputs.push(inputs[2]); - } + convInputs.push(inputs[2]); } // STEP.3: compute matmul + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; context.compute( - createConv2DMatMulProgramInfoLoader( - convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, - sequentialAccessByThreads), - {inputs: convInputs}); + createConv2DMatMulProgramInfoLoader( + convInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convInputs}); }; const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { @@ -234,11 +244,11 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { const isChannelLast = attributes.format === 'NHWC'; const inputs = [ context.inputs[0].reshape( - isChannelLast ? - // [N, W, C] -> [N, H=1, W, C] - [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : - // [N, C, W] -> [N, C, H=1, W] - [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), + isChannelLast ? + // [N, W, C] -> [N, H=1, W, C] + [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : + // [N, C, W] -> [N, C, H=1, W] + [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) ]; @@ -251,8 +261,8 @@ const conv1d = (context: ComputeContext, attributes: ConvAttributes): void => { const kernelShape = [1].concat(attributes.kernelShape); const adjustedAttributes = getAdjustedConvAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); context.compute(createGroupedConvProgramInfoLoader( - inputs, adjustedAttributes, - outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [])); + inputs, adjustedAttributes, + outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [])); }; export const conv = (context: ComputeContext, attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index a8b4b0d6c61e4..0b6446d285183 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -12,19 +12,20 @@ import {InternalActivationAttributes} from './fuse-utils'; const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : - [GpuDataType.default, GpuDataType.default], + [GpuDataType.default, GpuDataType.default], cacheHint }); export const createMatmulProgramInfoLoader = - (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], - reshapedOutputShape?: readonly number[]): ProgramInfoLoader => { - const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); - return { - ...metadata, - get: () => createMatmulProgramInfo(metadata, inputs, activationAttributes, outputShape, reshapedOutputShape) - }; + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], + reshapedOutputShape?: readonly number[], isChannelsLast = false): ProgramInfoLoader => { + const metadata = createMatmulProgramMetadata(inputs.length > 2, activationAttributes.activationCacheKey); + return { + ...metadata, + get: () => createMatmulProgramInfo( + metadata, inputs, activationAttributes, outputShape, reshapedOutputShape, isChannelsLast) }; + }; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 45ee20dfc1364..a230a4bace3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -57,40 +57,40 @@ export class ProgramManager { this.backend.pendingDispatchNumber++; - this.backend.endComputePass(); - const kernelId = this.backend.currentKernelId!; - const kernelName = this.backend.kernels.get(kernelId)![0]; - for (const output of outputs) { - const stagingBuffer = this.backend.device.createBuffer({ - size: output.buffer.size, - // eslint-disable-next-line no-bitwise - usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, - }); - - const commandEncoder = this.backend.getCommandEncoder(); - commandEncoder?.copyBufferToBuffer( - output.buffer, - 0, // Source offset - stagingBuffer, - 0, // Destination offset - output.buffer.size, - ); - this.backend.flush(); - - stagingBuffer - .mapAsync( - GPUMapMode.READ, - 0, // Offset - output.buffer.size, - ) - .then(() => { - const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size); - const data = copyArrayBuffer.slice(0); - stagingBuffer.unmap(); - console.log(`${kernelId}|${kernelName}:`); - console.log(new Float32Array(data)); - }); - } + // this.backend.endComputePass(); + // const kernelId = this.backend.currentKernelId!; + // const kernelName = this.backend.kernels.get(kernelId)![0]; + // for (const output of outputs) { + // const stagingBuffer = this.backend.device.createBuffer({ + // size: output.buffer.size, + // // eslint-disable-next-line no-bitwise + // usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + // }); + // + // const commandEncoder = this.backend.getCommandEncoder(); + // commandEncoder?.copyBufferToBuffer( + // output.buffer, + // 0, // Source offset + // stagingBuffer, + // 0, // Destination offset + // output.buffer.size, + // ); + // this.backend.flush(); + // + // stagingBuffer + // .mapAsync( + // GPUMapMode.READ, + // 0, // Offset + // output.buffer.size, + // ) + // .then(() => { + // const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size); + // const data = copyArrayBuffer.slice(0); + // stagingBuffer.unmap(); + // console.log(`${kernelId}|${kernelName}:`); + // console.log(new Float32Array(data)); + // }); + // } if (profilingEnabled) { // profiling write end timestamp