Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] added end-to-end tests #18700

Merged
merged 6 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md",
"pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts",
"test:e2e": "node ./test/e2e/run",
"test:training:e2e": "node ./test/training/e2e/run",
"prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit",
"build": "node ./script/build",
"test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli",
Expand Down
21 changes: 21 additions & 0 deletions js/web/test/training/e2e/browser-test-wasm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

'use strict';

describe('Browser E2E testing for training package', function() {
it('Check that training package encompasses inference', async function() {
ort.env.wasm.numThreads = 1;
await testInferenceFunction(ort, {executionProviders: ['wasm']});
});

it('Check training functionality, all options', async function() {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionAll(ort, {executionProviders: ['wasm']});
});

it('Check training functionality, minimum options', async function() {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionMin(ort, {executionProviders: ['wasm']});
});
});
234 changes: 234 additions & 0 deletions js/web/test/training/e2e/common.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

'use strict';

const DATA_FOLDER = 'data/';
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';

const trainingSessionAllOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
evalModel: TRAININGDATA_EVAL_MODEL,
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL
}

const trainingSessionMinOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
}

// ASSERT METHODS

function assert(cond) {
if (!cond) throw new Error();
}

function assertStrictEquals(actual, expected) {
if (actual !== expected) {
let strRep = actual;
if (typeof actual === 'object') {
strRep = JSON.stringify(actual);
}
throw new Error(`expected: ${expected}; got: ${strRep}`);
}
}

// HELPER METHODS FOR TESTS

function generateGaussianRandom(mean=0, scale=1) {
const u = 1 - Math.random();
const v = Math.random();
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
return z * scale + mean;
}

function generateGaussianFloatArray(length) {
const array = new Float32Array(length);

for (let i = 0; i < length; i++) {
array[i] = generateGaussianRandom();
}

return array;
}

/**
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
* training session are correct.
* @param {} ort
* @param {*} createOptions
* @param {*} options
* @returns
*/
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
const trainingSession = await ort.TrainingSession.create(createOptions, options);

assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
return trainingSession;
}

/**
* verifies that the eval input and output names associated with the eval model loaded into the given training session
* are correct.
*/
function checkEvalModel(trainingSession) {
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
assertStrictEquals(trainingSession.evalInputNames.length, 2);
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
}

/**
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
* accessed
* @param {} trainingSession
*/
function checkNoEvalModel(trainingSession) {
try {
assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing");
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
try {
assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing");
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
}

/**
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
* of 1 for the loss.
* @param {} trainingSession
* @param {*} feeds
* @returns
*/
var runTrainStepAndCheck = async function(trainingSession, feeds) {
const results = await trainingSession.runTrainStep(feeds);
assertStrictEquals(Object.keys(results).length, 1);
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
return results;
};

var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) {
// make a float32 array that is filled with the constant
const newParams = new Float32Array(paramsLength);
for (let i = 0; i < paramsLength; i++) {
newParams[i] = constant;
}

const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);

await trainingSession.loadParametersBuffer(newParamsUint8);
const paramsAfterLoad = await trainingSession.getContiguousParameters();

// check that the parameters have changed
assert(paramsAfterLoad.data != paramsBefore.data);
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);

// check that the parameters have changed to what they should be
for (let i = 0; i < paramsLength; i++) {
// round to the same number of digits (4 decimal places)
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
}

return paramsAfterLoad;
}

// TESTS

var testInferenceFunction = async function(ort, options) {
const session = await ort.InferenceSession.create('data/model.onnx', options || {});

const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);

const fetches =
await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])});

const c = fetches.c;

assert(c instanceof ort.Tensor);
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
assert(c.data[0] === 700);
assert(c.data[1] === 800);
assert(c.data[2] === 900);
assert(c.data[3] === 1580);
assert(c.data[4] === 1840);
assert(c.data[5] === 2100);
assert(c.data[6] === 2460);
assert(c.data[7] === 2880);
assert(c.data[8] === 3300);
};

var testTrainingFunctionMin = async function(ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
checkNoEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
const feeds = {"input-0": input0, "labels": labels};

// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);

// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);

await runTrainStepAndCheck(trainingSession, feeds);

await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
}

var testTrainingFunctionAll = async function(ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
checkEvalModel(trainingSession);

const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
let feeds = {"input-0": input0, "labels": labels};

// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);

// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);

const results = await runTrainStepAndCheck(trainingSession, feeds);

await trainingSession.runOptimizerStep(feeds);
feeds = {"input-0": input0, "labels": labels};
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
const optimizedParams = await trainingSession.getContiguousParameters();
assert(originalParams.data != optimizedParams.data);
carzh marked this conversation as resolved.
Show resolved Hide resolved

const results2 = await runTrainStepAndCheck(trainingSession, feeds);

// check that loss decreased after optimizer step and training again
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);

await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
}

if (typeof module === 'object') {
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
}
16 changes: 16 additions & 0 deletions js/web/test/training/e2e/data/model.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 backend-test:b

a
bc"MatMultest_matmul_2dZ
a


Z
b


b
c


B
56 changes: 56 additions & 0 deletions js/web/test/training/e2e/karma.conf.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
askhade marked this conversation as resolved.
Show resolved Hide resolved
// Licensed under the MIT License.

'use strict';

const args = require('minimist')(process.argv.slice(2));
const SELF_HOST = !!args['self-host'];
const ORT_MAIN = args['ort-main'];
const TEST_MAIN = args['test-main'];
if (typeof TEST_MAIN !== 'string') {
throw new Error('flag --test-main=<TEST_MAIN_JS_FILE> is required');
}
const USER_DATA = args['user-data'];
if (typeof USER_DATA !== 'string') {
throw new Error('flag --user-data=<CHROME_USER_DATA_FOLDER> is required');
}

module.exports = function(config) {
const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/';
config.set({
frameworks: ['mocha'],
files: [
{pattern: distPrefix + ORT_MAIN},
{pattern: './common.js'},
{pattern: TEST_MAIN},
{pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true},
{pattern: './data/*', included: false},
],
plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins],
proxies: {
'/model.onnx': '/base/model.onnx',
'/data/': '/base/data/',
'/test-wasm-path-override/ort-wasm.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm',
'/test-wasm-path-override/renamed.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm',
carzh marked this conversation as resolved.
Show resolved Hide resolved
},
client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}},
reporters: ['mocha'],
captureTimeout: 120000,
reportSlowerThan: 100,
browserDisconnectTimeout: 600000,
browserNoActivityTimeout: 300000,
browserDisconnectTolerance: 0,
browserSocketTimeout: 60000,
hostname: 'localhost',
browsers: [],
customLaunchers: {
Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA},
Chrome_no_threads: {
base: 'ChromeHeadless',
chromeDataDir: USER_DATA,
// TODO: no-thread flags
},
Edge_default: {base: 'Edge', edgeDataDir: USER_DATA}
}
});
};
14 changes: 14 additions & 0 deletions js/web/test/training/e2e/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"devDependencies": {
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
"fs-extra": "^11.1.0",
"globby": "^13.1.3",
"karma": "^6.4.1",
"karma-chrome-launcher": "^3.1.1",
"karma-mocha": "^2.0.1",
"karma-mocha-reporter": "^2.2.5",
"light-server": "^2.9.1",
"minimist": "^1.2.7",
"mocha": "^10.2.0"
}
}
Loading
Loading