Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for Uint8ClampedArray #21985

Merged
merged 22 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface {
*/
constructor(
type: TensorType,
data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[],
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new CPU tensor object from the given data and dims. Type is inferred from data.
*/
constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]);
constructor(
data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[],
dims?: readonly number[],
);
/**
* Construct a new tensor object from the pinned CPU data with the given type and dims.
*
Expand Down Expand Up @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface {
arg0:
| TensorType
| TensorDataType
| Uint8ClampedArray
| readonly string[]
| readonly boolean[]
| CpuPinnedConstructorParameters
| TextureConstructorParameters
| GpuBufferConstructorParameters,
arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[],
arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[],
arg2?: readonly number[],
) {
// perform one-time check for BigInt/Float16Array support
Expand Down Expand Up @@ -216,6 +220,8 @@ export class Tensor implements TensorInterface {
}
} else if (arg1 instanceof typedArrayConstructor) {
data = arg1;
} else if (arg1 instanceof Uint8ClampedArray) {
data = Uint8Array.from(arg1);
} else {
throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`);
}
Expand Down Expand Up @@ -243,6 +249,9 @@ export class Tensor implements TensorInterface {
} else {
throw new TypeError(`Invalid element type of data array: ${firstElementType}.`);
}
} else if (arg0 instanceof Uint8ClampedArray) {
type = 'uint8';
data = Uint8Array.from(arg0);
} else {
// get tensor type from TypedArray
const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get(
Expand Down
8 changes: 8 additions & 0 deletions js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ export interface TensorConstructor extends TensorFactory {
*/
new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>;

/**
* Construct a new uint8 tensor object from the given data and dims.
*
* @param data - Specify the CPU tensor data.
* @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed.
*/
new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>;

/**
* Construct a new uint16 tensor object from the given data and dims.
*
Expand Down
19 changes: 19 additions & 0 deletions js/common/test/type-tests/tensor/create-new-uint8.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import * as ort from 'onnxruntime-common';

// construct from Uint8Array
//
// {type-tests}|pass
new ort.Tensor(new Uint8Array(1));

// construct from Uint8ClampedArray
//
// {type-tests}|pass
new ort.Tensor(new Uint8ClampedArray(1));

// construct from type (bool), data (Uint8ClampedArray) and shape (number array)
//
// {type-tests}|fail
new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]);
8 changes: 8 additions & 0 deletions js/common/test/unit-tests/tensor/constructor-type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => {
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");
});

it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => {
const uint8ClampedArray = new Uint8ClampedArray(2);
uint8ClampedArray[0] = 0;
uint8ClampedArray[1] = 256; // clamped
const tensor = new Tensor('uint8', uint8ClampedArray, [2]);
prathikr marked this conversation as resolved.
Show resolved Hide resolved
assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'");
});

it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => {
const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]);
assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'");
Expand Down
Loading