Skip to content

Commit

Permalink
[js/webgpu] Use DataType as uniform cpu type
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Jan 26, 2024
1 parent 358650d commit bd035be
Show file tree
Hide file tree
Showing 36 changed files with 152 additions and 115 deletions.
14 changes: 7 additions & 7 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {tensorDataTypeEnumToString} from '../wasm-common';
import {DataType, tensorDataTypeEnumToString} from '../wasm-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
Expand Down Expand Up @@ -428,10 +428,10 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const sizeOfElement = v.type === 'float16' ? 2 : 4;
const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
if (v.type === DataType.float16) {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
Expand All @@ -445,7 +445,7 @@ export class WebGpuBackend {
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});
Expand All @@ -458,11 +458,11 @@ export class WebGpuBackend {
programUniforms.forEach((v, i) => {
const offset = offsets[i];
const data = typeof v.data === 'number' ? [v.data] : v.data;
if (v.type === 'int32') {
if (v.type === DataType.int32) {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
} else if (v.type === DataType.float16) {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
Expand Down
9 changes: 5 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -189,13 +190,13 @@ export const createConv2DMatMulProgramInfo =
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -197,13 +198,14 @@ export const createConv2DTransposeMatMulProgramInfo =
];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides},
{type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims},
{type: DataType.int32, data: pads}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
Expand Down Expand Up @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo =
const outputChannelsPerGroup = wShape[1];

const programUniforms: ProgramUniform[] = [
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
{type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides},
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
];
if (hasBias) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
// modified to fit the needs of the project

import {DataType} from '../../../../wasm-common';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
Expand Down Expand Up @@ -408,7 +409,7 @@ const matMulReadWriteFnSource =
${
hasBias ?
`value = value + ${isChannelsLast ? 'bias[colIn]' : `${typeSnippet(component, dataType)}(bias[row])`};` :
'' }
''}
${applyActivation}
${outputVariable.setByIndices('vec3<u32>(coords)', 'value')}
}
Expand Down Expand Up @@ -447,12 +448,14 @@ export const createMatmulProgramInfo =
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
const bShapeOrRank = bShapeTemp.length;
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const programUniforms: ProgramUniform[] = [
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
{type: DataType.int32, data: dimInner}
];
if (activationAttributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: activationAttributes.clipMax!},
{type: 'float32', data: activationAttributes.clipMin!});
{type: DataType.float, data: activationAttributes.clipMax!},
{type: DataType.float, data: activationAttributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
Expand Down
30 changes: 15 additions & 15 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {tensorDataTypeEnumToString} from '../../../wasm-common';
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';

Expand Down Expand Up @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
WG = Math.ceil(dComp / 8);
}
const elementsPerWG = Math.ceil(d / components / WG);
const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] =
[{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}];
const programUniforms: ProgramUniform[] = [
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
{type: DataType.uint32, data: elementsPerWG}
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -336,11 +337,10 @@ const computeAttentionProbs =
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};
const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type'];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize},
{type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength},
{type: tensorDataType, data: alpha}
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: parameters.totalSequenceLength},
{type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha}
];

const inputs = [q, key];
Expand Down Expand Up @@ -430,9 +430,9 @@ const computeVxAttentionScore =
z: params.batchSize * params.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength},
{type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads},
{type: 'uint32', data: params.vHiddenSize}
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down Expand Up @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
};
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N},
{type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize},
{type: 'uint32', data: parameters.hiddenSize},
{type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
{type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N},
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize},
{type: DataType.uint32, data: parameters.hiddenSize},
{type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {env} from 'onnxruntime-common';

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
...createTensorShapeVariables(yShape),
] :
[
{type: 'uint32', data: outputSize},
{type: DataType.uint32, data: outputSize},
],
}),
};
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,13 @@ const createBinaryOpProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: useShapesUniforms ?
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
] :
[
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
],
}),
};
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
return typeof mappedType === 'string' ? mappedType : mappedType[1];
};

export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] =>
dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}];
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
[] :
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];

/**
* A helper function to get maximum vector size for specified data length
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand Down Expand Up @@ -96,15 +97,15 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputShapeOrRanks = [];
const enableInputShapesUniforms = [];
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
if (enableInputShapesUniforms[i]) {
Expand Down
14 changes: 8 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
Expand Down Expand Up @@ -28,13 +29,14 @@ export const createGroupedConvProgramInfo =
const outputSize = ShapeUtil.size(outputShape);

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations},
{type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]},
{type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]},
{type: DataType.uint32, data: outputChannelsPerGroup}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
}
programUniforms.push(
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
Expand Down Expand Up @@ -132,8 +134,8 @@ export const createGroupedConvVectorizeProgramInfo =
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];

const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape),
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: attributes.strides},
{type: DataType.int32, data: attributes.pads}, ...createTensorShapeVariables(xShape),
...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader)
];
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const createCumsumProgramInfo =
outputs: [{dims: inputShape, dataType: inputType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: 'uint32', data: outputSize}, {type: 'int32', data: axis},
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
]

Expand Down
Loading

0 comments on commit bd035be

Please sign in to comment.