diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 4e0ef821dde57..342f5e3a467eb 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface { */ constructor( type: TensorType, - data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[], + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[], dims?: readonly number[], ); /** * Construct a new CPU tensor object from the given data and dims. Type is inferred from data. */ - constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]); + constructor( + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[], + dims?: readonly number[], + ); /** * Construct a new tensor object from the pinned CPU data with the given type and dims. * @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface { arg0: | TensorType | TensorDataType + | Uint8ClampedArray | readonly string[] | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters | GpuBufferConstructorParameters, - arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], + arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { // perform one-time check for BigInt/Float16Array support @@ -216,6 +220,12 @@ export class Tensor implements TensorInterface { } } else if (arg1 instanceof typedArrayConstructor) { data = arg1; + } else if (arg1 instanceof Uint8ClampedArray) { + if (arg0 === 'uint8') { + data = Uint8Array.from(arg1); + } else { + throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); + } } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } @@ -243,6 +253,9 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`Invalid element type of data array: ${firstElementType}.`); } + } else if (arg0 instanceof Uint8ClampedArray) { + type = 'uint8'; + data = Uint8Array.from(arg0); } else { // get tensor type from TypedArray const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get( diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 70396bbe1e9a3..8a1197994393b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory { dims?: readonly number[], ): TypedTensor<'bool'>; + /** + * Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims. + * + * @param type - Specify the element type. + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new 64-bit integer typed tensor object from the given type, data and dims. * @@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory { */ new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; + /** + * Construct a new uint8 tensor object from the given data and dims. + * + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new uint16 tensor object from the given data and dims. * diff --git a/js/common/test/type-tests/tensor/create-new-uint8.ts b/js/common/test/type-tests/tensor/create-new-uint8.ts new file mode 100644 index 0000000000000..46438f97ca2e7 --- /dev/null +++ b/js/common/test/type-tests/tensor/create-new-uint8.ts @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as ort from 'onnxruntime-common'; + +// construct from Uint8Array +// +// {type-tests}|pass +new ort.Tensor(new Uint8Array(1)); + +// construct from Uint8ClampedArray +// +// {type-tests}|pass +new ort.Tensor(new Uint8ClampedArray(1)); + +// construct from type (bool), data (Uint8ClampedArray) and shape (number array) +// +// {type-tests}|fail|1|2769 +new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index def711684d7f5..02390800e8611 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => { assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); + it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => { + const uint8ClampedArray = new Uint8ClampedArray(2); + uint8ClampedArray[0] = 0; + uint8ClampedArray[1] = 256; // clamped + const tensor = new Tensor('uint8', uint8ClampedArray, [2]); + assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'"); + }); + it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => { const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]); assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");