Skip to content

Commit

Permalink
[js/webgpu] Use DataType as uniform cpu type (microsoft#19281)
Browse files Browse the repository at this point in the history
This saves turning data type to string by tensorDataTypeEnumToString.
  • Loading branch information
axinging authored Jan 31, 2024
1 parent bc2b24b commit 04823cd
Show file tree
Hide file tree
Showing 37 changed files with 148 additions and 108 deletions.
18 changes: 10 additions & 8 deletions web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {tensorDataTypeEnumToString} from '../wasm-common';
import {DataType, tensorDataTypeEnumToString} from '../wasm-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
Expand Down Expand Up @@ -453,10 +453,10 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const sizeOfElement = v.type === 'float16' ? 2 : 4;
const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
if (v.type === DataType.float16) {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
Expand All @@ -470,7 +470,7 @@ export class WebGpuBackend {
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});
Expand All @@ -483,15 +483,17 @@ export class WebGpuBackend {
programUniforms.forEach((v, i) => {
const offset = offsets[i];
const data = typeof v.data === 'number' ? [v.data] : v.data;
if (v.type === 'int32') {
if (v.type === DataType.int32) {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
} else if (v.type === DataType.float16) {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
} else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);
} else {
throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
}
});

Expand Down
7 changes: 4 additions & 3 deletions web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo =
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo =
];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides},
{type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims},
{type: DataType.int32, data: pads}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
Expand Down Expand Up @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo =
const outputChannelsPerGroup = wShape[1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
{type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides},
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
];
if (hasBias) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -447,8 +448,10 @@ export const createMatmulProgramInfo =
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const bRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const programUniforms: ProgramUniform[] = [
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}
];
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
Expand Down
30 changes: 15 additions & 15 deletions web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {tensorDataTypeEnumToString} from '../../../wasm-common';
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';

Expand Down Expand Up @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
WG = Math.ceil(dComp / 8);
}
const elementsPerWG = Math.ceil(d / components / WG);
const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] =
[{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}];
const programUniforms: ProgramUniform[] = [
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
{type: DataType.uint32, data: elementsPerWG}
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -336,11 +337,10 @@ const computeAttentionProbs =
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};
const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize},
{type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength},
{type: tensorDataType, data: alpha}
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: parameters.totalSequenceLength},
{type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha}
];

const inputs = [q, key];
Expand Down Expand Up @@ -430,9 +430,9 @@ const computeVxAttentionScore =
z: params.batchSize * params.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength},
{type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads},
{type: 'uint32', data: params.vHiddenSize}
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
};
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N},
{type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize},
{type: 'uint32', data: parameters.hiddenSize},
{type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
{type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N},
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize},
{type: DataType.uint32, data: parameters.hiddenSize},
{type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down
5 changes: 3 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {env} from 'onnxruntime-common';

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
...createTensorShapeVariables(yShape),
] :
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
],
}),
};
Expand Down
2 changes: 1 addition & 1 deletion web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ const createBinaryOpProgramInfo =
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: [
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
Expand Down
5 changes: 3 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
return typeof mappedType === 'string' ? mappedType : mappedType[1];
};

export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] =>
dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}];
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
[] :
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];

/**
* A helper function to get maximum vector size for specified data length
Expand Down
5 changes: 3 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
Expand Down
13 changes: 8 additions & 5 deletions web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
Expand Down Expand Up @@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo =
const outputSize = ShapeUtil.size(outputShape);

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations},
{type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]},
{type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.uint32, data: outputChannelsPerGroup}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down Expand Up @@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo =
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}
{type: DataType.uint32, data: outputSize},
{type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]},
{type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
Expand Down
2 changes: 1 addition & 1 deletion web/lib/wasm/jsep/webgpu/ops/cumsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const createCumsumProgramInfo =
outputs: [{dims: inputShape, dataType: inputType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: 'uint32', data: outputSize}, {type: 'int32', data: axis},
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
]

Expand Down
7 changes: 5 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/einsum.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -272,8 +273,10 @@ const createEinsumProgramInfo =
// filter is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] =
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: 'uint32', data: outputSize});
.map(
(symbol) =>
({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
programUniformsInit.push({type: DataType.uint32, data: outputSize});
const programUniforms: ProgramUniform[] =
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
Expand Down
2 changes: 1 addition & 1 deletion web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
};

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
];
return {
Expand Down
7 changes: 5 additions & 2 deletions web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {MAX_CLIP, MIN_CLIP} from '../../util';
import {ProgramUniform} from '../types';

Expand Down Expand Up @@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v
export const appendActivationUniformsData =
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
if (attributes.activation === 'Clip') {
programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
programUniform.push(
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
} else if (attributes.activation === 'HardSigmoid') {
programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!});
programUniform.push(
{type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!});
}
};

Expand Down
Loading

0 comments on commit 04823cd

Please sign in to comment.