Skip to content

Commit

Permalink
[js/web/training] added end-to-end tests (#18700)
Browse files Browse the repository at this point in the history
## Summary
* following inference's [set-up for end-to-end
tests](https://github.com/microsoft/onnxruntime/tree/main/js/web/test/e2e),
created an end-to-end test runner for training
* this test runner copies testdata from the [trainingapi
folder](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/test/testdata/training_api)
* then runs two tests (training session with evalModel & optimizer
model, and training session with the minimum options), and tests if the
ORT-web training package encompasses inference
  * these tests check 
    * createTrainingSession
    * runTrainStep
    * runOptimizerStep if applicable
* the parameters methods (getParametersSize, loadParametersBuffer, and
getContiguousParameters)

## TL;DR
*
[`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c)
is responsible for setting up and running the end to end tests
*
[`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8)
contains the test function definitions (`testInferenceFunction`,
`testTrainingFunctionMin`, `testTrainingFunctionAll`)

## Flow
* entrypoint: user runs the following command in the terminal: `npm run
test:training:e2e`
*
[`js/web/package.json`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-79275844e75c3c410bb3a71c7f59b2b633e5a3e975c804ffc47220025084da28)
was modified to include an npm script that will run `run.js` which will
run the end to end tests
*
[`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c)
is responsible for
  * detecting and installing local tarball packages of ORT-web
  * copying training data to the `js/web/training/e2e/data` folder
* starting two Karma processes. Karma is a test runner framework that
simulates testing in the browser.
* In this case, the tests happen in Chrome. We can configure the tests
to run in Edge and other browsers in the future.
* one of these karma processes is self-hosted, meaning it pulls the
ORT-web package from local
* the other karma process is not self-hosted, meaning it pulls the
ORT-web package from another source. In this case, we start an http
server that serves the ORT-web binaries.
*
[`js/web/test/training/e2e/simple-http-server.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-f798ab485f3ec26c299fe5b2923574c9e4b090200ba20d490bbf6c183286993c)
is responsible for starting the HTTP server and serving the ORT binary
files. This code almost identical to the same code in the inference E2E
tests.
*
[`js/web/test/training/e2e/karma.conf.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-436cfe8f670c768a04895bd4a1874a5e033f85e0e2d84941c62ff1f7c30a9f28)
Karma configuration file that specifies what happens when a karma
process is started. The config specifies Mocha as the testing framework,
which will go through all the loaded files and run any tests that exist
*
[`js/web/test/training/e2e/browser-test-wasm.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-13b6155e106dddc7b531ef671186e69b2aadb8a0f4b2f3001db0991567d78221)
File that contains the tests that Mocha will pick up on and run.
* The test functions (such as testInference and testTrainingFunctionAll)
are defined in
[`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8).

## Notes
* I followed the [tests for training
core](https://github.com/microsoft/onnxruntime/blob/b023de0bfc7acb2404dfdcc4adc060b7b72fdaa1/orttraining/orttraining/test/training_api/core/training_api_tests.cc)
where they randomly generated input for the training session
* E2E tests are triggered by running `npm run test:training:e2e` --
suggestions for alternative script names are appreciated!!!

## Motivation and Context
- adding training bindings for web
  • Loading branch information
carzh authored and mszhanyi committed Jan 15, 2024
1 parent 2a5f3f9 commit 3586e2a
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 0 deletions.
1 change: 1 addition & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md",
"pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts",
"test:e2e": "node ./test/e2e/run",
"test:training:e2e": "node ./test/training/e2e/run",
"prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit",
"build": "node ./script/build",
"test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli",
Expand Down
21 changes: 21 additions & 0 deletions js/web/test/training/e2e/browser-test-wasm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

'use strict';

describe('Browser E2E testing for training package', function() {
it('Check that training package encompasses inference', async function() {
ort.env.wasm.numThreads = 1;
await testInferenceFunction(ort, {executionProviders: ['wasm']});
});

it('Check training functionality, all options', async function() {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionAll(ort, {executionProviders: ['wasm']});
});

it('Check training functionality, minimum options', async function() {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionMin(ort, {executionProviders: ['wasm']});
});
});
246 changes: 246 additions & 0 deletions js/web/test/training/e2e/common.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

'use strict';

const DATA_FOLDER = 'data/';
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';

const trainingSessionAllOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
evalModel: TRAININGDATA_EVAL_MODEL,
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL
}

const trainingSessionMinOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
}

// ASSERT METHODS

function assert(cond) {
if (!cond) throw new Error();
}

function assertStrictEquals(actual, expected) {
if (actual !== expected) {
let strRep = actual;
if (typeof actual === 'object') {
strRep = JSON.stringify(actual);
}
throw new Error(`expected: ${expected}; got: ${strRep}`);
}
}

function assertTwoListsUnequal(list1, list2) {
if (list1.length !== list2.length) {
return;
}
for (let i = 0; i < list1.length; i++) {
if (list1[i] !== list2[i]) {
return;
}
}
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
}

// HELPER METHODS FOR TESTS

function generateGaussianRandom(mean=0, scale=1) {
const u = 1 - Math.random();
const v = Math.random();
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
return z * scale + mean;
}

function generateGaussianFloatArray(length) {
const array = new Float32Array(length);

for (let i = 0; i < length; i++) {
array[i] = generateGaussianRandom();
}

return array;
}

/**
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
* training session are correct.
* @param {} ort
* @param {*} createOptions
* @param {*} options
* @returns
*/
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
const trainingSession = await ort.TrainingSession.create(createOptions, options);

assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
return trainingSession;
}

/**
* verifies that the eval input and output names associated with the eval model loaded into the given training session
* are correct.
*/
function checkEvalModel(trainingSession) {
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
assertStrictEquals(trainingSession.evalInputNames.length, 2);
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
}

/**
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
* accessed
* @param {} trainingSession
*/
function checkNoEvalModel(trainingSession) {
try {
assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing");
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
try {
assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing");
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
}

/**
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
* of 1 for the loss.
* @param {} trainingSession
* @param {*} feeds
* @returns
*/
var runTrainStepAndCheck = async function(trainingSession, feeds) {
const results = await trainingSession.runTrainStep(feeds);
assertStrictEquals(Object.keys(results).length, 1);
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
return results;
};

var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) {
// make a float32 array that is filled with the constant
const newParams = new Float32Array(paramsLength);
for (let i = 0; i < paramsLength; i++) {
newParams[i] = constant;
}

const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);

await trainingSession.loadParametersBuffer(newParamsUint8);
const paramsAfterLoad = await trainingSession.getContiguousParameters();

// check that the parameters have changed
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);

// check that the parameters have changed to what they should be
for (let i = 0; i < paramsLength; i++) {
// round to the same number of digits (4 decimal places)
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
}

return paramsAfterLoad;
}

// TESTS

var testInferenceFunction = async function(ort, options) {
const session = await ort.InferenceSession.create('data/model.onnx', options || {});

const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);

const fetches =
await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])});

const c = fetches.c;

assert(c instanceof ort.Tensor);
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
assert(c.data[0] === 700);
assert(c.data[1] === 800);
assert(c.data[2] === 900);
assert(c.data[3] === 1580);
assert(c.data[4] === 1840);
assert(c.data[5] === 2100);
assert(c.data[6] === 2460);
assert(c.data[7] === 2880);
assert(c.data[8] === 3300);
};

var testTrainingFunctionMin = async function(ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
checkNoEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
const feeds = {"input-0": input0, "labels": labels};

// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);

// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);

await runTrainStepAndCheck(trainingSession, feeds);

await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
}

var testTrainingFunctionAll = async function(ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
checkEvalModel(trainingSession);

const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
let feeds = {"input-0": input0, "labels": labels};

// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);

// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);

const results = await runTrainStepAndCheck(trainingSession, feeds);

await trainingSession.runOptimizerStep(feeds);
feeds = {"input-0": input0, "labels": labels};
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
const optimizedParams = await trainingSession.getContiguousParameters();
assertTwoListsUnequal(originalParams.data, optimizedParams.data);

const results2 = await runTrainStepAndCheck(trainingSession, feeds);

// check that loss decreased after optimizer step and training again
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);

await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
}

if (typeof module === 'object') {
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
}
16 changes: 16 additions & 0 deletions js/web/test/training/e2e/data/model.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
 backend-test:b

a
bc"MatMultest_matmul_2dZ
a


Z
b


b
c


B
54 changes: 54 additions & 0 deletions js/web/test/training/e2e/karma.conf.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

'use strict';

const args = require('minimist')(process.argv.slice(2));
const SELF_HOST = !!args['self-host'];
const ORT_MAIN = args['ort-main'];
const TEST_MAIN = args['test-main'];
if (typeof TEST_MAIN !== 'string') {
throw new Error('flag --test-main=<TEST_MAIN_JS_FILE> is required');
}
const USER_DATA = args['user-data'];
if (typeof USER_DATA !== 'string') {
throw new Error('flag --user-data=<CHROME_USER_DATA_FOLDER> is required');
}

module.exports = function(config) {
const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/';
config.set({
frameworks: ['mocha'],
files: [
{pattern: distPrefix + ORT_MAIN},
{pattern: './common.js'},
{pattern: TEST_MAIN},
{pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true},
{pattern: './data/*', included: false},
],
plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins],
proxies: {
'/model.onnx': '/base/model.onnx',
'/data/': '/base/data/',
},
client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}},
reporters: ['mocha'],
captureTimeout: 120000,
reportSlowerThan: 100,
browserDisconnectTimeout: 600000,
browserNoActivityTimeout: 300000,
browserDisconnectTolerance: 0,
browserSocketTimeout: 60000,
hostname: 'localhost',
browsers: [],
customLaunchers: {
Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA},
Chrome_no_threads: {
base: 'ChromeHeadless',
chromeDataDir: USER_DATA,
// TODO: no-thread flags
},
Edge_default: {base: 'Edge', edgeDataDir: USER_DATA}
}
});
};
14 changes: 14 additions & 0 deletions js/web/test/training/e2e/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"devDependencies": {
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
"fs-extra": "^11.1.0",
"globby": "^13.1.3",
"karma": "^6.4.1",
"karma-chrome-launcher": "^3.1.1",
"karma-mocha": "^2.0.1",
"karma-mocha-reporter": "^2.2.5",
"light-server": "^2.9.1",
"minimist": "^1.2.7",
"mocha": "^10.2.0"
}
}
Loading

0 comments on commit 3586e2a

Please sign in to comment.