Skip to content

Commit

Permalink
[js/webgpu] Support f16 uniform (#19098)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
axinging authored and fs-eire committed Mar 15, 2024
1 parent f02accb commit 3abc3db
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 26 deletions.
26 changes: 21 additions & 5 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,26 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
const sizeOfElement = v.type === 'float16' ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
sizeOfVecOrMat = 16;
}
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
offsets.push(currentOffset);
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>).
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});

// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
Expand All @@ -449,6 +462,9 @@ export class WebGpuBackend {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
new Float32Array(arrayBuffer, offset, data.length).set(data);
}
Expand Down
40 changes: 27 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => {
* @param name - the name of variable.
* @param index - the index of variable element.
* @param length - the length of variable.
* @param type - the type of variable, optional.
*/
export const getElementAt = (name: string, index: number|string, length: number): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
return `${name}[(${index}) / 4][(${index}) % 4]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};
export const getElementAt =
(name: string, index: number|string, length: number, type?: UniformDataElementType): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
if (type === 'f16') {
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
} else {
return `${name}[(${index}) / 4][(${index}) % 4]`;
}
} else {
if (type === 'f16') {
return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};

/**
* A helper function to get a IndicesHelper for a given input or output.
Expand Down Expand Up @@ -688,7 +698,7 @@ export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);

export type UniformDataElementType = 'u32'|'f32'|'i32';
export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32';
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;

/**
Expand Down Expand Up @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper {
const uniformSnippets: string[] = [];
for (const {name, type, length} of this.uniforms) {
if (length && length > 4) {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
if (type === 'f16') {
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
} else {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
}
} else {
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
uniformSnippets.push(`${name}:${typeTemp}`);
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('Too few inputs');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Input type must be float.');
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
throw new Error('Input type must be float or float16.');
}

if (inputs.length >= 2) {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export interface TensorInfo {
}

export interface ProgramUniform {
type: 'int32'|'float32'|'uint32';
type: 'int32'|'float16'|'float32'|'uint32';
data: number|readonly number[];
}

Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/js/operators/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2,
10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Pad);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
17,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX(
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand Down

0 comments on commit 3abc3db

Please sign in to comment.