From 5fa3e166d5f95c5d4205b617b97bf46dd0a98b29 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 11 Jan 2024 16:43:34 +0800 Subject: [PATCH 1/4] [js/webgpu] Support f16 uniform --- js/web/lib/wasm/jsep/backend-webgpu.ts | 17 ++++++++++++----- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/types.ts | 2 +- onnxruntime/core/providers/js/operators/pad.cc | 10 +++++----- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 0148f32cdd91b..857166eb497f9 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -335,13 +335,17 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + let sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVec = v.type === 'float16' ? 8 : 16; + const baseAlignment = data.length <= 2 ? data.length * sizeOfElement : sizeOfVec; currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 4) and SizeOf(vec4) = 8. The total byte length is + // N * SizeOf(vec4). + currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * sizeOfVec : data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -356,6 +360,9 @@ export class WebGpuBackend { new Int32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === 'uint32') { new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'float16') { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); } else { new Float32Array(arrayBuffer, offset, data.length).set(data); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index eca3fa7d944bb..c65b741e1105a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 23fa33a9bba8f..93de034592f91 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -25,7 +25,7 @@ export interface TensorInfo { export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: 'int32'|'float16'|'float32'|'uint32'; data: number|readonly number[]; } diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 24ba85cbf6e0d..83fee35481aa6 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 2, 10, kJsExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Pad); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 17, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( 19, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), From 75b83443339f6e9f904ae861251ae5decca0904f Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 22 Jan 2024 10:08:46 +0800 Subject: [PATCH 2/4] Use mat2x4 for f16 --- js/web/lib/wasm/jsep/backend-webgpu.ts | 21 +++++++++---- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 37 +++++++++++++++-------- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 857166eb497f9..2951410589733 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -335,17 +335,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - let sizeOfElement = v.type === 'float16' ? 2 : 4; - let sizeOfVec = v.type === 'float16' ? 8 : 16; - const baseAlignment = data.length <= 2 ? data.length * sizeOfElement : sizeOfVec; + const sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVec; + let baseAlignment; + if (v.type === 'float16') { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVec = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVec = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type - // array,N>, where N = Math.ceil(data.length / 4) and SizeOf(vec4) = 8. The total byte length is - // N * SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * sizeOfVec : data.length * sizeOfElement; + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + currentOffset += + data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVec : data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..e21cb5c33f781 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -330,18 +330,25 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + return `${name}[(${index}) / 4][(${index}) % 4]`; + + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +695,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -861,7 +868,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); From 565bc828ed13256c97dc43e50d7702075aa6f9f1 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 22 Jan 2024 10:13:12 +0800 Subject: [PATCH 3/4] Nit --- js/web/lib/wasm/jsep/backend-webgpu.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2951410589733..6f80c310f7310 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -336,14 +336,14 @@ export class WebGpuBackend { } // https://www.w3.org/TR/WGSL/#alignof const sizeOfElement = v.type === 'float16' ? 2 : 4; - let sizeOfVec; + let sizeOfVecOrMat; let baseAlignment; if (v.type === 'float16') { baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); - sizeOfVec = data.length > 4 ? 16 : sizeOfElement * data.length; + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; - sizeOfVec = 16; + sizeOfVecOrMat = 16; } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); @@ -353,8 +353,8 @@ export class WebGpuBackend { // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; - currentOffset += - data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVec : data.length * sizeOfElement; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set From b39f5058cd9ad4faffcf56b9f9a455ff823007a9 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 25 Jan 2024 13:10:23 +0800 Subject: [PATCH 4/4] Enable string mode --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index e21cb5c33f781..643744108c0f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -336,8 +336,11 @@ export const getElementAt = (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { if (name.startsWith('uniforms.') && length > 4) { if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } } else { if (type === 'f16') { return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;