Skip to content

Commit

Permalink
[js/web] allow op test to use f16 type for inputs/outputs (#21664)
Browse files Browse the repository at this point in the history
### Description
allow op test to use f16 type for inputs/outputs.

This PR introduces "@petamoriken/float16" as Float16Array polyfill but
restricts it to be only used for test runner.
  • Loading branch information
fs-eire authored Aug 8, 2024
1 parent d616025 commit 5e66fcc
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 2 deletions.
13 changes: 13 additions & 0 deletions js/web/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
],
"devDependencies": {
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
"@petamoriken/float16": "^3.8.7",
"@types/chai": "^4.3.4",
"@types/emscripten": "^1.39.6",
"@types/flatbuffers": "^1.10.0",
Expand Down
74 changes: 74 additions & 0 deletions js/web/test/data/ops/pad_f16.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
[
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "mode", "data": "constant", "type": "string" },
{ "name": "value", "data": 1.2, "type": "float" },
{ "name": "pads", "data": [3, 2, 2, 3], "type": "ints" }
],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
},
{
"name": "constant 2D float16",
"operator": "Pad",
"opset": { "domain": "", "version": 19 },
"attributes": [{ "name": "mode", "data": "constant", "type": "string" }],
"cases": [
{
"name": "[2,2]->[7,7]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.5],
"dims": [2, 2],
"type": "float16"
},
{
"data": [3, 2, 2, 3],
"dims": [4],
"type": "int64"
},
{
"data": [1.2],
"dims": [1],
"type": "float16"
}
],
"outputs": [
{
"data": [
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.0, 2.0, 1.2, 1.2, 1.2, 1.2, 1.2, 3.0, 4.5, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2,
1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1.2
],
"dims": [7, 7],
"type": "float16"
}
]
}
]
}
]
4 changes: 4 additions & 0 deletions js/web/test/op-test-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -213,6 +214,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -247,6 +249,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -283,6 +286,7 @@
"properties": {
"type": {
"enum": [
"float16",
"float32",
"float64",
"int8",
Expand Down
29 changes: 27 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Float16Array as Float16ArrayPolyfill} from '@petamoriken/float16';
import {expect} from 'chai';
import * as ort from 'onnxruntime-common';
import {extname} from 'path';
Expand Down Expand Up @@ -391,6 +392,24 @@ export class TensorResultValidator {
case 'string':
return this.strictEqual(actual.data, expected.data);

case 'float16': {
const actualData = actual.data as Uint16Array;
const actualDataBuffer = actualData.buffer;
const actualDataByteOffset = actualData.byteOffset;
const actualDataLength = actualData.length;
const actualDataFloat32Array =
new Float32Array(new Float16ArrayPolyfill(actualDataBuffer, actualDataByteOffset, actualDataLength));

const expectedData = expected.data as Uint16Array;
const expectedDataBuffer = expectedData.buffer;
const expectedDataByteOffset = expectedData.byteOffset;
const expectedDataLength = expectedData.length;
const expectedDataFloat32Array =
new Float32Array(new Float16ArrayPolyfill(expectedDataBuffer, expectedDataByteOffset, expectedDataLength));

return this.floatEqual(actualDataFloat32Array, expectedDataFloat32Array);
}

case 'float32':
case 'float64':
return this.floatEqual(
Expand Down Expand Up @@ -919,11 +938,14 @@ async function runProtoOpTestcase(
const fetches: Record<string, Pick<ort.Tensor, 'dims'|'type'>> = {};
testCase.inputs.forEach((input, i) => {
if (input.data) {
let data: number[]|BigUint64Array|BigInt64Array = input.data;
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = input.data;
if (input.type === 'uint64') {
data = BigUint64Array.from(input.data.map(BigInt));
} else if (input.type === 'int64') {
data = BigInt64Array.from(input.data.map(BigInt));
} else if (input.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(input.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
feeds[`input_${i}`] = new ort.Tensor(input.type, data, input.dims);
}
Expand All @@ -933,11 +955,14 @@ async function runProtoOpTestcase(
const expectedOutputNames: string[] = [];
testCase.outputs.forEach((output, i) => {
if (output.data) {
let data: number[]|BigUint64Array|BigInt64Array = output.data;
let data: number[]|BigUint64Array|BigInt64Array|Uint16Array = output.data;
if (output.type === 'uint64') {
data = BigUint64Array.from(output.data.map(BigInt));
} else if (output.type === 'int64') {
data = BigInt64Array.from(output.data.map(BigInt));
} else if (output.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(output.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
outputs.push(new ort.Tensor(output.type, data, output.dims));
expectedOutputNames.push(`output_${i}`);
Expand Down

0 comments on commit 5e66fcc

Please sign in to comment.