Skip to content

Commit

Permalink
Merge pull request #63 from mei1127/add_cast
Browse files Browse the repository at this point in the history
Implement cast
  • Loading branch information
huningxin authored Dec 21, 2023
2 parents 058aa77 + 431fb3c commit 64c784e
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 5 deletions.
1 change: 1 addition & 0 deletions .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module.exports = {
'chai': 'readonly',
'BigInt': 'readonly',
'BigInt64Array': 'readonly',
'BigUint64Array': 'readonly',
},
rules: {
'semi': 'error',
Expand Down
46 changes: 46 additions & 0 deletions src/cast.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
'use strict';

import {Tensor} from '../src/lib/tensor.js';

/**
* Cast each element in the input tensor to the target data type.
* @param {Tensor} input
* @return {Tensor}
*/

export function cast(input, type) {
let outputArray;
switch (type) {
case 'int8':
outputArray = new Int8Array(input.data);
break;
case 'uint8':
outputArray = new Uint8Array(input.data);
break;
case 'int32':
outputArray = new Int32Array(input.data);
break;
case 'uint32':
outputArray = new Uint32Array(input.data);
break;
case 'int64':
outputArray = new BigInt64Array(Array.from(input.data, (num) => BigInt(Math.trunc(num))));
break;
case 'float32':
outputArray = new Float32Array(input.data);
break;
case 'float64':
outputArray = new Float64Array(input.data);
break;
case 'float16':
// TODO: https://github.com/webmachinelearning/webnn-baseline/issues/66
throw new Error('Unsupported output type: float16');
case 'uint64':
// TODO: https://github.com/webmachinelearning/webnn-baseline/issues/67
throw new Error('Unsupported output type: uint64');
default:
throw new Error('Unsupported output type: ' + type);
}
const output = new Tensor(input.shape, outputArray);
return output;
}
175 changes: 175 additions & 0 deletions test/cast_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
'use strict';

import {Tensor} from '../src/lib/tensor.js';
import {cast} from '../src/cast.js';
import * as utils from './utils.js';

describe('test cast', function() {
const InputDataType = {
int8: Int8Array,
uint8: Uint8Array,
int32: Int32Array,
uint32: Uint32Array,
int64: BigInt64Array,
float32: Float32Array,
float64: Float64Array,
};
function testCast(input, type, expected) {
let tensorInput;
if (input.type) {
tensorInput = new Tensor(input.shape, new InputDataType[input.type](input.data));
} else {
tensorInput = new Tensor(input.shape, input.data);
}
const outputTensor = cast(tensorInput, type);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('cast float64 to int8', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.75, 14, -14,
],
};
const expected = {
shape: [5],
data: [
0, 0, 3, 14, -14,
],
};
testCast(input, 'int8', expected);
});

it('cast float64 to uint8', function() {
const input = {
shape: [5],
data: [
0.25, 0.75, 3.75, 14, 15,
],
};
const expected = {
shape: [5],
data: [
0, 0, 3, 14, 15,
],
};
testCast(input, 'uint8', expected);
});

it('cast float64 to int32', function() {
const input = {
shape: [5],
data: [
-0.25, 0.75, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
0, 0, 3, 1234, -1234,
],
};
testCast(input, 'int32', expected);
});

it('cast float64 to uint32', function() {
const input = {
shape: [5],
data: [
0.75, 0.25, 3.21, 14, 15,
],
};
const expected = {
shape: [5],
data: [
0, 0, 3, 14, 15,
],
};
testCast(input, 'uint32', expected);
});

it('cast float64 to int64', function() {
const input = {
shape: [5],
data: [
-0.25, 0.75, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
0n, 0n, 3n, 1234n, -1234n,
],
};
testCast(input, 'int64', expected);
});

it('cast float64 to float32', function() {
const input = {
shape: [5],
data: [
-0.25, 0.25, 3.21, 1234, -1234,
],
};
const expected = {
shape: [5],
data: [
-0.25, 0.25, 3.2100000381469727, 1234, -1234,
],
};
testCast(input, 'float32', expected);
});

it('cast int32 to float32', function() {
const input = {
shape: [5],
data: [
0, 1, -2, -3, 3,
],
type: 'int32',
};
const expected = {
shape: [5],
data: [
0, 1, -2, -3, 3,
],
};
testCast(input, 'float32', expected);
});

it('cast uint32 to float64', function() {
const input = {
shape: [5],
data: [
0, 1, 22, 33, 33,
],
type: 'uint32',
};
const expected = {
shape: [5],
data: [
0, 1, 22, 33, 33,
],
};
testCast(input, 'float64', expected);
});

it('cast float32 to float64', function() {
const input = {
shape: [5],
data: [
0, 0.1, 0.2, -300, 993,
],
type: 'float32',
};
const expected = {
shape: [5],
data: [
0, 0.10000000149011612, 0.20000000298023224, -300, 993,
],
};
testCast(input, 'float64', expected);
});
});
16 changes: 11 additions & 5 deletions test/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ function getBitwise(value) {
* false: The distance between a and b is far away from the given ULP distance.
*/
assert.isAlmostEqualUlp = function(a, b, nulp, message) {
const aBitwise = getBitwise(a);
const bBitwise = getBitwise(b);
let distance = aBitwise - bBitwise;
distance = distance >= 0 ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
if (typeof(a) == 'number') {
const aBitwise = getBitwise(a);
const bBitwise = getBitwise(b);
let distance = aBitwise - bBitwise;
distance = distance >= 0 ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
} else {
let distance = a - b;
distance = distance >= 0n ? distance : -distance;
return assert.isTrue(distance <= nulp, message);
}
};

export function checkValue(tensor, expected, nulp = 0) {
Expand Down

0 comments on commit 64c784e

Please sign in to comment.