Skip to content

Commit

Permalink
Separate shaders impl from matmul.ts (OP impl) to avoid circular depe…
Browse files Browse the repository at this point in the history
…ndency
  • Loading branch information
jiangzhaoming committed Nov 1, 2024
1 parent b70e48c commit d30b5b5
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import {
getActivationSnippet,
InternalActivationAttributes,
} from '../fuse-utils';
import { convertOutputBatchIndicesToInputBatchIndices } from '../matmul';
import { convertOutputBatchIndicesToInputBatchIndices } from '../matmul-shaders';

import { typeSnippet } from './activation_util';

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { computeConv3DInfo, createConv3DNaiveProgramInfo } from './3rd-party/con
import { createMatmulProgramInfo } from './3rd-party/matmul_packed_webgpu';
import { createGroupedConvProgramInfo, createGroupedConvVectorizeProgramInfo } from './conv-grouped';
import { InternalActivationAttributes, parseInternalActivationAttributes } from './fuse-utils';
import { createNaiveMatmulProgramInfo } from './matmul';
import { createNaiveMatmulProgramInfo } from './matmul-shaders';
import { createTransposeProgramInfo } from './transpose';

export const calculateOutputShape = (
Expand Down
191 changes: 191 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul-shaders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { ProgramInfo, ProgramUniform } from '../types';

import {
createTensorShapeVariables,
getElementAt,
getMaxComponents,
IndicesHelper,
inputVariable,
internalVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglStorageType,
UniformsArrayType,
} from './common';
import {
appendActivationUniforms,
appendActivationUniformsData,
getActivationSnippet,
InternalActivationAttributes,
} from './fuse-utils';

// Helper that convert output batch indices to input batch indices using only the rank and
// the shape information in uniform
export const convertOutputBatchIndicesToInputBatchIndices = (
targetIndicesName: string,
inputVariable: IndicesHelper,
inputBatchRank: number,
outputBatchRank: number,
batchIndicesName: string,
) => {
// Assume outputBatchRank >= inputBatchRank, the first outputBatchRank - inputBatchRank of
// outputBatchRank should be ignored.
const extendingInputRank = outputBatchRank - inputBatchRank;
return `
${Array.from({ length: inputBatchRank })
.map(
(_, i) => `
if (${getElementAt(inputVariable.shape, i, inputVariable.rank)} != 1) {
${inputVariable.indicesSet(targetIndicesName, i, getElementAt(batchIndicesName, i + extendingInputRank, outputBatchRank))}
} else {
${inputVariable.indicesSet(targetIndicesName, i, 0)}
}`,
)
.join('')}
`;
};

export const createNaiveMatmulProgramInfo = (
inputs: readonly TensorView[],
activationAttributes: InternalActivationAttributes,
outputShape: readonly number[],
reshapedOutputShape?: readonly number[],
isChannelsLast = false /* only used for conv2dByMatMul*/,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const aShape = inputs[0].dims;
const bShape = inputs[1].dims;

const M = aShape[aShape.length - 2];
const N = bShape[bShape.length - 1];
const K = aShape[aShape.length - 1];
const components = getMaxComponents(N);
const aComponents = getMaxComponents(K);
const outputNumber = getMaxComponents(M);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const hasBias = inputs.length > 2;
const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
const batchSize = ShapeUtil.size(outerDims);
const outputShapeInShader = [batchSize, M, N];

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: M },
{ type: DataType.uint32, data: N },
{ type: DataType.uint32, data: K },
];
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShapeInShader));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
const inputVariables = [a, b];
let processBias = '';
if (hasBias) {
const biasComponents = isChannelsLast ? components : 1;
inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
processBias = `${
isChannelsLast ? `value += bias[col / ${biasComponents}];` : `value += ${output.type.value}(bias[row + i]);`
}`;
}

const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'M', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'K', type: 'u32' },
];
appendActivationUniforms(activationAttributes, uniforms);

const calcResult = (): string => {
let calcStr = `var a_data: ${a.type.value};`;
for (let i = 0; i < aComponents; i++) {
calcStr += `
let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
}
for (let i = 0; i < outputNumber; i++) {
calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;

for (let j = 0; j < aComponents; j++) {
calcStr += `
values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${i}]);\n`;
}
}
return calcStr;
};

return `
${shaderHelper
.registerUniforms(uniforms)
.registerInternalVariables(batchDims)
.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let col = (global_idx % (uniforms.N / ${components})) * ${components};
var index1 = global_idx / (uniforms.N / ${components});
let stride1 = uniforms.M / ${outputNumber};
let row = (index1 % stride1) * ${outputNumber};
let batch = index1 / stride1;
${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
var a_indices: ${a.type.indices};
${convertOutputBatchIndicesToInputBatchIndices('a_indices', a, a.rank - 2, batchDims.rank, 'batch_indices')}
${a.indicesSet('a_indices', a.rank - 2, 0)}
${a.indicesSet('a_indices', a.rank - 1, 0)}
let a_offset = ${a.indicesToOffset('a_indices')};
var b_indices: ${b.type.indices};
${convertOutputBatchIndicesToInputBatchIndices('b_indices', b, b.rank - 2, batchDims.rank, 'batch_indices')}
${b.indicesSet('b_indices', b.rank - 2, 0)}
${b.indicesSet('b_indices', b.rank - 1, 0)}
let b_offset = ${b.indicesToOffset('b_indices')};
var values: array<${output.type.value}, ${outputNumber}>;
for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
${calcResult()}
}
for (var i = 0u; i < ${outputNumber}u; i++) {
var value = values[i];
${processBias}
${applyActivation}
let cur_indices = ${output.type.indices}(batch, row + i, col);
let offset = ${output.indicesToOffset('cur_indices')};
${output.setByOffset(`offset / ${components}`, 'value')};
}
}
`;
};
return {
name: 'MatMulNaive',
shaderCache: {
hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
},
getRunData: () => ({
outputs: [
{
dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
dataType: inputs[0].dataType,
},
],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};
Loading

0 comments on commit d30b5b5

Please sign in to comment.