From cff6a7f11e5a687b359b3d5c28063afb37fa81dc Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 29 Aug 2023 21:50:18 +0000 Subject: [PATCH 01/19] added training build configuration --- js/web/lib/build-def.d.ts | 4 ++++ js/web/lib/wasm/wasm-factory.ts | 27 ++++++++++++++++++++------- js/web/script/build.ts | 4 ++++ js/web/webpack.config.js | 10 ++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 2049b2663ead3..939486926f319 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -30,6 +30,10 @@ interface BuildDefinitions { * defines whether to disable multi-threading feature in WebAssembly backend in the build. */ DISABLE_WASM_THREAD: boolean; + /** + * defines whether to enable training APIs in WebAssembly backend. + */ + ENABLE_TRAINING: boolean; } declare let BUILD_DEFS: BuildDefinitions; diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 7648f0c473f07..4ccb06e281b9d 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ -const ortWasmFactory: EmscriptenModuleFactory = - BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +let ortWasmFactory: EmscriptenModuleFactory; + +if (BUILD_DEFS.ENABLE_TRAINING) { + ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); +} +else { + ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : @@ -71,12 +77,19 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { + let wasmArtifact : string = 'ort'; + if (useTraining) { + wasmArtifact += '-training'; + } + wasmArtifact += '-wasm'; + if (useSimd) { + wasmArtifact += '-simd'; + } if (useThreads) { - return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm'; - } else { - return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm'; + wasmArtifact += '-threaded'; } + return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => { @@ -102,7 +115,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads); + const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 03510ae86b85f..bda7da0ab2a8b 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -33,6 +33,7 @@ const FILTER = args.f || args.filter; const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); +const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -45,6 +46,7 @@ const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist'); const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm'); const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm'); const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm'); +const TRAINING_WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-training-wasm-simd.wasm'); const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm'); const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm'); const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm'); @@ -75,6 +77,8 @@ if (WASM) { validateFile(WASM_SIMD_THREADED_WASM_PATH); validateFile(WASM_SIMD_JSEP_WASM_PATH); validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + validateFile(TRAINING_WASM_SIMD_WASM_PATH); } catch (e) { npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`); throw e; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 81c69ffdcf6bf..7115a3e3b4258 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -57,6 +57,7 @@ const DEFAULT_BUILD_DEFS = { DISABLE_WASM: false, DISABLE_WASM_PROXY: false, DISABLE_WASM_THREAD: false, + ENABLE_TRAINING: false }; // common config for release bundle @@ -289,6 +290,15 @@ module.exports = () => { buildOrtWebConfig({suffix: '.es6.min', target: 'es6'}), // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), + + // ort.wasm.min.js + buildOrtConfig({ + suffix: '.training.wasm.min', + build_defs: { + ...DEFAULT_BUILD_DEFS, + ENABLE_TRAINING: true, + } + }), ); case 'node': From 7a1365d93272567feda3a6682ad6a2cb79422898 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 29 Aug 2023 22:13:34 +0000 Subject: [PATCH 02/19] edited wasm factory + added training backend --- js/web/lib/backend-wasm-with-training.ts | 17 +++++++++++++++++ js/web/lib/backend-wasm.ts | 2 +- js/web/lib/index.ts | 11 +++++++---- js/web/lib/wasm/wasm-factory.ts | 22 +++++++++------------- 4 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 js/web/lib/backend-wasm-with-training.ts diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts new file mode 100644 index 0000000000000..6d31861fb1bcd --- /dev/null +++ b/js/web/lib/backend-wasm-with-training.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 04108c2ad0f66..3f06eea5329a0 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise { // populate wasm flags initializeFlags(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d5ed536034f3e..03e1108b31d12 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env} from 'onnxruntime-common'; +import {registerBackend, env, Backend} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : + require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (!BUILD_DEFS.ENABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 4ccb06e281b9d..1ce06b0f63b88 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -12,9 +12,9 @@ let ortWasmFactory: EmscriptenModuleFactory; if (BUILD_DEFS.ENABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); -} -else { - ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); } const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? @@ -78,18 +78,14 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { - let wasmArtifact : string = 'ort'; - if (useTraining) { - wasmArtifact += '-training'; - } - wasmArtifact += '-wasm'; if (useSimd) { - wasmArtifact += '-simd'; - } - if (useThreads) { - wasmArtifact += '-threaded'; + if (useTraining) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; + } else { + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } - return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => { From 71f9dbc94735908d4093e4fa114b62867e29306c Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 31 Aug 2023 21:02:16 +0000 Subject: [PATCH 03/19] added package.json modification for training + minimized training artifact --- js/web/lib/backend-wasm-with-training.ts | 4 +++- js/web/lib/index.ts | 2 +- js/web/package.json | 4 ++++ js/web/script/build.ts | 26 ++++++++++++++++++++++++ js/web/webpack.config.js | 5 +++-- 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 6d31861fb1bcd..5fac10da7791d 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -3,15 +3,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; -import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise { throw new Error('Method not implemented yet.'); } + /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 03e1108b31d12..30d0af5dc9e47 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env, Backend} from 'onnxruntime-common'; +import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { diff --git a/js/web/package.json b/js/web/package.json index 8ae5b733e5f21..798434bf5d574 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -78,6 +78,10 @@ "./webgpu": { "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" + }, + "./training": { + "types": "./types.d.ts", + "default": "./dist/ort-web.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index bda7da0ab2a8b..ff94d1a806a09 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -34,6 +34,7 @@ const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); +const TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.min.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -93,6 +94,31 @@ if (WASM) { */ `; + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"...'); + try { + const terser = spawnSync( + 'npx', + [ + 'terser', TRAINING_WASM_BINDING_SIMD_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', + '--mangle', 'reserved=[_scriptDir]', '--module' + ], + {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); + if (terser.status !== 0) { + console.error(terser.error); + process.exit(terser.status === null ? undefined : terser.status); + } + + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH, terser.stdout); + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + + validateFile(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + } catch (e) { + npmlog.error('Build', `Failed to run terser on ort-training-wasm-simd.js. ERR: ${e}`); + throw e; + } + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"... DONE'); + npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"...'); try { const terser = spawnSync( diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 7115a3e3b4258..1ac421b5246d9 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -104,6 +104,7 @@ function buildConfig({filename, format, target, mode, devtool, build_defs}) { config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js'; config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js'; config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js'; + config.resolve.alias['./binding/ort-training-wasm-simd.js'] = './binding/ort-training-wasm-simd.min.js'; const options = defaultTerserPluginOptions(target); options.terserOptions.format.preamble = COPYRIGHT_BANNER; @@ -291,8 +292,8 @@ module.exports = () => { // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), - // ort.wasm.min.js - buildOrtConfig({ + // ort-web.training.wasm.min.js + buildOrtWebConfig({ suffix: '.training.wasm.min', build_defs: { ...DEFAULT_BUILD_DEFS, From 5adbcf407e5c5b2c2797ff9568b3a4d80ccf2a74 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 22 Sep 2023 21:21:42 +0000 Subject: [PATCH 04/19] applied suggestions --- js/web/lib/backend-wasm-with-training.ts | 8 +++----- js/web/lib/wasm/wasm-factory.ts | 6 +++--- js/web/webpack.config.js | 2 ++ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 5fac10da7791d..af5b575c87a7f 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -6,14 +6,12 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions): Promise { + _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, + _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, + _options: InferenceSession.SessionOptions): Promise { throw new Error('Method not implemented yet.'); } - /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 1ce06b0f63b88..08192b4174272 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -77,9 +77,9 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (useTraining) { + if (BUILD_DEFS.ENABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; @@ -111,7 +111,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); + const wasmFileName = getWasmFileName(useSimd, useThreads); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 1ac421b5246d9..28d30604dc8d3 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -298,6 +298,8 @@ module.exports = () => { build_defs: { ...DEFAULT_BUILD_DEFS, ENABLE_TRAINING: true, + DISABLE_WEBGL: true, + DISABLE_WEBGPU: true, } }), ); From 9068ab41adb8855c3eba4cce0e21afd17b5ff06b Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 14:29:30 -0700 Subject: [PATCH 05/19] create training session implementation + supporting changes --- js/common/lib/training-session-impl.ts | 24 ++- js/web/lib/backend-wasm-with-training.ts | 12 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 3 + js/web/lib/wasm/proxy-messages.ts | 7 +- js/web/lib/wasm/proxy-wrapper.ts | 14 ++ .../lib/wasm/session-handler-for-training.ts | 72 ++++++++ js/web/lib/wasm/session-handler.ts | 6 +- js/web/lib/wasm/wasm-core-impl.ts | 7 + js/web/lib/wasm/wasm-training-core-impl.ts | 162 ++++++++++++++++++ onnxruntime/wasm/api.cc | 37 ++++ onnxruntime/wasm/api.h | 23 +++ 11 files changed, 356 insertions(+), 11 deletions(-) create mode 100644 js/web/lib/wasm/session-handler-for-training.ts create mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index f06d06bda035f..bde1fb0c373ea 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -4,8 +4,11 @@ import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; +import { resolveBackend } from './backend-impl.js'; type SessionOptions = InferenceSession.SessionOptions; +const noBackendErrMsg: string = "Training backend could not be resolved. " + + "Make sure you\'re using the correct configuration & WebAssembly files."; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,9 +23,26 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + let checkpointState: string|Uint8Array = trainingOptions.checkpointState; + let trainModel: string|Uint8Array = trainingOptions.trainModel; + let evalModel: string|Uint8Array = trainingOptions.evalModel ? trainingOptions.evalModel : ''; + let optimizerModel: string|Uint8Array = trainingOptions.optimizerModel ? trainingOptions.optimizerModel : ''; + let options: SessionOptions = sessionOptions ? sessionOptions : {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = + await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } + else { + throw new Error(noBackendErrMsg); + } } async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index af5b575c87a7f..1796259d0f335 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff4537095..c2b2bdf628d85 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 43f70c23f7193..99c68076e62f7 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 202209ed3bfed..881f905b803c3 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -242,3 +243,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +} diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts new file mode 100644 index 0000000000000..5f694371919fc --- /dev/null +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata} from './proxy-messages'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + getContiguousParameters(trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + let checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + let trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } + + async runTrainStep( + _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 7bc467449c33a..baf48ed5b6f57 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -6,10 +6,9 @@ import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} import {promisify} from 'util'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { @@ -58,13 +57,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!isOrtEnvInitialized()) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e3..770328934deba 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -57,6 +57,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -92,6 +94,11 @@ type SessionMetadata = [ ]; const activeSessions = new Map(); +let ortEnvInitialized = false; + +export const isOrtEnvInitialized = (): boolean => { + return ortEnvInitialized; +}; /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..9216a5b1f4361 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// import {InferenceSession, Tensor} from 'onnxruntime-common'; +import {InferenceSession} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages'; +// import {setRunOptions} from './run-options'; +import {setSessionOptions} from './session-options'; +// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; +// import {allocWasmString, checkLastError} from './wasm-utils'; +// import { prepareInputOutputTensor } from './wasm-core-impl'; + +const throwNoTrainingFuncsError = (): void => { + throw TypeError( + 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); +}; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); + } else { + throwNoTrainingFuncsError(); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throwNoTrainingFuncsError(); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + + const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); + + const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); + const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +} + + const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetInputOutputCount) { + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throwNoTrainingFuncsError(); + // unreachable code -- placeholder to prevent linting errors + return [0, 0]; + } + } finally { + wasm.stackRestore(stack); + } +} + +const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetInputOutputName) { + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throwNoTrainingFuncsError; + } + } + return [names, namesUTF8Encoded]; +} + +export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } +} diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 968eece361724..1f41e341e4cbd 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ + do { \ + int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ + if (error_code != ORT_OK) { \ + return error_code; \ + } \ + } while (false) + ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; @@ -571,6 +579,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; +} + +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput) { + OrtAllocator* allocator = nullptr; + RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + char* name = nullptr; + + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } +} + void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9a0664697f0ff..ea21eb8a9e8c8 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -432,6 +432,29 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio size_t parameter_count, bool trainable_only); +/** + * Gets the input count and output count of the training model associated with the given training handle. + * @param traning_handle handle of the traning session + * @param input_count [out] a pointer to a size_t variable to accept input_count + * @param output_count [out] a pointer to a size_t variable to accept output_count + * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count); + +/** + * Gets the input or output name at the specified index associated with the training model from the + * given training session. + * @param traning_handle handle of the traning session + * @param index the input or output index + * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by + */ +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput); + /** * @brief Release the specified ORT training session. * From 5202e3b9363753fd363c4265b88865671dc0d442 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 16:26:40 -0700 Subject: [PATCH 06/19] fixed variable names + enforced wasmBackend being a singleton with suggested fix of adding backend-wasm-inference.ts --- js/web/lib/backend-wasm-inference.ts | 5 +++++ ...ckend-wasm-with-training.ts => backend-wasm-training.ts} | 0 js/web/lib/backend-wasm.ts | 2 -- js/web/lib/index.ts | 6 +++--- js/web/lib/wasm/wasm-factory.ts | 4 ++-- 5 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/backend-wasm-inference.ts rename js/web/lib/{backend-wasm-with-training.ts => backend-wasm-training.ts} (100%) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts new file mode 100644 index 0000000000000..059a30bd45a57 --- /dev/null +++ b/js/web/lib/backend-wasm-inference.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-training.ts similarity index 100% rename from js/web/lib/backend-wasm-with-training.ts rename to js/web/lib/backend-wasm-training.ts diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 60468581968e3..5740263583031 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -51,5 +51,3 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } - -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 30d0af5dc9e47..09f478dc2ddbe 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -16,14 +16,14 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : - require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.ENABLE_TRAINING) { + if (BUILD_DEFS.DISABLE_TRAINING) { registerBackend('xnnpack', wasmBackend, 9); registerBackend('webnn', wasmBackend, 9); } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 87fa0a7771d3c..2b7d492cc70ba 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -10,7 +10,7 @@ import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ let ortWasmFactory: EmscriptenModuleFactory; -if (BUILD_DEFS.ENABLE_TRAINING) { +if (!BUILD_DEFS.DISABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); } else { ortWasmFactory = @@ -79,7 +79,7 @@ const isSimdSupported = (): boolean => { const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (BUILD_DEFS.ENABLE_TRAINING) { + if (!BUILD_DEFS.DISABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; From a277347111af09bc5b7623d42215a36732c4f7e9 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 17:08:47 -0700 Subject: [PATCH 07/19] format + lint --- js/web/lib/backend-wasm-inference.ts | 2 +- js/web/lib/index.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts index 059a30bd45a57..475a0243ebd3d 100644 --- a/js/web/lib/backend-wasm-inference.ts +++ b/js/web/lib/backend-wasm-inference.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 09f478dc2ddbe..4b3d874534398 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -17,7 +17,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : - require('./backend-wasm-training').wasmBackend; + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } From e775933eab58fdd94d7a9d5e8cc616cf6659af9c Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 14:05:50 -0700 Subject: [PATCH 08/19] minor tweak to remove placeholder return statement --- js/web/lib/wasm/wasm-training-core-impl.ts | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 9216a5b1f4361..c55b4976c9d0d 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -12,10 +12,7 @@ import {checkLastError} from './wasm-utils'; // import {allocWasmString, checkLastError} from './wasm-utils'; // import { prepareInputOutputTensor } from './wasm-core-impl'; -const throwNoTrainingFuncsError = (): void => { - throw TypeError( - 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); -}; +const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -26,7 +23,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n if (wasm._OrtTrainingLoadCheckpoint) { checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); } else { - throwNoTrainingFuncsError(); + throw new Error(NO_TRAIN_FUNCS_MSG); } if (checkpointHandle === 0) { @@ -66,7 +63,7 @@ export const createTrainingSessionHandle = sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], evalModelData[1], optimizerModelData[0], optimizerModelData[1]); } else { - throwNoTrainingFuncsError(); + throw new Error(NO_TRAIN_FUNCS_MSG); } if (trainingSessionHandle === 0) { @@ -117,9 +114,7 @@ export const createTrainingSessionHandle = } return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { - throwNoTrainingFuncsError(); - // unreachable code -- placeholder to prevent linting errors - return [0, 0]; + throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { wasm.stackRestore(stack); @@ -142,7 +137,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); } else { - throwNoTrainingFuncsError; + throw new Error(NO_TRAIN_FUNCS_MSG); } } return [names, namesUTF8Encoded]; From c745a6570e5b7c46d573e57302e52f9570984799 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 14:26:18 -0700 Subject: [PATCH 09/19] format --- js/common/lib/training-session-impl.ts | 17 ++++---- js/web/lib/backend-wasm-training.ts | 2 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/proxy-wrapper.ts | 2 +- .../lib/wasm/session-handler-for-training.ts | 5 ++- js/web/lib/wasm/session-handler.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 39 +++++++++---------- 7 files changed, 35 insertions(+), 36 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index bde1fb0c373ea..3cb6ec572c9a6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; -import { resolveBackend } from './backend-impl.js'; type SessionOptions = InferenceSession.SessionOptions; -const noBackendErrMsg: string = "Training backend could not be resolved. " + - "Make sure you\'re using the correct configuration & WebAssembly files."; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -23,7 +23,7 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { let checkpointState: string|Uint8Array = trainingOptions.checkpointState; let trainModel: string|Uint8Array = trainingOptions.trainModel; @@ -36,11 +36,10 @@ export class TrainingSession implements TrainingSessionInterface { const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { - const handler = - await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); - } - else { + const handler = + await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { throw new Error(noBackendErrMsg); } } diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 1796259d0f335..98e40807aa29c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -11,7 +11,7 @@ class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBacken checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); await handler.createTrainingSession( checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index c2b2bdf628d85..060fb1e756ef9 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,7 +102,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 6bb1bb733f194..55edd2106130d 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -264,4 +264,4 @@ export const isOrtEnvInitialized = async(): Promise => { } else { return core.isOrtEnvInitialized(); } -} +}; diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index 5f694371919fc..e0a9db97b951d 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -61,12 +61,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } async runTrainStep( _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); + throw new Error('Method not implemented yet.'); } } diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 2e2239b948b0f..a5017a920f38b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,7 +5,7 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; @@ -56,7 +56,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!isOrtEnvInitialized()) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c55b4976c9d0d..7de1601cd73b5 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,16 +3,13 @@ // import {InferenceSession, Tensor} from 'onnxruntime-common'; import {InferenceSession} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages'; -// import {setRunOptions} from './run-options'; +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; import {setSessionOptions} from './session-options'; -// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -// import {allocWasmString, checkLastError} from './wasm-utils'; -// import { prepareInputOutputTensor } from './wasm-core-impl'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); +const NO_TRAIN_FUNCS_MSG = + 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -93,7 +90,7 @@ export const createTrainingSessionHandle = } }; - const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); @@ -102,7 +99,7 @@ export const createTrainingSessionHandle = return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; } - const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { +const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { @@ -143,15 +140,17 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: return [names, namesUTF8Encoded]; } -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } -} +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + } From beca6dc3f989bd04c81cb4b14bd02daa0f83f6a0 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 15:04:45 -0700 Subject: [PATCH 10/19] lint fixes --- js/common/lib/training-session-impl.ts | 11 +++++------ js/web/lib/wasm/session-handler-for-training.ts | 8 ++++---- js/web/lib/wasm/wasm-core-impl.ts | 3 ++- js/web/lib/wasm/wasm-training-core-impl.ts | 9 +++++---- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 3cb6ec572c9a6..921730d86d136 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -25,11 +25,9 @@ export class TrainingSession implements TrainingSessionInterface { static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - let checkpointState: string|Uint8Array = trainingOptions.checkpointState; - let trainModel: string|Uint8Array = trainingOptions.trainModel; - let evalModel: string|Uint8Array = trainingOptions.evalModel ? trainingOptions.evalModel : ''; - let optimizerModel: string|Uint8Array = trainingOptions.optimizerModel ? trainingOptions.optimizerModel : ''; - let options: SessionOptions = sessionOptions ? sessionOptions : {}; + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; // get backend hints const eps = options.executionProviders || []; @@ -37,7 +35,8 @@ export class TrainingSession implements TrainingSessionInterface { const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { const handler = - await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + await backend.createTrainingSessionHandler(trainingOptions.checkpointState, trainingOptions.trainModel, + evalModel, optimizerModel, options); return new TrainingSession(handler); } else { throw new Error(noBackendErrMsg); diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index e0a9db97b951d..83d133b9a5157 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -8,10 +8,10 @@ import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-co import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } - getContiguousParameters(trainableOnly: boolean): Promise { + async getContiguousParameters(_trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } private sessionId: number; @@ -42,8 +42,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes if (!isOrtEnvInitialized()) { await initRuntime(env); } - let checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - let trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr let evalModelData: SerializableModeldata = [0, 0]; let optimizerModelData: SerializableModeldata = [0, 0]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 770328934deba..7d3c8e263087d 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -94,7 +96,6 @@ type SessionMetadata = [ ]; const activeSessions = new Map(); -let ortEnvInitialized = false; export const isOrtEnvInitialized = (): boolean => { return ortEnvInitialized; diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 7de1601cd73b5..130bc06737deb 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// import {InferenceSession, Tensor} from 'onnxruntime-common'; + import {InferenceSession} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; @@ -9,7 +9,8 @@ import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; const NO_TRAIN_FUNCS_MSG = - 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'; + 'Built without training APIs enabled. ' + + 'Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -97,7 +98,7 @@ const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; -} +}; const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { const wasm = getInstance(); @@ -116,7 +117,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n } finally { wasm.stackRestore(stack); } -} +}; const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { const names = []; From b37e77f83e2fa60035f107c8c3de7ad00dc81dfd Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 22:28:53 +0000 Subject: [PATCH 11/19] lint + format --- js/common/lib/training-session-impl.ts | 5 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 107 ++++++++++----------- onnxruntime/wasm/api.cc | 2 +- 4 files changed, 57 insertions(+), 61 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 921730d86d136..47e67879e66ce 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -34,9 +34,8 @@ export class TrainingSession implements TrainingSessionInterface { const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { - const handler = - await backend.createTrainingSessionHandler(trainingOptions.checkpointState, trainingOptions.trainModel, - evalModel, optimizerModel, options); + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); return new TrainingSession(handler); } else { throw new Error(noBackendErrMsg); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7d3c8e263087d..947242945c665 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -97,9 +97,7 @@ type SessionMetadata = [ const activeSessions = new Map(); -export const isOrtEnvInitialized = (): boolean => { - return ortEnvInitialized; -}; +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 130bc06737deb..6baac86b2e885 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -8,8 +8,7 @@ import {setSessionOptions} from './session-options'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -const NO_TRAIN_FUNCS_MSG = - 'Built without training APIs enabled. ' + +const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + 'Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { @@ -39,6 +38,56 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n } }; +const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetInputOutputCount) { + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetInputOutputName) { + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; +}; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); + + const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); + const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, optimizerModelData: SerializableModeldata, @@ -70,8 +119,8 @@ export const createTrainingSessionHandle = [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = getTrainingModelInputOutputNames(trainingSessionHandle); - return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); @@ -91,56 +140,6 @@ export const createTrainingSessionHandle = } }; -const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); - - const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); - const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); - - return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; -}; - -const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { - const names = []; - const wasm = getInstance(); - - const namesUTF8Encoded = []; - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } - - namesUTF8Encoded.push(name); - names.push(wasm.UTF8ToString(name)); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return [names, namesUTF8Encoded]; -} - export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { @@ -154,4 +153,4 @@ export const releaseTrainingSessionAndCheckpoint = if (wasm._OrtTrainingReleaseSession) { wasm._OrtTrainingReleaseSession(sessionId); } - } + }; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 1f41e341e4cbd..2645d6f05222f 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,7 +493,7 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) -#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ do { \ int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ if (error_code != ORT_OK) { \ From d2f3d880198dcc79d3549e68720107fcf192c9bb Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 12 Oct 2023 16:45:44 -0700 Subject: [PATCH 12/19] added isOrtEnvInitialized case to proxy wrapper --- js/web/lib/wasm/proxy-wrapper.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 55edd2106130d..d7f2e957c2e0c 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -93,6 +93,12 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } default: } }; From 05a708fd868ad242babd5cd17909d39995a7f2d2 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 13 Oct 2023 10:33:00 -0700 Subject: [PATCH 13/19] fixed proxy wrapper case statement fixed proxy worker run format --- js/web/lib/wasm/proxy-worker/main.ts | 10 +++++++++- js/web/lib/wasm/proxy-wrapper.ts | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b191..1f4595818e5c0 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index d7f2e957c2e0c..069a1fa452dbc 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -99,6 +99,7 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { } else { isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); } + break; default: } }; From daa1023d530d8f496d3689e8e6c7cd4ab82435e1 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 23 Oct 2023 12:17:33 -0700 Subject: [PATCH 14/19] updated getInputOutputCount and getInputOutputNames signature, added more informative error message --- js/web/lib/wasm/binding/ort-wasm.d.ts | 5 ++- js/web/lib/wasm/wasm-training-core-impl.ts | 13 +++--- onnxruntime/wasm/api.cc | 50 ++++++++++++++++------ onnxruntime/wasm/api.h | 14 ++++-- 4 files changed, 57 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 060fb1e756ef9..bb5f4c1090ea6 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,8 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; - _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number, + isEvalModel: boolean): number; + _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 6baac86b2e885..9f0ee2ea9dfe4 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -8,17 +8,20 @@ import {setSessionOptions} from './session-options'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + - 'Make sure to use the onnxruntime-training package for training functionality.'; +const NO_TRAIN_FUNCS_MSG = + `Built without training API's enabled. Use the onnxruntime-web/training import for training \ + functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \ + using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); + const [checkpointDataOffset, checkpointDataLength] = checkpointData; let checkpointHandle = 0; try { if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -44,7 +47,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n try { const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false); if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } @@ -65,7 +68,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false); if (name === 0) { checkLastError('Can\'t get input or output name'); } diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 2645d6f05222f..f8375c0a77ae3 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -581,30 +581,52 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count) { - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); - return ORT_OK; + size_t* output_count, + bool isEvalModel) { + if (isEvalModel) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); + return ORT_OK; + } else { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; + } } char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput) { + bool isInput, + bool isEvalModel) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); char* name = nullptr; - if (isInput) { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isEvalModel) { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } else { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index ea21eb8a9e8c8..d7bc84c0f00bd 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -433,27 +433,33 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio bool trainable_only); /** - * Gets the input count and output count of the training model associated with the given training handle. + * Gets the input count and output count of the training or eval model associated with the given training handle. * @param traning_handle handle of the traning session * @param input_count [out] a pointer to a size_t variable to accept input_count * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output + * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count); + size_t* output_count, + bool isEvalModel); /** - * Gets the input or output name at the specified index associated with the training model from the + * Gets the input or output name at the specified index associated with the training or eval model from the * given training session. * @param traning_handle handle of the traning session * @param index the input or output index * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output + * names of the eval model. * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput); + bool isInput, + bool isEvalModel); /** * @brief Release the specified ORT training session. From 389f08e810d37a79ddd0ae12807fb72196d92f27 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 23 Oct 2023 12:19:13 -0700 Subject: [PATCH 15/19] updated parameter names according to suggestions --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index bb5f4c1090ea6..a8f098483c977 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,9 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number, + _OrtTrainingGetInputOutputCount?(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingGetInputOutputName?(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion From 70505c4b0ca667cdfbddc244fc91e2fbf17c49b0 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 24 Oct 2023 11:27:23 -0700 Subject: [PATCH 16/19] format + lint --- js/web/lib/wasm/binding/ort-wasm.d.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index a8f098483c977..def706f53fc3a 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,9 +102,10 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(trainingHandle: number, inputCount: number, outputCount: number, - isEvalModel: boolean): number; - _OrtTrainingGetInputOutputName?(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingGetInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion From a5bc60c29910f18e7a0972752651307c9762d565 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 24 Oct 2023 14:15:46 -0700 Subject: [PATCH 17/19] lint fix -- changed multiline string to singlequote string --- js/web/lib/wasm/wasm-training-core-impl.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 9f0ee2ea9dfe4..613ebb057a9d3 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -9,9 +9,9 @@ import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; const NO_TRAIN_FUNCS_MSG = - `Built without training API's enabled. Use the onnxruntime-web/training import for training \ - functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \ - using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`; + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); From ae936e5457283ac226a2c6cf669a478681ac1e16 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Wed, 25 Oct 2023 13:41:21 -0700 Subject: [PATCH 18/19] Apply naming suggestions from code review Co-authored-by: Ashwini Khade --- js/web/lib/wasm/wasm-training-core-impl.ts | 4 ++-- onnxruntime/wasm/api.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 613ebb057a9d3..6277fbec94ef7 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -41,7 +41,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n } }; -const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { +const getModelInputOutputCount = (trainingSessionId: number, IsEvalModel: boolean): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { @@ -60,7 +60,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n } }; -const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { +const getModelInputOutputNamesLoop = (trainingSessionId: number, count: number, isInput: boolean, IsEvalModel:boolean): [string[], number[]] => { const names = []; const wasm = getInstance(); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index f8375c0a77ae3..8346da48a7ab8 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -579,7 +579,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } -int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, size_t* output_count, bool isEvalModel) { From 316a6e7801d7aad9409df789a7e6c5ca2aa0a5ba Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 25 Oct 2023 14:29:13 -0700 Subject: [PATCH 19/19] implemented naming suggestions + fixed training session & checkpoint release order --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 57 ++++++++++++---------- onnxruntime/wasm/api.cc | 14 +++--- onnxruntime/wasm/api.h | 16 +++--- 4 files changed, 47 insertions(+), 44 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index def706f53fc3a..00431a4e86d5b 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,9 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount? + _OrtTrainingGetModelInputOutputCount? (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetInputOutputName? + _OrtTrainingGetModelInputOutputName? (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 6277fbec94ef7..4830b5d2b5e80 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -41,13 +41,14 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n } }; -const getModelInputOutputCount = (trainingSessionId: number, IsEvalModel: boolean): [number, number] => { +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } @@ -60,33 +61,35 @@ const getModelInputOutputCount = (trainingSessionId: number, IsEvalModel: boolea } }; -const getModelInputOutputNamesLoop = (trainingSessionId: number, count: number, isInput: boolean, IsEvalModel:boolean): [string[], number[]] => { - const names = []; - const wasm = getInstance(); +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); - const namesUTF8Encoded = []; + const namesUTF8Encoded = []; - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } - namesUTF8Encoded.push(name); - names.push(wasm.UTF8ToString(name)); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return [names, namesUTF8Encoded]; -}; + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); - const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); - const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; }; @@ -150,10 +153,10 @@ export const releaseTrainingSessionAndCheckpoint = inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } if (wasm._OrtTrainingReleaseSession) { wasm._OrtTrainingReleaseSession(sessionId); } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } }; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 8346da48a7ab8..0e58bb4f93f7f 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -580,9 +580,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio } int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, - size_t* input_count, - size_t* output_count, - bool isEvalModel) { + size_t* input_count, + size_t* output_count, + bool isEvalModel) { if (isEvalModel) { RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); @@ -594,10 +594,10 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_sessio } } -char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, - size_t index, - bool isInput, - bool isEvalModel) { +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index d7bc84c0f00bd..2cd1515d191c8 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -441,10 +441,10 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ -int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, - size_t* input_count, - size_t* output_count, - bool isEvalModel); +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count, + bool isEvalModel); /** * Gets the input or output name at the specified index associated with the training or eval model from the @@ -456,10 +456,10 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_han * names of the eval model. * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ -char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, - size_t index, - bool isInput, - bool isEvalModel); +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel); /** * @brief Release the specified ORT training session.