-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/web/training] added end-to-end tests (#18700)
## Summary * following inference's [set-up for end-to-end tests](https://github.com/microsoft/onnxruntime/tree/main/js/web/test/e2e), created an end-to-end test runner for training * this test runner copies testdata from the [trainingapi folder](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/test/testdata/training_api) * then runs two tests (training session with evalModel & optimizer model, and training session with the minimum options), and tests if the ORT-web training package encompasses inference * these tests check * createTrainingSession * runTrainStep * runOptimizerStep if applicable * the parameters methods (getParametersSize, loadParametersBuffer, and getContiguousParameters) ## TL;DR * [`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c) is responsible for setting up and running the end to end tests * [`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8) contains the test function definitions (`testInferenceFunction`, `testTrainingFunctionMin`, `testTrainingFunctionAll`) ## Flow * entrypoint: user runs the following command in the terminal: `npm run test:training:e2e` * [`js/web/package.json`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-79275844e75c3c410bb3a71c7f59b2b633e5a3e975c804ffc47220025084da28) was modified to include an npm script that will run `run.js` which will run the end to end tests * [`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c) is responsible for * detecting and installing local tarball packages of ORT-web * copying training data to the `js/web/training/e2e/data` folder * starting two Karma processes. Karma is a test runner framework that simulates testing in the browser. * In this case, the tests happen in Chrome. We can configure the tests to run in Edge and other browsers in the future. * one of these karma processes is self-hosted, meaning it pulls the ORT-web package from local * the other karma process is not self-hosted, meaning it pulls the ORT-web package from another source. In this case, we start an http server that serves the ORT-web binaries. * [`js/web/test/training/e2e/simple-http-server.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-f798ab485f3ec26c299fe5b2923574c9e4b090200ba20d490bbf6c183286993c) is responsible for starting the HTTP server and serving the ORT binary files. This code almost identical to the same code in the inference E2E tests. * [`js/web/test/training/e2e/karma.conf.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-436cfe8f670c768a04895bd4a1874a5e033f85e0e2d84941c62ff1f7c30a9f28) Karma configuration file that specifies what happens when a karma process is started. The config specifies Mocha as the testing framework, which will go through all the loaded files and run any tests that exist * [`js/web/test/training/e2e/browser-test-wasm.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-13b6155e106dddc7b531ef671186e69b2aadb8a0f4b2f3001db0991567d78221) File that contains the tests that Mocha will pick up on and run. * The test functions (such as testInference and testTrainingFunctionAll) are defined in [`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8). ## Notes * I followed the [tests for training core](https://github.com/microsoft/onnxruntime/blob/b023de0bfc7acb2404dfdcc4adc060b7b72fdaa1/orttraining/orttraining/test/training_api/core/training_api_tests.cc) where they randomly generated input for the training session * E2E tests are triggered by running `npm run test:training:e2e` -- suggestions for alternative script names are appreciated!!! ## Motivation and Context - adding training bindings for web
- Loading branch information
Showing
9 changed files
with
559 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
'use strict'; | ||
|
||
describe('Browser E2E testing for training package', function() { | ||
it('Check that training package encompasses inference', async function() { | ||
ort.env.wasm.numThreads = 1; | ||
await testInferenceFunction(ort, {executionProviders: ['wasm']}); | ||
}); | ||
|
||
it('Check training functionality, all options', async function() { | ||
ort.env.wasm.numThreads = 1; | ||
await testTrainingFunctionAll(ort, {executionProviders: ['wasm']}); | ||
}); | ||
|
||
it('Check training functionality, minimum options', async function() { | ||
ort.env.wasm.numThreads = 1; | ||
await testTrainingFunctionMin(ort, {executionProviders: ['wasm']}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
'use strict'; | ||
|
||
const DATA_FOLDER = 'data/'; | ||
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; | ||
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; | ||
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; | ||
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; | ||
|
||
const trainingSessionAllOptions = { | ||
checkpointState: TRAININGDATA_CKPT, | ||
trainModel: TRAININGDATA_TRAIN_MODEL, | ||
evalModel: TRAININGDATA_EVAL_MODEL, | ||
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL | ||
} | ||
|
||
const trainingSessionMinOptions = { | ||
checkpointState: TRAININGDATA_CKPT, | ||
trainModel: TRAININGDATA_TRAIN_MODEL, | ||
} | ||
|
||
// ASSERT METHODS | ||
|
||
function assert(cond) { | ||
if (!cond) throw new Error(); | ||
} | ||
|
||
function assertStrictEquals(actual, expected) { | ||
if (actual !== expected) { | ||
let strRep = actual; | ||
if (typeof actual === 'object') { | ||
strRep = JSON.stringify(actual); | ||
} | ||
throw new Error(`expected: ${expected}; got: ${strRep}`); | ||
} | ||
} | ||
|
||
function assertTwoListsUnequal(list1, list2) { | ||
if (list1.length !== list2.length) { | ||
return; | ||
} | ||
for (let i = 0; i < list1.length; i++) { | ||
if (list1[i] !== list2[i]) { | ||
return; | ||
} | ||
} | ||
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); | ||
} | ||
|
||
// HELPER METHODS FOR TESTS | ||
|
||
function generateGaussianRandom(mean=0, scale=1) { | ||
const u = 1 - Math.random(); | ||
const v = Math.random(); | ||
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); | ||
return z * scale + mean; | ||
} | ||
|
||
function generateGaussianFloatArray(length) { | ||
const array = new Float32Array(length); | ||
|
||
for (let i = 0; i < length; i++) { | ||
array[i] = generateGaussianRandom(); | ||
} | ||
|
||
return array; | ||
} | ||
|
||
/** | ||
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the | ||
* training session are correct. | ||
* @param {} ort | ||
* @param {*} createOptions | ||
* @param {*} options | ||
* @returns | ||
*/ | ||
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { | ||
const trainingSession = await ort.TrainingSession.create(createOptions, options); | ||
|
||
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); | ||
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); | ||
assertStrictEquals(trainingSession.trainingInputNames.length, 2); | ||
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); | ||
assertStrictEquals(trainingSession.trainingOutputNames.length, 1); | ||
return trainingSession; | ||
} | ||
|
||
/** | ||
* verifies that the eval input and output names associated with the eval model loaded into the given training session | ||
* are correct. | ||
*/ | ||
function checkEvalModel(trainingSession) { | ||
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); | ||
assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); | ||
assertStrictEquals(trainingSession.evalInputNames.length, 2); | ||
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); | ||
assertStrictEquals(trainingSession.evalOutputNames.length, 1); | ||
} | ||
|
||
/** | ||
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if | ||
* accessed | ||
* @param {} trainingSession | ||
*/ | ||
function checkNoEvalModel(trainingSession) { | ||
try { | ||
assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing"); | ||
} catch (error) { | ||
assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); | ||
} | ||
try { | ||
assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing"); | ||
} catch (error) { | ||
assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); | ||
} | ||
} | ||
|
||
/** | ||
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length | ||
* of 1 for the loss. | ||
* @param {} trainingSession | ||
* @param {*} feeds | ||
* @returns | ||
*/ | ||
var runTrainStepAndCheck = async function(trainingSession, feeds) { | ||
const results = await trainingSession.runTrainStep(feeds); | ||
assertStrictEquals(Object.keys(results).length, 1); | ||
assertStrictEquals(results['onnx::loss::21273'].data.length, 1); | ||
assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); | ||
return results; | ||
}; | ||
|
||
var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) { | ||
// make a float32 array that is filled with the constant | ||
const newParams = new Float32Array(paramsLength); | ||
for (let i = 0; i < paramsLength; i++) { | ||
newParams[i] = constant; | ||
} | ||
|
||
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); | ||
|
||
await trainingSession.loadParametersBuffer(newParamsUint8); | ||
const paramsAfterLoad = await trainingSession.getContiguousParameters(); | ||
|
||
// check that the parameters have changed | ||
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); | ||
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); | ||
|
||
// check that the parameters have changed to what they should be | ||
for (let i = 0; i < paramsLength; i++) { | ||
// round to the same number of digits (4 decimal places) | ||
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); | ||
} | ||
|
||
return paramsAfterLoad; | ||
} | ||
|
||
// TESTS | ||
|
||
var testInferenceFunction = async function(ort, options) { | ||
const session = await ort.InferenceSession.create('data/model.onnx', options || {}); | ||
|
||
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); | ||
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); | ||
|
||
const fetches = | ||
await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])}); | ||
|
||
const c = fetches.c; | ||
|
||
assert(c instanceof ort.Tensor); | ||
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); | ||
assert(c.data[0] === 700); | ||
assert(c.data[1] === 800); | ||
assert(c.data[2] === 900); | ||
assert(c.data[3] === 1580); | ||
assert(c.data[4] === 1840); | ||
assert(c.data[5] === 2100); | ||
assert(c.data[6] === 2460); | ||
assert(c.data[7] === 2880); | ||
assert(c.data[8] === 3300); | ||
}; | ||
|
||
var testTrainingFunctionMin = async function(ort, options) { | ||
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); | ||
checkNoEvalModel(trainingSession); | ||
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); | ||
const labels = new ort.Tensor('int32', [2, 1], [2]); | ||
const feeds = {"input-0": input0, "labels": labels}; | ||
|
||
// check getParametersSize | ||
const paramsSize = await trainingSession.getParametersSize(); | ||
assertStrictEquals(paramsSize, 397510); | ||
|
||
// check getContiguousParameters | ||
const originalParams = await trainingSession.getContiguousParameters(); | ||
assertStrictEquals(originalParams.dims.length, 1); | ||
assertStrictEquals(originalParams.dims[0], 397510); | ||
assertStrictEquals(originalParams.data[0], -0.025190064683556557); | ||
assertStrictEquals(originalParams.data[2000], -0.034044936299324036); | ||
|
||
await runTrainStepAndCheck(trainingSession, feeds); | ||
|
||
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); | ||
} | ||
|
||
var testTrainingFunctionAll = async function(ort, options) { | ||
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); | ||
checkEvalModel(trainingSession); | ||
|
||
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); | ||
const labels = new ort.Tensor('int32', [2, 1], [2]); | ||
let feeds = {"input-0": input0, "labels": labels}; | ||
|
||
// check getParametersSize | ||
const paramsSize = await trainingSession.getParametersSize(); | ||
assertStrictEquals(paramsSize, 397510); | ||
|
||
// check getContiguousParameters | ||
const originalParams = await trainingSession.getContiguousParameters(); | ||
assertStrictEquals(originalParams.dims.length, 1); | ||
assertStrictEquals(originalParams.dims[0], 397510); | ||
assertStrictEquals(originalParams.data[0], -0.025190064683556557); | ||
assertStrictEquals(originalParams.data[2000], -0.034044936299324036); | ||
|
||
const results = await runTrainStepAndCheck(trainingSession, feeds); | ||
|
||
await trainingSession.runOptimizerStep(feeds); | ||
feeds = {"input-0": input0, "labels": labels}; | ||
// check getContiguousParameters after optimizerStep -- that the parameters have been updated | ||
const optimizedParams = await trainingSession.getContiguousParameters(); | ||
assertTwoListsUnequal(originalParams.data, optimizedParams.data); | ||
|
||
const results2 = await runTrainStepAndCheck(trainingSession, feeds); | ||
|
||
// check that loss decreased after optimizer step and training again | ||
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); | ||
|
||
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); | ||
} | ||
|
||
if (typeof module === 'object') { | ||
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
backend-test:b | ||
a | ||
bc"MatMultest_matmul_2dZ | ||
a | ||
Z | ||
b | ||
b | ||
c | ||
B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
'use strict'; | ||
|
||
const args = require('minimist')(process.argv.slice(2)); | ||
const SELF_HOST = !!args['self-host']; | ||
const ORT_MAIN = args['ort-main']; | ||
const TEST_MAIN = args['test-main']; | ||
if (typeof TEST_MAIN !== 'string') { | ||
throw new Error('flag --test-main=<TEST_MAIN_JS_FILE> is required'); | ||
} | ||
const USER_DATA = args['user-data']; | ||
if (typeof USER_DATA !== 'string') { | ||
throw new Error('flag --user-data=<CHROME_USER_DATA_FOLDER> is required'); | ||
} | ||
|
||
module.exports = function(config) { | ||
const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; | ||
config.set({ | ||
frameworks: ['mocha'], | ||
files: [ | ||
{pattern: distPrefix + ORT_MAIN}, | ||
{pattern: './common.js'}, | ||
{pattern: TEST_MAIN}, | ||
{pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true}, | ||
{pattern: './data/*', included: false}, | ||
], | ||
plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], | ||
proxies: { | ||
'/model.onnx': '/base/model.onnx', | ||
'/data/': '/base/data/', | ||
}, | ||
client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}}, | ||
reporters: ['mocha'], | ||
captureTimeout: 120000, | ||
reportSlowerThan: 100, | ||
browserDisconnectTimeout: 600000, | ||
browserNoActivityTimeout: 300000, | ||
browserDisconnectTolerance: 0, | ||
browserSocketTimeout: 60000, | ||
hostname: 'localhost', | ||
browsers: [], | ||
customLaunchers: { | ||
Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA}, | ||
Chrome_no_threads: { | ||
base: 'ChromeHeadless', | ||
chromeDataDir: USER_DATA, | ||
// TODO: no-thread flags | ||
}, | ||
Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} | ||
} | ||
}); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
{ | ||
"devDependencies": { | ||
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", | ||
"fs-extra": "^11.1.0", | ||
"globby": "^13.1.3", | ||
"karma": "^6.4.1", | ||
"karma-chrome-launcher": "^3.1.1", | ||
"karma-mocha": "^2.0.1", | ||
"karma-mocha-reporter": "^2.2.5", | ||
"light-server": "^2.9.1", | ||
"minimist": "^1.2.7", | ||
"mocha": "^10.2.0" | ||
} | ||
} |
Oops, something went wrong.