diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 78147ffc09ab7..39f8c2a6d0db3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -593,7 +593,6 @@ export class WebGpuBackend { } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === DataType.float16) { - // TODO: use Float16Array. new Uint16Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3f326881079f0..b31fbc6255c41 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,8 +3,9 @@ import { Env } from 'onnxruntime-common'; +import { calculateTensorSizeInBytes, DataType } from '../wasm-common'; + import type { OrtWasmModule } from '../wasm-types'; -import { DataType, calculateTensorSizeInBytes } from '../wasm-common'; import { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; @@ -22,6 +23,14 @@ class TensorViewImpl implements TensorView { public readonly dims: readonly number[], ) {} + getUint16Array(): Uint16Array { + if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) { + throw new Error('Invalid data type'); + } + const elementCount = ShapeUtil.size(this.dims); + return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount); + } + getFloat32Array(): Float32Array { if (this.dataType !== DataType.float) { throw new Error('Invalid data type'); diff --git a/js/web/lib/wasm/jsep/tensor-view.ts b/js/web/lib/wasm/jsep/tensor-view.ts index defc418c29264..5f1fdfa4534cd 100644 --- a/js/web/lib/wasm/jsep/tensor-view.ts +++ b/js/web/lib/wasm/jsep/tensor-view.ts @@ -28,6 +28,11 @@ export interface TensorView { readonly dataType: number; readonly dims: readonly number[]; + /** + * get a Float16Array data view of the tensor data. tensor data must be on CPU. + */ + getUint16Array(): Uint16Array; + /** * get a Float32Array data view of the tensor data. tensor data must be on CPU. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 4951bd0192baf..eb08cc7b713f2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -165,8 +165,10 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr { type: DataType.uint32, data: outputSize }, { type: DataType.int32, data: attributes.pads }, ]; + + const isValueFromInput = inputs.length >= 3 && inputs[2].data; if (attributes.mode === 0) { - programUniforms.push({ type: inputs[0].dataType, data: attributes.value }); + programUniforms.push({ type: isValueFromInput ? inputs[2].dataType : DataType.float, data: attributes.value }); } programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); @@ -182,7 +184,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr { name: 'pads', type: 'i32', length: attributes.pads.length }, ]; if (attributes.mode === 0) { - uniforms.push({ name: 'constant_value', type: dataType as UniformDataElementType }); + uniforms.push({ name: 'constant_value', type: (isValueFromInput ? dataType : 'f32') as UniformDataElementType }); } return ` @@ -200,7 +202,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr return { name: 'Pad', - shaderCache: { hint: `${attributes.mode}`, inputDependencies }, + shaderCache: { hint: `${attributes.mode}${isValueFromInput}`, inputDependencies }, getRunData: () => ({ outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) }, @@ -213,7 +215,12 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { if (inputs.length > 1) { const bigInt64Pads = inputs[1].getBigInt64Array(); - const value = inputs.length >= 3 && inputs[2].data ? inputs[2].getFloat32Array()[0] : 0.0; + const value = + inputs.length >= 3 && inputs[2].data + ? inputs[2].dataType === DataType.float16 + ? inputs[2].getUint16Array()[0] + : inputs[2].getFloat32Array()[0] + : 0.0; const inputRank = inputs[0].dims.length; const updatePads = new Int32Array(2 * inputRank).fill(0); diff --git a/js/web/test/data/ops/pad_f16.jsonc b/js/web/test/data/ops/pad_f16.jsonc index 44c61b8a95382..9e497cd2215ca 100644 --- a/js/web/test/data/ops/pad_f16.jsonc +++ b/js/web/test/data/ops/pad_f16.jsonc @@ -1,6 +1,6 @@ [ { - "name": "constant 2D float16", + "name": "constant 2D float16 v10", "operator": "Pad", "opset": { "domain": "", "version": 10 }, "attributes": [ @@ -33,7 +33,7 @@ ] }, { - "name": "constant 2D float16", + "name": "constant 2D float16 v19", "operator": "Pad", "opset": { "domain": "", "version": 19 }, "attributes": [{ "name": "mode", "data": "constant", "type": "string" }], diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index edbaeb6f4095c..829e55a625102 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1385,6 +1385,7 @@ "reduce-min.jsonc", "relu.jsonc", "gelu.jsonc", + "pad_f16.jsonc", "pad.jsonc", "pad-big.jsonc", "pow.jsonc",