From ce08215be39167217e92b5ed88a610cb3b40d7ff Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Fri, 12 Jan 2024 13:33:33 -0800 Subject: [PATCH] [js/web/training] added end-to-end tests (#18700) ## Summary * following inference's [set-up for end-to-end tests](https://github.com/microsoft/onnxruntime/tree/main/js/web/test/e2e), created an end-to-end test runner for training * this test runner copies testdata from the [trainingapi folder](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/test/testdata/training_api) * then runs two tests (training session with evalModel & optimizer model, and training session with the minimum options), and tests if the ORT-web training package encompasses inference * these tests check * createTrainingSession * runTrainStep * runOptimizerStep if applicable * the parameters methods (getParametersSize, loadParametersBuffer, and getContiguousParameters) ## TL;DR * [`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c) is responsible for setting up and running the end to end tests * [`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8) contains the test function definitions (`testInferenceFunction`, `testTrainingFunctionMin`, `testTrainingFunctionAll`) ## Flow * entrypoint: user runs the following command in the terminal: `npm run test:training:e2e` * [`js/web/package.json`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-79275844e75c3c410bb3a71c7f59b2b633e5a3e975c804ffc47220025084da28) was modified to include an npm script that will run `run.js` which will run the end to end tests * [`js/web/test/training/e2e/run.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-c1359c4d401f9ba69e937814219cefe5fd11b151a6ffd084c641af3c82e8216c) is responsible for * detecting and installing local tarball packages of ORT-web * copying training data to the `js/web/training/e2e/data` folder * starting two Karma processes. Karma is a test runner framework that simulates testing in the browser. * In this case, the tests happen in Chrome. We can configure the tests to run in Edge and other browsers in the future. * one of these karma processes is self-hosted, meaning it pulls the ORT-web package from local * the other karma process is not self-hosted, meaning it pulls the ORT-web package from another source. In this case, we start an http server that serves the ORT-web binaries. * [`js/web/test/training/e2e/simple-http-server.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-f798ab485f3ec26c299fe5b2923574c9e4b090200ba20d490bbf6c183286993c) is responsible for starting the HTTP server and serving the ORT binary files. This code almost identical to the same code in the inference E2E tests. * [`js/web/test/training/e2e/karma.conf.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-436cfe8f670c768a04895bd4a1874a5e033f85e0e2d84941c62ff1f7c30a9f28) Karma configuration file that specifies what happens when a karma process is started. The config specifies Mocha as the testing framework, which will go through all the loaded files and run any tests that exist * [`js/web/test/training/e2e/browser-test-wasm.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-13b6155e106dddc7b531ef671186e69b2aadb8a0f4b2f3001db0991567d78221) File that contains the tests that Mocha will pick up on and run. * The test functions (such as testInference and testTrainingFunctionAll) are defined in [`js/web/test/training/e2e/common.js`](https://github.com/microsoft/onnxruntime/compare/main...carzh:onnxruntime:carzh/training-e2e-runner?expand=1#diff-ee5452491b7b2563d175d13d81d10f2323b12b18589aa4c5798962a8b904a4a8). ## Notes * I followed the [tests for training core](https://github.com/microsoft/onnxruntime/blob/b023de0bfc7acb2404dfdcc4adc060b7b72fdaa1/orttraining/orttraining/test/training_api/core/training_api_tests.cc) where they randomly generated input for the training session * E2E tests are triggered by running `npm run test:training:e2e` -- suggestions for alternative script names are appreciated!!! ## Motivation and Context - adding training bindings for web --- web/package.json | 1 + web/test/training/e2e/browser-test-wasm.js | 21 ++ web/test/training/e2e/common.js | 246 ++++++++++++++++++++ web/test/training/e2e/data/model.onnx | 16 ++ web/test/training/e2e/karma.conf.js | 54 +++++ web/test/training/e2e/package.json | 14 ++ web/test/training/e2e/run.js | 138 +++++++++++ web/test/training/e2e/simple-http-server.js | 64 +++++ 8 files changed, 554 insertions(+) create mode 100644 web/test/training/e2e/browser-test-wasm.js create mode 100644 web/test/training/e2e/common.js create mode 100644 web/test/training/e2e/data/model.onnx create mode 100644 web/test/training/e2e/karma.conf.js create mode 100644 web/test/training/e2e/package.json create mode 100644 web/test/training/e2e/run.js create mode 100644 web/test/training/e2e/simple-http-server.js diff --git a/web/package.json b/web/package.json index 9b4531d7766fe..7ffc9ba16aaa9 100644 --- a/web/package.json +++ b/web/package.json @@ -24,6 +24,7 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", + "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", diff --git a/web/test/training/e2e/browser-test-wasm.js b/web/test/training/e2e/browser-test-wasm.js new file mode 100644 index 0000000000000..fa87389f7ac46 --- /dev/null +++ b/web/test/training/e2e/browser-test-wasm.js @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +describe('Browser E2E testing for training package', function() { + it('Check that training package encompasses inference', async function() { + ort.env.wasm.numThreads = 1; + await testInferenceFunction(ort, {executionProviders: ['wasm']}); + }); + + it('Check training functionality, all options', async function() { + ort.env.wasm.numThreads = 1; + await testTrainingFunctionAll(ort, {executionProviders: ['wasm']}); + }); + + it('Check training functionality, minimum options', async function() { + ort.env.wasm.numThreads = 1; + await testTrainingFunctionMin(ort, {executionProviders: ['wasm']}); + }); +}); diff --git a/web/test/training/e2e/common.js b/web/test/training/e2e/common.js new file mode 100644 index 0000000000000..b6040b63d56b4 --- /dev/null +++ b/web/test/training/e2e/common.js @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const DATA_FOLDER = 'data/'; +const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; +const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; +const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; +const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; + +const trainingSessionAllOptions = { + checkpointState: TRAININGDATA_CKPT, + trainModel: TRAININGDATA_TRAIN_MODEL, + evalModel: TRAININGDATA_EVAL_MODEL, + optimizerModel: TRAININGDATA_OPTIMIZER_MODEL +} + +const trainingSessionMinOptions = { + checkpointState: TRAININGDATA_CKPT, + trainModel: TRAININGDATA_TRAIN_MODEL, +} + +// ASSERT METHODS + +function assert(cond) { + if (!cond) throw new Error(); +} + +function assertStrictEquals(actual, expected) { + if (actual !== expected) { + let strRep = actual; + if (typeof actual === 'object') { + strRep = JSON.stringify(actual); + } + throw new Error(`expected: ${expected}; got: ${strRep}`); + } +} + +function assertTwoListsUnequal(list1, list2) { + if (list1.length !== list2.length) { + return; + } + for (let i = 0; i < list1.length; i++) { + if (list1[i] !== list2[i]) { + return; + } + } + throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); +} + +// HELPER METHODS FOR TESTS + +function generateGaussianRandom(mean=0, scale=1) { + const u = 1 - Math.random(); + const v = Math.random(); + const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); + return z * scale + mean; +} + +function generateGaussianFloatArray(length) { + const array = new Float32Array(length); + + for (let i = 0; i < length; i++) { + array[i] = generateGaussianRandom(); + } + + return array; +} + +/** + * creates the TrainingSession and verifies that the input and output names of the training model loaded into the + * training session are correct. + * @param {} ort + * @param {*} createOptions + * @param {*} options + * @returns + */ +async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { + const trainingSession = await ort.TrainingSession.create(createOptions, options); + + assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); + assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); + assertStrictEquals(trainingSession.trainingInputNames.length, 2); + assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); + assertStrictEquals(trainingSession.trainingOutputNames.length, 1); + return trainingSession; +} + +/** + * verifies that the eval input and output names associated with the eval model loaded into the given training session + * are correct. + */ +function checkEvalModel(trainingSession) { + assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); + assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); + assertStrictEquals(trainingSession.evalInputNames.length, 2); + assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); + assertStrictEquals(trainingSession.evalOutputNames.length, 1); +} + +/** + * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if + * accessed + * @param {} trainingSession + */ +function checkNoEvalModel(trainingSession) { + try { + assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing"); + } catch (error) { + assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); + } + try { + assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing"); + } catch (error) { + assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); + } +} + +/** + * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length + * of 1 for the loss. + * @param {} trainingSession + * @param {*} feeds + * @returns + */ +var runTrainStepAndCheck = async function(trainingSession, feeds) { + const results = await trainingSession.runTrainStep(feeds); + assertStrictEquals(Object.keys(results).length, 1); + assertStrictEquals(results['onnx::loss::21273'].data.length, 1); + assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); + return results; +}; + +var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) { + // make a float32 array that is filled with the constant + const newParams = new Float32Array(paramsLength); + for (let i = 0; i < paramsLength; i++) { + newParams[i] = constant; + } + + const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); + + await trainingSession.loadParametersBuffer(newParamsUint8); + const paramsAfterLoad = await trainingSession.getContiguousParameters(); + + // check that the parameters have changed + assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); + assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); + + // check that the parameters have changed to what they should be + for (let i = 0; i < paramsLength; i++) { + // round to the same number of digits (4 decimal places) + assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); + } + + return paramsAfterLoad; +} + +// TESTS + +var testInferenceFunction = async function(ort, options) { + const session = await ort.InferenceSession.create('data/model.onnx', options || {}); + + const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); + + const fetches = + await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])}); + + const c = fetches.c; + + assert(c instanceof ort.Tensor); + assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); + assert(c.data[0] === 700); + assert(c.data[1] === 800); + assert(c.data[2] === 900); + assert(c.data[3] === 1580); + assert(c.data[4] === 1840); + assert(c.data[5] === 2100); + assert(c.data[6] === 2460); + assert(c.data[7] === 2880); + assert(c.data[8] === 3300); +}; + +var testTrainingFunctionMin = async function(ort, options) { + const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); + checkNoEvalModel(trainingSession); + const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); + const labels = new ort.Tensor('int32', [2, 1], [2]); + const feeds = {"input-0": input0, "labels": labels}; + + // check getParametersSize + const paramsSize = await trainingSession.getParametersSize(); + assertStrictEquals(paramsSize, 397510); + + // check getContiguousParameters + const originalParams = await trainingSession.getContiguousParameters(); + assertStrictEquals(originalParams.dims.length, 1); + assertStrictEquals(originalParams.dims[0], 397510); + assertStrictEquals(originalParams.data[0], -0.025190064683556557); + assertStrictEquals(originalParams.data[2000], -0.034044936299324036); + + await runTrainStepAndCheck(trainingSession, feeds); + + await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); +} + +var testTrainingFunctionAll = async function(ort, options) { + const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); + checkEvalModel(trainingSession); + + const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); + const labels = new ort.Tensor('int32', [2, 1], [2]); + let feeds = {"input-0": input0, "labels": labels}; + + // check getParametersSize + const paramsSize = await trainingSession.getParametersSize(); + assertStrictEquals(paramsSize, 397510); + + // check getContiguousParameters + const originalParams = await trainingSession.getContiguousParameters(); + assertStrictEquals(originalParams.dims.length, 1); + assertStrictEquals(originalParams.dims[0], 397510); + assertStrictEquals(originalParams.data[0], -0.025190064683556557); + assertStrictEquals(originalParams.data[2000], -0.034044936299324036); + + const results = await runTrainStepAndCheck(trainingSession, feeds); + + await trainingSession.runOptimizerStep(feeds); + feeds = {"input-0": input0, "labels": labels}; + // check getContiguousParameters after optimizerStep -- that the parameters have been updated + const optimizedParams = await trainingSession.getContiguousParameters(); + assertTwoListsUnequal(originalParams.data, optimizedParams.data); + + const results2 = await runTrainStepAndCheck(trainingSession, feeds); + + // check that loss decreased after optimizer step and training again + assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); + + await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); +} + +if (typeof module === 'object') { + module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; +} diff --git a/web/test/training/e2e/data/model.onnx b/web/test/training/e2e/data/model.onnx new file mode 100644 index 0000000000000..088124bd48624 --- /dev/null +++ b/web/test/training/e2e/data/model.onnx @@ -0,0 +1,16 @@ + backend-test:b + +a +bc"MatMultest_matmul_2dZ +a +  + +Z +b +  + +b +c +  + +B \ No newline at end of file diff --git a/web/test/training/e2e/karma.conf.js b/web/test/training/e2e/karma.conf.js new file mode 100644 index 0000000000000..e441cb65b4125 --- /dev/null +++ b/web/test/training/e2e/karma.conf.js @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const args = require('minimist')(process.argv.slice(2)); +const SELF_HOST = !!args['self-host']; +const ORT_MAIN = args['ort-main']; +const TEST_MAIN = args['test-main']; +if (typeof TEST_MAIN !== 'string') { + throw new Error('flag --test-main= is required'); +} +const USER_DATA = args['user-data']; +if (typeof USER_DATA !== 'string') { + throw new Error('flag --user-data= is required'); +} + +module.exports = function(config) { + const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; + config.set({ + frameworks: ['mocha'], + files: [ + {pattern: distPrefix + ORT_MAIN}, + {pattern: './common.js'}, + {pattern: TEST_MAIN}, + {pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true}, + {pattern: './data/*', included: false}, + ], + plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], + proxies: { + '/model.onnx': '/base/model.onnx', + '/data/': '/base/data/', + }, + client: {captureConsole: true, mocha: {expose: ['body'], timeout: 60000}}, + reporters: ['mocha'], + captureTimeout: 120000, + reportSlowerThan: 100, + browserDisconnectTimeout: 600000, + browserNoActivityTimeout: 300000, + browserDisconnectTolerance: 0, + browserSocketTimeout: 60000, + hostname: 'localhost', + browsers: [], + customLaunchers: { + Chrome_default: {base: 'ChromeHeadless', chromeDataDir: USER_DATA}, + Chrome_no_threads: { + base: 'ChromeHeadless', + chromeDataDir: USER_DATA, + // TODO: no-thread flags + }, + Edge_default: {base: 'Edge', edgeDataDir: USER_DATA} + } + }); +}; diff --git a/web/test/training/e2e/package.json b/web/test/training/e2e/package.json new file mode 100644 index 0000000000000..5f11a27de6dfc --- /dev/null +++ b/web/test/training/e2e/package.json @@ -0,0 +1,14 @@ +{ + "devDependencies": { + "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", + "fs-extra": "^11.1.0", + "globby": "^13.1.3", + "karma": "^6.4.1", + "karma-chrome-launcher": "^3.1.1", + "karma-mocha": "^2.0.1", + "karma-mocha-reporter": "^2.2.5", + "light-server": "^2.9.1", + "minimist": "^1.2.7", + "mocha": "^10.2.0" + } +} diff --git a/web/test/training/e2e/run.js b/web/test/training/e2e/run.js new file mode 100644 index 0000000000000..379a8136f3ff8 --- /dev/null +++ b/web/test/training/e2e/run.js @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +const path = require('path'); +const fs = require('fs-extra'); +const {spawn} = require('child_process'); +const startServer = require('./simple-http-server'); +const minimist = require('minimist'); + +// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file +// exists in its parent folder. +// here we use /build/js/e2e-training/ for the test + +const TEST_E2E_SRC_FOLDER = __dirname; +const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); +const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); +const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); +const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); +fs.emptyDirSync(TEST_E2E_RUN_FOLDER); +fs.emptyDirSync(NPM_CACHE_FOLDER); +fs.emptyDirSync(CHROME_USER_DATA_FOLDER); +fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); + +// training data to copy +const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); +const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); +const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); + +// always use a new folder as user-data-dir +let nextUserDataDirId = 0; +function getNextUserDataDir() { + const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()) + nextUserDataDirId++; + fs.emptyDirSync(dir); + return dir; +} + +// commandline arguments +const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; + +async function main() { + // find packed package + const {globbySync} = await import('globby'); + + const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); + const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', {cwd: ORT_COMMON_FOLDER}); + + const PACKAGES_TO_INSTALL = []; + + if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { + PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); + } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { + throw new Error('multiple packages found for onnxruntime-common.'); + } + + const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); + const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', {cwd: ORT_WEB_FOLDER}); + if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { + throw new Error('cannot find exactly single package for onnxruntime-web.'); + } + PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); + + // we start here: + + // install dev dependencies + await runInShell(`npm install`); + + // npm install with "--cache" to install packed packages with an empty cache folder + await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map(i => `"${i}"`).join(' ')}`); + + // prepare training data + prepareTrainingDataByCopying(); + + console.log('==============================================================='); + console.log("Running self-hosted tests"); + console.log('==============================================================='); + // test cases with self-host (ort hosted in same origin) + await testAllBrowserCases({hostInKarma: true}); + + console.log('==============================================================='); + console.log("Running not self-hosted tests"); + console.log('==============================================================='); + // test cases without self-host (ort hosted in same origin) + startServer(path.resolve(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web')); + await testAllBrowserCases({hostInKarma: false}); + + // no error occurs, exit with code 0 + process.exit(0); +} + +async function testAllBrowserCases({hostInKarma}) { + await runKarma({hostInKarma, main: './browser-test-wasm.js'}); +} + +async function runKarma({hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js'}) { + console.log('==============================================================='); + console.log(`Running karma with the following binary: ${ortMain}`); + console.log('==============================================================='); + const selfHostFlag = hostInKarma ? '--self-host' : ''; + await runInShell(`npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ + ortMain} --test-main=${main} --user-data=${getNextUserDataDir()}`); +} + +async function runInShell(cmd) { + console.log('==============================================================='); + console.log(' Running command in shell:'); + console.log(' > ' + cmd); + console.log('==============================================================='); + let complete = false; + const childProcess = spawn(cmd, {shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER}); + childProcess.on('close', function(code) { + if (code !== 0) { + process.exit(code); + } else { + complete = true; + } + }); + while (!complete) { + await delay(100); + } +} + +async function delay(ms) { + return new Promise(function(resolve) { + setTimeout(function() { + resolve(); + }, ms); + }); +} + +function prepareTrainingDataByCopying() { + fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); + console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); +} + +main(); diff --git a/web/test/training/e2e/simple-http-server.js b/web/test/training/e2e/simple-http-server.js new file mode 100644 index 0000000000000..a157c7dd93ad8 --- /dev/null +++ b/web/test/training/e2e/simple-http-server.js @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +'use strict'; + +// this is a simple HTTP server that enables CORS. +// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework + +const http = require('http'); +const fs = require('fs'); +const path = require('path'); + +const validRequests = { + // .wasm files + '/dist/ort-wasm.wasm': ['dist/ort-wasm.wasm', 'application/wasm'], + '/dist/ort-wasm-simd.wasm': ['dist/ort-wasm-simd.wasm', 'application/wasm'], + '/dist/ort-training-wasm-simd.wasm': ['dist/ort-training-wasm-simd.wasm', 'application/wasm'], + '/dist/ort-wasm-threaded.wasm': ['dist/ort-wasm-threaded.wasm', 'application/wasm'], + '/dist/ort-wasm-simd-threaded.wasm': ['dist/ort-wasm-simd-threaded.wasm', 'application/wasm'], + + // proxied .wasm files: + '/test-wasm-path-override/ort-wasm.wasm': ['dist/ort-training-wasm.wasm', 'application/wasm'], + //'/test-wasm-path-override/renamed.wasm': ['dist/ort-wasm.wasm', 'application/wasm'], + + // .js files + '/dist/ort.min.js': ['dist/ort.min.js', 'text/javascript'], + '/dist/ort.training.simd.wasm.min.js': ['dist/ort.training.simd.wasm.min.js', 'text/javascript'], + '/dist/ort.training.wasm.min.js': ['dist/ort.training.wasm.min.js', 'text/javascript'], + '/dist/ort.js': ['dist/ort.js', 'text/javascript'], + '/dist/ort.webgl.min.js': ['dist/ort.webgl.min.js', 'text/javascript'], + '/dist/ort.wasm.min.js': ['dist/ort.wasm.min.js', 'text/javascript'], + '/dist/ort.wasm-core.min.js': ['dist/ort.wasm-core.min.js', 'text/javascript'], +}; + +module.exports = function(dir) { + http.createServer(function(request, response) { + console.log(`request ${request.url.replace(/\n|\r/g, '')}`); + + const requestData = validRequests[request.url]; + if (!request) { + response.writeHead(404); + response.end('404'); + } else { + const [filePath, contentType] = requestData; + fs.readFile(path.resolve(dir, filePath), function(error, content) { + if (error) { + if (error.code == 'ENOENT') { + response.writeHead(404); + response.end('404'); + } else { + response.writeHead(500); + response.end('500'); + } + } else { + response.setHeader('access-control-allow-origin', '*'); + response.writeHead(200, {'Content-Type': contentType}); + response.end(content, 'utf-8'); + } + }); + } + }) + .listen(8081); + console.log('Server running at http://127.0.0.1:8081/'); + };