Skip to content

Commit

Permalink
adds support for Uint8ClampedArray (#21985)
Browse files Browse the repository at this point in the history
Fixes #21753
  • Loading branch information
prathikr authored Sep 12, 2024
1 parent d8e64bb commit d495e6c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
19 changes: 16 additions & 3 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface {
*/
constructor(
type: TensorType,
data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[],
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new CPU tensor object from the given data and dims. Type is inferred from data.
*/
constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]);
constructor(
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new tensor object from the pinned CPU data with the given type and dims.
*
Expand Down Expand Up @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface {
arg0:
| TensorType
| TensorDataType
| Uint8ClampedArray
| readonly string[]
| readonly boolean[]
| CpuPinnedConstructorParameters
| TextureConstructorParameters
| GpuBufferConstructorParameters,
arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[],
arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[],
arg2?: readonly number[],
) {
// perform one-time check for BigInt/Float16Array support
Expand Down Expand Up @@ -216,6 +220,12 @@ export class Tensor implements TensorInterface {
}
} else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
} else if (arg1 instanceof Uint8ClampedArray) {
if (arg0 === 'uint8') {
data = Uint8Array.from(arg1);
} else {
throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`);
}
} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
Expand Down Expand Up @@ -243,6 +253,9 @@ export class Tensor implements TensorInterface {
} else {
throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
}
} else if (arg0 instanceof Uint8ClampedArray) {
type = 'uint8';
data = Uint8Array.from(arg0);
} else {
// get tensor type from TypedArray
const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(
Expand Down
17 changes: 17 additions & 0 deletions js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory {
dims?: readonly number[],
): TypedTensor<'bool'>;

/**
* Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims.
*
* @param type - Specify the element type.
* @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>;

/**
* Construct a new 64-bit integer typed tensor object from the given type, data and dims.
*
Expand Down Expand Up @@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory {
*/
new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>;

/**
* Construct a new uint8 tensor object from the given data and dims.
*
* @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>;

/**
* Construct a new uint16 tensor object from the given data and dims.
*
Expand Down
19 changes: 19 additions & 0 deletions js/common/test/type-tests/tensor/create-new-uint8.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import * as ort from 'onnxruntime-common';

// construct from Uint8Array
//
// {type-tests}|pass
new ort.Tensor(new Uint8Array(1));

// construct from Uint8ClampedArray
//
// {type-tests}|pass
new ort.Tensor(new Uint8ClampedArray(1));

// construct from type (bool), data (Uint8ClampedArray) and shape (number array)
//
// {type-tests}|fail|1|2769
new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]);
8 changes: 8 additions & 0 deletions js/common/test/unit-tests/tensor/constructor-type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => {
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");
});

it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => {
const uint8ClampedArray = new Uint8ClampedArray(2);
uint8ClampedArray[0] = 0;
uint8ClampedArray[1] = 256; // clamped
const tensor = new Tensor('uint8', uint8ClampedArray, [2]);
assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'");
});

it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => {
const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]);
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");
Expand Down

0 comments on commit d495e6c

Please sign in to comment.