Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support f16 uniform #19098

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,26 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
const sizeOfElement = v.type === 'float16' ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
sizeOfVecOrMat = 16;
}
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
offsets.push(currentOffset);
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>).
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});

// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
Expand All @@ -461,6 +474,9 @@ export class WebGpuBackend {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do some preprocessing to pass cpu data to Uint16Array if the cpu data is treated as float32? For example, the pad2~pad10, the constant_value is float not T as input.
cc @jzm-intel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my example is not correct for this case. If constant_value is float, the code will directly go to the float32 path not float16 case. But the concern is still there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, user should make sure the "data" is stored float16 bits, for example: If "data" is 1.0, then the corresponding float16 bits is 0x3c00. I am trying to add some warning msg to inform user that this should be unit test only: https://github.com/microsoft/onnxruntime/pull/19307/files#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR520

} else {
new Float32Array(arrayBuffer, offset, data.length).set(data);
}
Expand Down
40 changes: 27 additions & 13 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => {
* @param name - the name of variable.
* @param index - the index of variable element.
* @param length - the length of variable.
* @param type - the type of variable, optional.
*/
export const getElementAt = (name: string, index: number|string, length: number): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
return `${name}[(${index}) / 4][(${index}) % 4]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};
export const getElementAt =
(name: string, index: number|string, length: number, type?: UniformDataElementType): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
guschmue marked this conversation as resolved.
Show resolved Hide resolved
if (type === 'f16') {
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
} else {
return `${name}[(${index}) / 4][(${index}) % 4]`;
}
} else {
if (type === 'f16') {
return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};

/**
* A helper function to get a IndicesHelper for a given input or output.
Expand Down Expand Up @@ -688,7 +698,7 @@ export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);

export type UniformDataElementType = 'u32'|'f32'|'i32';
export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32';
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;

/**
Expand Down Expand Up @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper {
const uniformSnippets: string[] = [];
for (const {name, type, length} of this.uniforms) {
if (length && length > 4) {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
if (type === 'f16') {
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
} else {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
}
} else {
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
uniformSnippets.push(`${name}:${typeTemp}`);
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('Too few inputs');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Input type must be float.');
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
throw new Error('Input type must be float or float16.');
}

if (inputs.length >= 2) {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export interface TensorInfo {
}

export interface ProgramUniform {
type: 'int32'|'float32'|'uint32';
type: 'int32'|'float16'|'float32'|'uint32';
data: number|readonly number[];
}

Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/js/operators/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2,
10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Pad);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand All @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For version <=10, we can use JsepSupportedFloatTypes(). After versions (>10), use JsepSupportedDataTypes()?

.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
17,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand All @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX(
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
Expand Down
Loading