Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reviewability of WebNN operations baseline #1

Merged
merged 7 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module.exports = {
'indent': 2,
'require-jsdoc': 'off',
'max-len': ['error', {'code': 100}],
'prefer-rest-params': 'off'
},
extends: [
'eslint:recommended',
Expand Down
49 changes: 7 additions & 42 deletions src/batch_normalization.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,8 @@

import {add, sub, mul, div, pow} from './binary.js';
import {reshape} from './reshape.js';
import {Tensor, Scalar} from './tensor.js';

/**
* Check the tensor whether it is a 1-D tensor and its length is equal to `expectedSize`.
* @param {Tensor} a
* @param {Number} expectedSize
* @param {String} name
*/
function check1DTensorWithSize(a, expectedSize, name) {
if (a) {
if (a.rank !== 1) {
throw new Error(`The parameter ${name} is not a 1-D tensor.`);
} else {
if (a.shape[0] !== expectedSize) {
throw new Error(`The length ${a.shape[0]} of the ${name} values is not equal to the ` +
`size ${expectedSize} of the input dimension denoted by options.axis.`);
}
}
}
}
import {Tensor, Scalar} from './lib/tensor.js';
import {validateBatchNormalizationParams} from './lib/validate-input.js';

/**
* Normalize the tensor values of input features across the batch dimension using
Expand All @@ -32,25 +14,10 @@ function check1DTensorWithSize(a, expectedSize, name) {
* @param {MLBatchNormalizationOptions} [options]
* @return {Tensor}
*/
export function batchNormalization(input, mean, variance, options = {}) {
let axis = options.axis;
if (axis !== undefined) {
if (!Number.isInteger(axis)) {
throw new Error(`Invalid axis ${axis}, axis should be an integer.`);
}
} else {
axis = 1;
}
const dim = input.shape[axis];
check1DTensorWithSize(mean, dim, 'mean');
check1DTensorWithSize(variance, dim, 'variance');
const scale = options.scale;
check1DTensorWithSize(scale, dim, 'scale');
const bias = options.bias;
check1DTensorWithSize(bias, dim, 'bias');
const epsilon = options.epsilon ? options.epsilon : 1e-5;
const activation = options.activation;
// The output tensor of the same shape as the input tensor.
export function batchNormalization(input, mean, variance, {axis=1, scale, bias, epsilon=1e-5,
activation = (x) => x} = {}) {
validateBatchNormalizationParams(...arguments);
// The output tensor has the same shape as the input tensor.
let output = new Tensor(input.shape);
const shape = new Array(input.rank).fill(1);
shape[axis] = -1;
Expand All @@ -63,8 +30,6 @@ export function batchNormalization(input, mean, variance, options = {}) {
if (bias) {
output = add(output, reshape(bias, shape));
}
if (activation) {
output = activation(output);
}
output = activation(output);
return output;
}
4 changes: 2 additions & 2 deletions src/binary.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'use strict';

import {broadcast, getBroadcastShape} from './broadcast.js';
import {Tensor, sizeOfShape} from './tensor.js';
import {broadcast, getBroadcastShape} from './lib/broadcast.js';
import {Tensor, sizeOfShape} from './lib/tensor.js';

/**
* Compute the element-wise binary operation of two input tensors.
Expand Down
19 changes: 3 additions & 16 deletions src/clamp.js
Original file line number Diff line number Diff line change
@@ -1,31 +1,18 @@
'use strict';

import {Tensor} from './tensor.js';
import {Tensor} from './lib/tensor.js';

/**
* Clamp the input tensor element-wise within a range specified by the minimum and maximum values.
* @param {Tensor} input
* @param {MLClampOptions} [options]
* @return {Tensor}
*/
export function clamp(input, options = {}) {
export function clamp(input, {minValue=-Infinity, maxValue=Infinity} = {}) {
const output = new Tensor(input.shape);
for (let i = 0; i < input.size; ++i) {
const x = input.getValueByIndex(i);
let y;
if (options.minValue === undefined) {
if (options.maxValue === undefined) {
y = x;
} else {
y = Math.min(x, options.maxValue);
}
} else {
if (options.maxValue === undefined) {
y = Math.max(x, options.minValue);
} else {
y = Math.min(Math.max(x, options.minValue), options.maxValue);
}
}
const y = Math.min(Math.max(x, minValue), maxValue);
output.setValueByIndex(i, y);
}
return output;
Expand Down
28 changes: 4 additions & 24 deletions src/concat.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'use strict';

import {Tensor, sizeOfShape} from './tensor.js';
import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateConcatParams} from './lib/validate-input.js';

/**
* Concatenates the input tensors along a given axis.
Expand All @@ -9,32 +10,11 @@ import {Tensor, sizeOfShape} from './tensor.js';
* @return {Tensor}
*/
export function concat(inputs, axis) {
const rank = inputs[0].rank;
if (!Number.isInteger(axis)) {
throw new Error(`Invalid axis ${axis}, axis should be an integer.`);
} else {
if (axis < 0 || axis >= rank) {
throw new Error(`Invalid axis ${axis}, axis should be in the interval [0, ${rank}).`);
}
}
validateConcatParams(...arguments);
const inputShape = inputs[0].shape;
const outputShape = inputShape.slice();
for (let i = 1; i < inputs.length; ++i) {
if (inputs[i].rank !== rank) {
throw new Error('All input tensors should have the same rank.');
} else {
const shape = inputs[i].shape;
for (let j = 0; j < inputShape.length; ++j) {
if (j !== axis) {
if (inputShape[j] !== shape[j]) {
throw new Error('All input tensors should have the same shape, ' +
'except for the size of the dimension to concatenate on.');
}
} else {
outputShape[axis] += shape[axis];
}
}
}
outputShape[axis] += inputs[i].shape[axis];
}
const output = new Tensor(outputShape);
for (let i = 0; i < sizeOfShape(outputShape); ++i) {
Expand Down
97 changes: 23 additions & 74 deletions src/conv2d.js
Original file line number Diff line number Diff line change
@@ -1,65 +1,32 @@
'use strict';

import {Tensor} from './tensor.js';
import {Tensor} from './lib/tensor.js';
import {validateConv2dParams} from './lib/validate-input.js';
import {computePaddingForAutoPad} from './lib/compute-padding.js';
import {transpose} from './transpose.js';

/**
* Compute the beginning and ending pad given input, filter and stride.
* @param {String} autoPad
* @param {Number} inputSize
* @param {Number} effectiveFilterSize
* @param {Number} stride
* @return {Array} [paddingBegin, paddingEnd]
*/
export function computePaddingForAutoPad(autoPad, inputSize, effectiveFilterSize, stride) {
const outSize = Math.ceil(inputSize / stride);
const neededInput = (outSize - 1) * stride + effectiveFilterSize;
const totalPadding = neededInput > inputSize ? neededInput - inputSize : 0;
let paddingBegin;
let paddingEnd;
switch (autoPad) {
case 'same-upper':
paddingBegin = Math.floor(totalPadding / 2);
paddingEnd = Math.floor((totalPadding + 1) / 2);
break;
case 'same-lower':
paddingBegin = Math.floor((totalPadding + 1) / 2);
paddingEnd = Math.floor(totalPadding / 2);
break;
default:
throw new Error('The autoPad is invalid.');
}
return [paddingBegin, paddingEnd];
}

/**
* Compute a 2-D convolution given 4-D input and filter tensors.
* @param {Tensor} input
* @param {Tensor} filter
* @param {MLConv2dOptions} options
* @return {Tensor}
*/
export function conv2d(input, filter, options = {}) {
if (input.rank !== 4) {
throw new Error('The input should be a 4-D tensor.');
}

if (filter.rank !== 4) {
throw new Error('The filter should be a 4-D tensor.');
}

const padding = options.padding ? options.padding : [0, 0, 0, 0];
const strides = options.strides ? options.strides : [1, 1];
const groups = options.groups ? options.groups : 1;
const dilations = options.dilations ? options.dilations : [1, 1];
const activation = options.activation;

const inputLayout = options.inputLayout ? options.inputLayout : 'nchw';
export function conv2d(input, filter, {padding = [0, 0, 0, 0],
strides = [1, 1],
groups = 1,
dilations = [1, 1],
activation = (x) => x,
inputLayout = 'nchw',
filterLayout = 'oihw',
bias,
autoPad = 'explicit',
}
= {}) {
if (inputLayout === 'nhwc') {
// nhwc -> nchw
input = transpose(input, {permutation: [0, 3, 1, 2]});
}
const filterLayout = options.filterLayout ? options.filterLayout : 'oihw';
if (filterLayout === 'hwio') {
// hwio -> oihw
filter = transpose(filter, {permutation: [3, 2, 0, 1]});
Expand All @@ -70,43 +37,27 @@ export function conv2d(input, filter, options = {}) {
// ihwo -> oihw
filter = transpose(filter, {permutation: [3, 0, 1, 2]});
}
validateConv2dParams(input, filter, {groups, bias});

const batchCount = input.shape[0];
const inputChannels = input.shape[1];
const inputHeight = input.shape[2];
const inputWidth = input.shape[3];
const outputChannels = filter.shape[0];
const filterInputChannels = filter.shape[1];
const filterHeight = filter.shape[2];
const filterWidth = filter.shape[3];
const strideHeight = strides[0];
const strideWidth = strides[1];
const dilationHeight = dilations[0];
const dilationWidth = dilations[1];
const [batchCount, inputChannels, inputHeight, inputWidth] = input.shape;
const [outputChannels, , filterHeight, filterWidth] = filter.shape;
const [strideHeight, strideWidth] = strides;
const [dilationHeight, dilationWidth] = dilations;
const effectiveFilterHeight = filterHeight + (filterHeight - 1) * (dilationHeight - 1);
const effectiveFilterWidth = filterWidth + (filterWidth - 1) * (dilationWidth - 1);

if (inputChannels !== filterInputChannels * groups) {
throw new Error('The input channels of filter is invalid.');
}

const bias = options.bias;
if (bias && (bias.rank !== 1 || bias.shape[0] != outputChannels)) {
throw new Error('the bias should be a 1-D tensor with the shape of [output_channels].');
}

let beginningPaddingHeight;
let endingPaddingHeight;
let beginningPaddingWidth;
let endingPaddingWidth;
if (options.autoPad === undefined || options.autoPad === 'explicit') {
if (autoPad === 'explicit') {
[beginningPaddingHeight, endingPaddingHeight, beginningPaddingWidth, endingPaddingWidth] =
padding;
} else {
[beginningPaddingHeight, endingPaddingHeight] = computePaddingForAutoPad(
options.autoPad, inputHeight, effectiveFilterHeight, strideHeight);
autoPad, inputHeight, effectiveFilterHeight, strideHeight);
[beginningPaddingWidth, endingPaddingWidth] = computePaddingForAutoPad(
options.autoPad, inputWidth, effectiveFilterWidth, strideWidth);
autoPad, inputWidth, effectiveFilterWidth, strideWidth);
}

const outputShape = new Array(4);
Expand Down Expand Up @@ -178,9 +129,7 @@ export function conv2d(input, filter, options = {}) {
}
}

if (activation) {
output = activation(output);
}
output = activation(output);

if (inputLayout === 'nhwc') {
// nchw -> nhwc
Expand Down
26 changes: 11 additions & 15 deletions src/gemm.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import {add, mul} from './binary.js';
import {matmul} from './matmul.js';
import {Scalar} from './tensor.js';
import {Scalar} from './lib/tensor.js';
import {validateGemmParams} from './lib/validate-input.js';
import {transpose} from './transpose.js';

/**
Expand All @@ -13,19 +14,15 @@ import {transpose} from './transpose.js';
* @param {MLGemmOptions} options
* @return {Tensor}
*/
export function gemm(a, b, options = {}) {
if (a.rank !== 2) {
throw new Error('The input a is not a 2-D tensor.');
}
if (b.rank !== 2) {
throw new Error('The input b is not a 2-D tensor.');
}
const c = options.c ? options.c : undefined;
const alpha = new Scalar(options.alpha ? options.alpha : 1.0);
const beta = new Scalar(options.beta ? options.beta : 1.0);
const aTranspose = options.aTranspose ? options.aTranspose : false;
const bTranspose = options.bTranspose ? options.bTranspose : false;

export function gemm(a, b, {c = new Scalar(0.0),
alpha: fAlpha = 1.0,
beta: fBeta = 1.0,
aTranspose = false,
bTranspose = false,
} = {}) {
validateGemmParams(...arguments);
const alpha = new Scalar(fAlpha);
const beta = new Scalar(fBeta);
if (aTranspose) {
a = transpose(a);
}
Expand All @@ -35,7 +32,6 @@ export function gemm(a, b, options = {}) {
}

let output = matmul(mul(a, alpha), b);

if (c) {
output = add(output, mul(c, beta));
}
Expand Down
Loading