Skip to content

Commit

Permalink
Use pure-uniform batch broadcasting, remove inputBias, refator param
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhaoming committed Nov 7, 2024
1 parent d567dfe commit 2ce8aa0
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 152 deletions.
63 changes: 37 additions & 26 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,35 @@ import {/*BroadcastUtil,*/ ShapeUtil} from '../../util';
import {/*ComputeContext,*/ ProgramInfo, ProgramUniform} from '../types';

// import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {createTensorShapeVariables, getElementAt, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils';

// Helper that convert output batch indices to input batch indices using only the rank and
// the shape information in uniform
export const convertOutputBatchIndicesToInputBatchIndices = (
targetIndicesName: string,
inputVariable: IndicesHelper,
inputBatchRank: number,
outputBatchRank: number,
batchIndicesName: string,
) => {
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of
// outputBatchRank should be ignored.
const extendingInputRank = outputBatchRank - inputBatchRank;
return `
${Array.from({ length: inputBatchRank })
.map(
(_, i) => `
if (${getElementAt(inputVariable.shape, i, inputVariable.rank)} != 1) {
${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))}
} else {
${inputVariable.indicesSet(targetIndicesName, i, 0)}
}`,
)
.join('')}
`;
};

export const createNaiveMatmulProgramInfo = (
inputs: readonly TensorView[],
activationAttributes: InternalActivationAttributes,
Expand Down Expand Up @@ -63,10 +89,6 @@ export const createNaiveMatmulProgramInfo = (
}`;
}

const outerDimsA = aShape.slice(0, -2);
const outerDimsB = bShape.slice(0, -2);
const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'M', type: 'u32' },
Expand All @@ -75,25 +97,6 @@ export const createNaiveMatmulProgramInfo = (
];
appendActivationUniforms(activationAttributes, uniforms);

const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
const rank = variable.rank;
const name = variable.name;
if (rank === 2) {
return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`;
}
const batchRank = batchDims.rank;
let resStr = `var ${name}_indices: ${variable.type.indices};`;
for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`;
}
broadCastDims.forEach((i) => {
resStr += `\n${name}_indices[${i}] = 0;`;
});
resStr += `${name}_indices[${rank - 2}] = 0u;
${name}_indices[${rank - 1}] = 0u;`;
return resStr;
};

const calcResult = (): string => {
let calcStr = `var a_data: ${a.type.value};`;
for (let i = 0; i < aComponents; i++) {
Expand Down Expand Up @@ -125,9 +128,17 @@ export const createNaiveMatmulProgramInfo = (
let batch = index1 / stride1;
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
${getIndices(a, broadCastADims)}
var a_indices: ${a.type.indices};
${convertOutputBatchIndicesToInputBatchIndices('a_indices', a, a.rank - 2, batchDims.rank, 'batch_indices')}
${a.indicesSet('a_indices', a.rank - 2, 0)}
${a.indicesSet('a_indices', a.rank - 1, 0)}
let a_offset = ${a.indicesToOffset('a_indices')};
${getIndices(b, broadCastBDims)}
var b_indices: ${b.type.indices};
${convertOutputBatchIndicesToInputBatchIndices('b_indices', b, b.rank - 2, batchDims.rank, 'batch_indices')}
${b.indicesSet('b_indices', b.rank - 2, 0)}
${b.indicesSet('b_indices', b.rank - 1, 0)}
let b_offset = ${b.indicesToOffset('b_indices')};
var values: array<${output.type.value}, ${outputNumber}>;
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
Expand Down
Loading

0 comments on commit 2ce8aa0

Please sign in to comment.