From cff6a7f11e5a687b359b3d5c28063afb37fa81dc Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 29 Aug 2023 21:50:18 +0000 Subject: [PATCH 01/18] added training build configuration --- js/web/lib/build-def.d.ts | 4 ++++ js/web/lib/wasm/wasm-factory.ts | 27 ++++++++++++++++++++------- js/web/script/build.ts | 4 ++++ js/web/webpack.config.js | 10 ++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 2049b2663ead3..939486926f319 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -30,6 +30,10 @@ interface BuildDefinitions { * defines whether to disable multi-threading feature in WebAssembly backend in the build. */ DISABLE_WASM_THREAD: boolean; + /** + * defines whether to enable training APIs in WebAssembly backend. + */ + ENABLE_TRAINING: boolean; } declare let BUILD_DEFS: BuildDefinitions; diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 7648f0c473f07..4ccb06e281b9d 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ -const ortWasmFactory: EmscriptenModuleFactory = - BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +let ortWasmFactory: EmscriptenModuleFactory; + +if (BUILD_DEFS.ENABLE_TRAINING) { + ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); +} +else { + ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : @@ -71,12 +77,19 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { + let wasmArtifact : string = 'ort'; + if (useTraining) { + wasmArtifact += '-training'; + } + wasmArtifact += '-wasm'; + if (useSimd) { + wasmArtifact += '-simd'; + } if (useThreads) { - return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm'; - } else { - return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm'; + wasmArtifact += '-threaded'; } + return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => { @@ -102,7 +115,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads); + const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 03510ae86b85f..bda7da0ab2a8b 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -33,6 +33,7 @@ const FILTER = args.f || args.filter; const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); +const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -45,6 +46,7 @@ const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist'); const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm'); const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm'); const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm'); +const TRAINING_WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-training-wasm-simd.wasm'); const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm'); const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm'); const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm'); @@ -75,6 +77,8 @@ if (WASM) { validateFile(WASM_SIMD_THREADED_WASM_PATH); validateFile(WASM_SIMD_JSEP_WASM_PATH); validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + validateFile(TRAINING_WASM_SIMD_WASM_PATH); } catch (e) { npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`); throw e; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 81c69ffdcf6bf..7115a3e3b4258 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -57,6 +57,7 @@ const DEFAULT_BUILD_DEFS = { DISABLE_WASM: false, DISABLE_WASM_PROXY: false, DISABLE_WASM_THREAD: false, + ENABLE_TRAINING: false }; // common config for release bundle @@ -289,6 +290,15 @@ module.exports = () => { buildOrtWebConfig({suffix: '.es6.min', target: 'es6'}), // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), + + // ort.wasm.min.js + buildOrtConfig({ + suffix: '.training.wasm.min', + build_defs: { + ...DEFAULT_BUILD_DEFS, + ENABLE_TRAINING: true, + } + }), ); case 'node': From 7a1365d93272567feda3a6682ad6a2cb79422898 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 29 Aug 2023 22:13:34 +0000 Subject: [PATCH 02/18] edited wasm factory + added training backend --- js/web/lib/backend-wasm-with-training.ts | 17 +++++++++++++++++ js/web/lib/backend-wasm.ts | 2 +- js/web/lib/index.ts | 11 +++++++---- js/web/lib/wasm/wasm-factory.ts | 22 +++++++++------------- 4 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 js/web/lib/backend-wasm-with-training.ts diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts new file mode 100644 index 0000000000000..6d31861fb1bcd --- /dev/null +++ b/js/web/lib/backend-wasm-with-training.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 04108c2ad0f66..3f06eea5329a0 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise { // populate wasm flags initializeFlags(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d5ed536034f3e..03e1108b31d12 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env} from 'onnxruntime-common'; +import {registerBackend, env, Backend} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : + require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (!BUILD_DEFS.ENABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 4ccb06e281b9d..1ce06b0f63b88 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -12,9 +12,9 @@ let ortWasmFactory: EmscriptenModuleFactory; if (BUILD_DEFS.ENABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); -} -else { - ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); } const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? @@ -78,18 +78,14 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { - let wasmArtifact : string = 'ort'; - if (useTraining) { - wasmArtifact += '-training'; - } - wasmArtifact += '-wasm'; if (useSimd) { - wasmArtifact += '-simd'; - } - if (useThreads) { - wasmArtifact += '-threaded'; + if (useTraining) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; + } else { + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } - return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => { From 71f9dbc94735908d4093e4fa114b62867e29306c Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 31 Aug 2023 21:02:16 +0000 Subject: [PATCH 03/18] added package.json modification for training + minimized training artifact --- js/web/lib/backend-wasm-with-training.ts | 4 +++- js/web/lib/index.ts | 2 +- js/web/package.json | 4 ++++ js/web/script/build.ts | 26 ++++++++++++++++++++++++ js/web/webpack.config.js | 5 +++-- 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 6d31861fb1bcd..5fac10da7791d 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -3,15 +3,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; -import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise { throw new Error('Method not implemented yet.'); } + /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 03e1108b31d12..30d0af5dc9e47 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env, Backend} from 'onnxruntime-common'; +import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { diff --git a/js/web/package.json b/js/web/package.json index 8ae5b733e5f21..798434bf5d574 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -78,6 +78,10 @@ "./webgpu": { "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" + }, + "./training": { + "types": "./types.d.ts", + "default": "./dist/ort-web.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index bda7da0ab2a8b..ff94d1a806a09 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -34,6 +34,7 @@ const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); +const TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.min.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -93,6 +94,31 @@ if (WASM) { */ `; + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"...'); + try { + const terser = spawnSync( + 'npx', + [ + 'terser', TRAINING_WASM_BINDING_SIMD_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', + '--mangle', 'reserved=[_scriptDir]', '--module' + ], + {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); + if (terser.status !== 0) { + console.error(terser.error); + process.exit(terser.status === null ? undefined : terser.status); + } + + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH, terser.stdout); + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + + validateFile(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + } catch (e) { + npmlog.error('Build', `Failed to run terser on ort-training-wasm-simd.js. ERR: ${e}`); + throw e; + } + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"... DONE'); + npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"...'); try { const terser = spawnSync( diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 7115a3e3b4258..1ac421b5246d9 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -104,6 +104,7 @@ function buildConfig({filename, format, target, mode, devtool, build_defs}) { config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js'; config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js'; config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js'; + config.resolve.alias['./binding/ort-training-wasm-simd.js'] = './binding/ort-training-wasm-simd.min.js'; const options = defaultTerserPluginOptions(target); options.terserOptions.format.preamble = COPYRIGHT_BANNER; @@ -291,8 +292,8 @@ module.exports = () => { // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), - // ort.wasm.min.js - buildOrtConfig({ + // ort-web.training.wasm.min.js + buildOrtWebConfig({ suffix: '.training.wasm.min', build_defs: { ...DEFAULT_BUILD_DEFS, From 5adbcf407e5c5b2c2797ff9568b3a4d80ccf2a74 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 22 Sep 2023 21:21:42 +0000 Subject: [PATCH 04/18] applied suggestions --- js/web/lib/backend-wasm-with-training.ts | 8 +++----- js/web/lib/wasm/wasm-factory.ts | 6 +++--- js/web/webpack.config.js | 2 ++ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 5fac10da7791d..af5b575c87a7f 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -6,14 +6,12 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions): Promise { + _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, + _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, + _options: InferenceSession.SessionOptions): Promise { throw new Error('Method not implemented yet.'); } - /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 1ce06b0f63b88..08192b4174272 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -77,9 +77,9 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (useTraining) { + if (BUILD_DEFS.ENABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; @@ -111,7 +111,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); + const wasmFileName = getWasmFileName(useSimd, useThreads); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 1ac421b5246d9..28d30604dc8d3 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -298,6 +298,8 @@ module.exports = () => { build_defs: { ...DEFAULT_BUILD_DEFS, ENABLE_TRAINING: true, + DISABLE_WEBGL: true, + DISABLE_WEBGPU: true, } }), ); From 9068ab41adb8855c3eba4cce0e21afd17b5ff06b Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 14:29:30 -0700 Subject: [PATCH 05/18] create training session implementation + supporting changes --- js/common/lib/training-session-impl.ts | 24 ++- js/web/lib/backend-wasm-with-training.ts | 12 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 3 + js/web/lib/wasm/proxy-messages.ts | 7 +- js/web/lib/wasm/proxy-wrapper.ts | 14 ++ .../lib/wasm/session-handler-for-training.ts | 72 ++++++++ js/web/lib/wasm/session-handler.ts | 6 +- js/web/lib/wasm/wasm-core-impl.ts | 7 + js/web/lib/wasm/wasm-training-core-impl.ts | 162 ++++++++++++++++++ onnxruntime/wasm/api.cc | 37 ++++ onnxruntime/wasm/api.h | 23 +++ 11 files changed, 356 insertions(+), 11 deletions(-) create mode 100644 js/web/lib/wasm/session-handler-for-training.ts create mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index f06d06bda035f..bde1fb0c373ea 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -4,8 +4,11 @@ import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; +import { resolveBackend } from './backend-impl.js'; type SessionOptions = InferenceSession.SessionOptions; +const noBackendErrMsg: string = "Training backend could not be resolved. " + + "Make sure you\'re using the correct configuration & WebAssembly files."; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,9 +23,26 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + let checkpointState: string|Uint8Array = trainingOptions.checkpointState; + let trainModel: string|Uint8Array = trainingOptions.trainModel; + let evalModel: string|Uint8Array = trainingOptions.evalModel ? trainingOptions.evalModel : ''; + let optimizerModel: string|Uint8Array = trainingOptions.optimizerModel ? trainingOptions.optimizerModel : ''; + let options: SessionOptions = sessionOptions ? sessionOptions : {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = + await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } + else { + throw new Error(noBackendErrMsg); + } } async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index af5b575c87a7f..1796259d0f335 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff4537095..c2b2bdf628d85 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 43f70c23f7193..99c68076e62f7 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 202209ed3bfed..881f905b803c3 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -242,3 +243,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +} diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts new file mode 100644 index 0000000000000..5f694371919fc --- /dev/null +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata} from './proxy-messages'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + getContiguousParameters(trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + let checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + let trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } + + async runTrainStep( + _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 7bc467449c33a..baf48ed5b6f57 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -6,10 +6,9 @@ import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} import {promisify} from 'util'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { @@ -58,13 +57,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!isOrtEnvInitialized()) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e3..770328934deba 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -57,6 +57,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -92,6 +94,11 @@ type SessionMetadata = [ ]; const activeSessions = new Map(); +let ortEnvInitialized = false; + +export const isOrtEnvInitialized = (): boolean => { + return ortEnvInitialized; +}; /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..9216a5b1f4361 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// import {InferenceSession, Tensor} from 'onnxruntime-common'; +import {InferenceSession} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages'; +// import {setRunOptions} from './run-options'; +import {setSessionOptions} from './session-options'; +// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; +// import {allocWasmString, checkLastError} from './wasm-utils'; +// import { prepareInputOutputTensor } from './wasm-core-impl'; + +const throwNoTrainingFuncsError = (): void => { + throw TypeError( + 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); +}; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); + } else { + throwNoTrainingFuncsError(); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throwNoTrainingFuncsError(); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + + const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); + + const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); + const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +} + + const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetInputOutputCount) { + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throwNoTrainingFuncsError(); + // unreachable code -- placeholder to prevent linting errors + return [0, 0]; + } + } finally { + wasm.stackRestore(stack); + } +} + +const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetInputOutputName) { + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throwNoTrainingFuncsError; + } + } + return [names, namesUTF8Encoded]; +} + +export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } +} diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 968eece361724..1f41e341e4cbd 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ + do { \ + int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ + if (error_code != ORT_OK) { \ + return error_code; \ + } \ + } while (false) + ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; @@ -571,6 +579,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; +} + +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput) { + OrtAllocator* allocator = nullptr; + RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + char* name = nullptr; + + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } +} + void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9a0664697f0ff..ea21eb8a9e8c8 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -432,6 +432,29 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio size_t parameter_count, bool trainable_only); +/** + * Gets the input count and output count of the training model associated with the given training handle. + * @param traning_handle handle of the traning session + * @param input_count [out] a pointer to a size_t variable to accept input_count + * @param output_count [out] a pointer to a size_t variable to accept output_count + * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count); + +/** + * Gets the input or output name at the specified index associated with the training model from the + * given training session. + * @param traning_handle handle of the traning session + * @param index the input or output index + * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by + */ +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput); + /** * @brief Release the specified ORT training session. * From 5202e3b9363753fd363c4265b88865671dc0d442 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 16:26:40 -0700 Subject: [PATCH 06/18] fixed variable names + enforced wasmBackend being a singleton with suggested fix of adding backend-wasm-inference.ts --- js/web/lib/backend-wasm-inference.ts | 5 +++++ ...ckend-wasm-with-training.ts => backend-wasm-training.ts} | 0 js/web/lib/backend-wasm.ts | 2 -- js/web/lib/index.ts | 6 +++--- js/web/lib/wasm/wasm-factory.ts | 4 ++-- 5 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/backend-wasm-inference.ts rename js/web/lib/{backend-wasm-with-training.ts => backend-wasm-training.ts} (100%) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts new file mode 100644 index 0000000000000..059a30bd45a57 --- /dev/null +++ b/js/web/lib/backend-wasm-inference.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-training.ts similarity index 100% rename from js/web/lib/backend-wasm-with-training.ts rename to js/web/lib/backend-wasm-training.ts diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 60468581968e3..5740263583031 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -51,5 +51,3 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } - -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 30d0af5dc9e47..09f478dc2ddbe 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -16,14 +16,14 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : - require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.ENABLE_TRAINING) { + if (BUILD_DEFS.DISABLE_TRAINING) { registerBackend('xnnpack', wasmBackend, 9); registerBackend('webnn', wasmBackend, 9); } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 87fa0a7771d3c..2b7d492cc70ba 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -10,7 +10,7 @@ import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ let ortWasmFactory: EmscriptenModuleFactory; -if (BUILD_DEFS.ENABLE_TRAINING) { +if (!BUILD_DEFS.DISABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); } else { ortWasmFactory = @@ -79,7 +79,7 @@ const isSimdSupported = (): boolean => { const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (BUILD_DEFS.ENABLE_TRAINING) { + if (!BUILD_DEFS.DISABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; From a277347111af09bc5b7623d42215a36732c4f7e9 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 6 Oct 2023 17:08:47 -0700 Subject: [PATCH 07/18] format + lint --- js/web/lib/backend-wasm-inference.ts | 2 +- js/web/lib/index.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts index 059a30bd45a57..475a0243ebd3d 100644 --- a/js/web/lib/backend-wasm-inference.ts +++ b/js/web/lib/backend-wasm-inference.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 09f478dc2ddbe..4b3d874534398 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -17,7 +17,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : - require('./backend-wasm-training').wasmBackend; + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } From e775933eab58fdd94d7a9d5e8cc616cf6659af9c Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 14:05:50 -0700 Subject: [PATCH 08/18] minor tweak to remove placeholder return statement --- js/web/lib/wasm/wasm-training-core-impl.ts | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 9216a5b1f4361..c55b4976c9d0d 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -12,10 +12,7 @@ import {checkLastError} from './wasm-utils'; // import {allocWasmString, checkLastError} from './wasm-utils'; // import { prepareInputOutputTensor } from './wasm-core-impl'; -const throwNoTrainingFuncsError = (): void => { - throw TypeError( - 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); -}; +const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -26,7 +23,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n if (wasm._OrtTrainingLoadCheckpoint) { checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); } else { - throwNoTrainingFuncsError(); + throw new Error(NO_TRAIN_FUNCS_MSG); } if (checkpointHandle === 0) { @@ -66,7 +63,7 @@ export const createTrainingSessionHandle = sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], evalModelData[1], optimizerModelData[0], optimizerModelData[1]); } else { - throwNoTrainingFuncsError(); + throw new Error(NO_TRAIN_FUNCS_MSG); } if (trainingSessionHandle === 0) { @@ -117,9 +114,7 @@ export const createTrainingSessionHandle = } return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { - throwNoTrainingFuncsError(); - // unreachable code -- placeholder to prevent linting errors - return [0, 0]; + throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { wasm.stackRestore(stack); @@ -142,7 +137,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); } else { - throwNoTrainingFuncsError; + throw new Error(NO_TRAIN_FUNCS_MSG); } } return [names, namesUTF8Encoded]; From c745a6570e5b7c46d573e57302e52f9570984799 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 14:26:18 -0700 Subject: [PATCH 09/18] format --- js/common/lib/training-session-impl.ts | 17 ++++---- js/web/lib/backend-wasm-training.ts | 2 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/proxy-wrapper.ts | 2 +- .../lib/wasm/session-handler-for-training.ts | 5 ++- js/web/lib/wasm/session-handler.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 39 +++++++++---------- 7 files changed, 35 insertions(+), 36 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index bde1fb0c373ea..3cb6ec572c9a6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; -import { resolveBackend } from './backend-impl.js'; type SessionOptions = InferenceSession.SessionOptions; -const noBackendErrMsg: string = "Training backend could not be resolved. " + - "Make sure you\'re using the correct configuration & WebAssembly files."; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -23,7 +23,7 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { let checkpointState: string|Uint8Array = trainingOptions.checkpointState; let trainModel: string|Uint8Array = trainingOptions.trainModel; @@ -36,11 +36,10 @@ export class TrainingSession implements TrainingSessionInterface { const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { - const handler = - await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); - } - else { + const handler = + await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { throw new Error(noBackendErrMsg); } } diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 1796259d0f335..98e40807aa29c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -11,7 +11,7 @@ class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBacken checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); await handler.createTrainingSession( checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index c2b2bdf628d85..060fb1e756ef9 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,7 +102,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 6bb1bb733f194..55edd2106130d 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -264,4 +264,4 @@ export const isOrtEnvInitialized = async(): Promise => { } else { return core.isOrtEnvInitialized(); } -} +}; diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index 5f694371919fc..e0a9db97b951d 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -61,12 +61,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } async runTrainStep( _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); + throw new Error('Method not implemented yet.'); } } diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 2e2239b948b0f..a5017a920f38b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,7 +5,7 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; @@ -56,7 +56,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!isOrtEnvInitialized()) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c55b4976c9d0d..7de1601cd73b5 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,16 +3,13 @@ // import {InferenceSession, Tensor} from 'onnxruntime-common'; import {InferenceSession} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages'; -// import {setRunOptions} from './run-options'; +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; import {setSessionOptions} from './session-options'; -// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -// import {allocWasmString, checkLastError} from './wasm-utils'; -// import { prepareInputOutputTensor } from './wasm-core-impl'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); +const NO_TRAIN_FUNCS_MSG = + 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -93,7 +90,7 @@ export const createTrainingSessionHandle = } }; - const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); @@ -102,7 +99,7 @@ export const createTrainingSessionHandle = return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; } - const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { +const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { @@ -143,15 +140,17 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: return [names, namesUTF8Encoded]; } -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } -} +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + } From beca6dc3f989bd04c81cb4b14bd02daa0f83f6a0 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 15:04:45 -0700 Subject: [PATCH 10/18] lint fixes --- js/common/lib/training-session-impl.ts | 11 +++++------ js/web/lib/wasm/session-handler-for-training.ts | 8 ++++---- js/web/lib/wasm/wasm-core-impl.ts | 3 ++- js/web/lib/wasm/wasm-training-core-impl.ts | 9 +++++---- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 3cb6ec572c9a6..921730d86d136 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -25,11 +25,9 @@ export class TrainingSession implements TrainingSessionInterface { static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - let checkpointState: string|Uint8Array = trainingOptions.checkpointState; - let trainModel: string|Uint8Array = trainingOptions.trainModel; - let evalModel: string|Uint8Array = trainingOptions.evalModel ? trainingOptions.evalModel : ''; - let optimizerModel: string|Uint8Array = trainingOptions.optimizerModel ? trainingOptions.optimizerModel : ''; - let options: SessionOptions = sessionOptions ? sessionOptions : {}; + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; // get backend hints const eps = options.executionProviders || []; @@ -37,7 +35,8 @@ export class TrainingSession implements TrainingSessionInterface { const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { const handler = - await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + await backend.createTrainingSessionHandler(trainingOptions.checkpointState, trainingOptions.trainModel, + evalModel, optimizerModel, options); return new TrainingSession(handler); } else { throw new Error(noBackendErrMsg); diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index e0a9db97b951d..83d133b9a5157 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -8,10 +8,10 @@ import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-co import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } - getContiguousParameters(trainableOnly: boolean): Promise { + async getContiguousParameters(_trainableOnly: boolean): Promise { throw new Error('Method not implemented.'); } private sessionId: number; @@ -42,8 +42,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes if (!isOrtEnvInitialized()) { await initRuntime(env); } - let checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - let trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr let evalModelData: SerializableModeldata = [0, 0]; let optimizerModelData: SerializableModeldata = [0, 0]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 770328934deba..7d3c8e263087d 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -94,7 +96,6 @@ type SessionMetadata = [ ]; const activeSessions = new Map(); -let ortEnvInitialized = false; export const isOrtEnvInitialized = (): boolean => { return ortEnvInitialized; diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 7de1601cd73b5..130bc06737deb 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// import {InferenceSession, Tensor} from 'onnxruntime-common'; + import {InferenceSession} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; @@ -9,7 +9,8 @@ import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; const NO_TRAIN_FUNCS_MSG = - 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'; + 'Built without training APIs enabled. ' + + 'Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -97,7 +98,7 @@ const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; -} +}; const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { const wasm = getInstance(); @@ -116,7 +117,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n } finally { wasm.stackRestore(stack); } -} +}; const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { const names = []; From b37e77f83e2fa60035f107c8c3de7ad00dc81dfd Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 22:28:53 +0000 Subject: [PATCH 11/18] lint + format --- js/common/lib/training-session-impl.ts | 5 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 107 ++++++++++----------- onnxruntime/wasm/api.cc | 2 +- 4 files changed, 57 insertions(+), 61 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 921730d86d136..47e67879e66ce 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -34,9 +34,8 @@ export class TrainingSession implements TrainingSessionInterface { const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { - const handler = - await backend.createTrainingSessionHandler(trainingOptions.checkpointState, trainingOptions.trainModel, - evalModel, optimizerModel, options); + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); return new TrainingSession(handler); } else { throw new Error(noBackendErrMsg); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7d3c8e263087d..947242945c665 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -97,9 +97,7 @@ type SessionMetadata = [ const activeSessions = new Map(); -export const isOrtEnvInitialized = (): boolean => { - return ortEnvInitialized; -}; +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 130bc06737deb..6baac86b2e885 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -8,8 +8,7 @@ import {setSessionOptions} from './session-options'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -const NO_TRAIN_FUNCS_MSG = - 'Built without training APIs enabled. ' + +const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + 'Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { @@ -39,6 +38,56 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n } }; +const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetInputOutputCount) { + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetInputOutputName) { + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; +}; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); + + const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); + const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, optimizerModelData: SerializableModeldata, @@ -70,8 +119,8 @@ export const createTrainingSessionHandle = [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = getTrainingModelInputOutputNames(trainingSessionHandle); - return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); @@ -91,56 +140,6 @@ export const createTrainingSessionHandle = } }; -const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); - - const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); - const [outputNames, outputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, outputCount, false); - - return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; -}; - -const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: boolean): [string[], number[]] => { - const names = []; - const wasm = getInstance(); - - const namesUTF8Encoded = []; - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } - - namesUTF8Encoded.push(name); - names.push(wasm.UTF8ToString(name)); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return [names, namesUTF8Encoded]; -} - export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { @@ -154,4 +153,4 @@ export const releaseTrainingSessionAndCheckpoint = if (wasm._OrtTrainingReleaseSession) { wasm._OrtTrainingReleaseSession(sessionId); } - } + }; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 1f41e341e4cbd..2645d6f05222f 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,7 +493,7 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) -#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ do { \ int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ if (error_code != ORT_OK) { \ From d2f3d880198dcc79d3549e68720107fcf192c9bb Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 12 Oct 2023 16:45:44 -0700 Subject: [PATCH 12/18] added isOrtEnvInitialized case to proxy wrapper --- js/web/lib/wasm/proxy-wrapper.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 55edd2106130d..d7f2e957c2e0c 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -93,6 +93,12 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } default: } }; From 05a708fd868ad242babd5cd17909d39995a7f2d2 Mon Sep 17 00:00:00 2001 From: carzh Date: Fri, 13 Oct 2023 10:33:00 -0700 Subject: [PATCH 13/18] fixed proxy wrapper case statement fixed proxy worker run format --- js/web/lib/wasm/proxy-worker/main.ts | 10 +++++++++- js/web/lib/wasm/proxy-wrapper.ts | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b191..1f4595818e5c0 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index d7f2e957c2e0c..069a1fa452dbc 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -99,6 +99,7 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { } else { isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); } + break; default: } }; From 46a96774d8bb5538953a446f915f7bcca5f590ab Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 16:48:54 -0700 Subject: [PATCH 14/18] working runTrainStep implementation --- js/common/lib/training-session-impl.ts | 113 +++++++++++++- .../lib/wasm/session-handler-for-training.ts | 54 ++++++- js/web/lib/wasm/session-handler.ts | 4 +- js/web/lib/wasm/wasm-core-impl.ts | 2 +- js/web/lib/wasm/wasm-training-core-impl.ts | 139 +++++++++++++++++- 5 files changed, 293 insertions(+), 19 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 47e67879e66ce..e2d962677f30c 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -5,8 +5,15 @@ import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; +import { OnnxValue } from './onnx-value.js'; +import { Tensor } from './tensor.js'; type SessionOptions = InferenceSession.SessionOptions; +type FeedsType = InferenceSession.FeedsType; +type FetchesType = InferenceSession.FetchesType; +type ReturnType = InferenceSession.ReturnType; +type RunOptions = InferenceSession.RunOptions; + const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; @@ -50,14 +57,106 @@ export class TrainingSession implements TrainingSessionInterface { throw new Error('Method not implemented.'); } - runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): - Promise; + runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; runTrainStep( - feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, - options?: InferenceSession.RunOptions|undefined): Promise; - async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): - Promise { - throw new Error('Method not implemented.'); + feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + const fetches: {[name: string]: OnnxValue|null} = {}; + let options: RunOptions = {}; + // check inputs + if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) { + throw new TypeError( + '\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.'); + } + + let isFetchesEmpty = true; + // determine which override is being used + if (typeof arg1 === 'object') { + if (arg1 === null) { + throw new TypeError('Unexpected argument[1]: cannot be null.'); + } + if (arg1 instanceof Tensor) { + throw new TypeError('\'fetches\' cannot be a Tensor'); + } + + if (Array.isArray(arg1)) { + if (arg1.length === 0) { + throw new TypeError('\'fetches\' cannot be an empty array.'); + } + isFetchesEmpty = false; + // output names + for (const name of arg1) { + if (typeof name !== 'string') { + throw new TypeError('\'fetches\' must be a string array or an object.'); + } + if (this.outputNames.indexOf(name) === -1) { + throw new RangeError(`'fetches' contains invalid output name: ${name}.`); + } + fetches[name] = null; + } + + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + // decide whether arg1 is fetches or options + // if any output name is present and its value is valid OnnxValue, we consider it fetches + let isFetches = false; + const arg1Keys = Object.getOwnPropertyNames(arg1); + for (const name of this.outputNames) { + if (arg1Keys.indexOf(name) !== -1) { + const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; + if (v === null || v instanceof Tensor) { + isFetches = true; + isFetchesEmpty = false; + fetches[name] = v; + } + } + } + + if (isFetches) { + if (typeof arg2 === 'object' && arg2 !== null) { + options = arg2; + } else if (typeof arg2 !== 'undefined') { + throw new TypeError('\'options\' must be an object.'); + } + } else { + options = arg1 as RunOptions; + } + } + } else if (typeof arg1 !== 'undefined') { + throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.'); + } + + // check if all inputs are in feed + for (const name of this.inputNames) { + if (typeof feeds[name] === 'undefined') { + throw new Error(`input '${name}' is missing in 'feeds'.`); + } + } + + // if no fetches is specified, we use the full output names list + if (isFetchesEmpty) { + for (const name of this.outputNames) { + fetches[name] = null; + } + } + + const results = await this.handler.runTrainStep(feeds, fetches, options); + const returnValue: {[name: string]: OnnxValue} = {}; + for (const key in results) { + if (Object.hasOwnProperty.call(results, key)) { + const result = results[key]; + if (result instanceof Tensor) { + returnValue[key] = result; + } else { + returnValue[key] = new Tensor(result.type, result.data, result.dims); + } + } + } + return returnValue; } async release(): Promise { diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index 83d133b9a5157..9aeca4a28dce4 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, SessionHandler, TrainingSessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata} from './proxy-messages'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, runTrainStep, + releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; +import { encodeTensorMetadata, decodeTensorMetadata } from './session-handler'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { @@ -60,14 +62,52 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); } + async runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const inputArray: Tensor[] = []; + const inputIndices: number[] = []; + Object.entries(feeds).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = this.inputNames.indexOf(name); + if (index === -1) { + throw new Error(`invalid input '${name}'`); + } + inputArray.push(tensor); + inputIndices.push(index); + }); + + const outputArray: Array = []; + const outputIndices: number[] = []; + Object.entries(fetches).forEach(kvp => { + const name = kvp[0]; + const tensor = kvp[1]; + const index = this.outputNames.indexOf(name); + if (index === -1) { + throw new Error(`invalid output '${name}'`); + } + outputArray.push(tensor); + outputIndices.push(index); + }); + + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results. length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); + } + return resultMap; + } + async dispose(): Promise { return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } - async runTrainStep( - _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, - _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); - } } diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index a5017a920f38b..3ca34d957c572 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -10,7 +10,7 @@ import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; -const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { +export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': return [tensor.type, tensor.dims, tensor.data, 'cpu']; @@ -21,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta } }; -const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { +export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { switch (tensor[3]) { case 'cpu': return new Tensor(tensor[0], tensor[2], tensor[1]); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 947242945c665..3aacf8f4d90e0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -240,7 +240,7 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -const prepareInputOutputTensor = +export const prepareInputOutputTensor = (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): void => { if (!tensor) { diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 6baac86b2e885..7593bca81ffce 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -1,12 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession} from 'onnxruntime-common'; +import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import { prepareInputOutputTensor } from './wasm-core-impl'; +import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; import {setSessionOptions} from './session-options'; +import { tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor } from './wasm-common'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; +import { setRunOptions } from './run-options'; const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + 'Make sure to use the onnxruntime-training package for training functionality.'; @@ -140,6 +143,138 @@ export const createTrainingSessionHandle = } }; +export const runTrainStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + + try { + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // TODO: + // move all input and output processing -> wasm heap to one helper method???? + // can abstract out the similarities between input and output + // create input tensors + for (let i = 0; i < inputCount; i++) { + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, trainingSessionId, inputIndices[i]); + } + + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, trainingSessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + } + + let errorCode: number; + + if (wasm._OrtTrainingRunTrainStep) { + errorCode = await wasm._OrtTrainingRunTrainStep(trainingSessionId, inputValuesOffset, inputCount, + outputValuesOffset, outputCount, runOptionsHandle); + } + else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); + } + + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { + wasm._OrtReleaseTensor(tensor); + } + } + } + + return output; + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { From 3ca1c27955fef55f29850d485725efb9bb15e979 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 17 Oct 2023 16:45:27 -0700 Subject: [PATCH 15/18] light refactoring renamed session-handler for inference files lint + format --- js/common/lib/training-session-impl.ts | 23 +- js/web/lib/backend-onnxjs.ts | 2 +- js/web/lib/backend-wasm-training.ts | 2 +- js/web/lib/backend-wasm.ts | 2 +- ...andler.ts => session-handler-inference.ts} | 0 ...andler.ts => session-handler-inference.ts} | 0 ...raining.ts => session-handler-training.ts} | 10 +- js/web/lib/wasm/wasm-training-core-impl.ts | 222 ++++++++++-------- 8 files changed, 142 insertions(+), 119 deletions(-) rename js/web/lib/onnxjs/{session-handler.ts => session-handler-inference.ts} (100%) rename js/web/lib/wasm/{session-handler.ts => session-handler-inference.ts} (100%) rename js/web/lib/wasm/{session-handler-for-training.ts => session-handler-training.ts} (91%) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index e2d962677f30c..faf597931165a 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -4,9 +4,9 @@ import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; +import {Tensor} from './tensor.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; -import { OnnxValue } from './onnx-value.js'; -import { Tensor } from './tensor.js'; type SessionOptions = InferenceSession.SessionOptions; type FeedsType = InferenceSession.FeedsType; @@ -49,17 +49,8 @@ export class TrainingSession implements TrainingSessionInterface { } } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; - runTrainStep( - feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; + runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; @@ -159,6 +150,14 @@ export class TrainingSession implements TrainingSessionInterface { return returnValue; } + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async release(): Promise { return this.handler.dispose(); } diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 5ea7de809a495..7176823c9bf13 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -5,7 +5,7 @@ import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; -import {OnnxjsSessionHandler} from './onnxjs/session-handler'; +import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference'; class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 98e40807aa29c..09dac3a85311c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,7 +4,7 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; -import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 5740263583031..78edcc90f55f9 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -5,7 +5,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; -import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler'; +import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** * This function initializes all flags for WebAssembly. diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler-inference.ts similarity index 100% rename from js/web/lib/onnxjs/session-handler.ts rename to js/web/lib/onnxjs/session-handler-inference.ts diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler-inference.ts similarity index 100% rename from js/web/lib/wasm/session-handler.ts rename to js/web/lib/wasm/session-handler-inference.ts diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-training.ts similarity index 91% rename from js/web/lib/wasm/session-handler-for-training.ts rename to js/web/lib/wasm/session-handler-training.ts index 9aeca4a28dce4..e754e0bf64282 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, TrainingSessionHandler, Tensor} from 'onnxruntime-common'; +import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata} from './proxy-messages'; +import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, runTrainStep, - releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; -import { encodeTensorMetadata, decodeTensorMetadata } from './session-handler'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { @@ -99,7 +98,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results. length; i++) { + for (let i = 0; i < results.length; i++) { resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } return resultMap; @@ -109,5 +108,4 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } - } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 7593bca81ffce..d8bb0fae905f0 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,13 +3,13 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import { tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor } from './wasm-common'; +import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -import { setRunOptions } from './run-options'; const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + 'Make sure to use the onnxruntime-training package for training functionality.'; @@ -143,10 +143,113 @@ export const createTrainingSessionHandle = } }; +/** + * Prepares input and output tensors by creating the tensors in the WASM side then moving them to the heap + * @param trainingSessionId + * @param indices for each tensor, the index of the input or output name that the tensor corresponds with + * @param tensors list of TensorMetaData + * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting + * handles of the allocated tensors on the heap + * @param inputOutputAllocs modified in-place by this method + * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor + */ +const createAndAllocateTensors = + (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], + inputOutputAllocs: number[], indexAdd: number) => { + const wasm = getInstance(); + + const count = indices.length; + const valuesOffset = wasm.stackAlloc(count * 4); + + // creates the tensors + for (let i = 0; i < count; i++) { + prepareInputOutputTensor( + tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); + } + + // moves to heap + let valuesIndex = valuesOffset / 4; + for (let i = 0; i < count; i++) { + wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + } + + return valuesOffset; + }; + +/** + * Move output tensors from the heap to an array + * @param outputValuesOffset + * @param outputCount + * @returns + */ +const moveOutputToTensorMetadataArr = + (outputValuesOffset: number, outputCount: number) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); + + const keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { + wasm._OrtReleaseTensor(tensor); + } + } + } + + return output; + }; + export const runTrainStep = async( - trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], - outputTensors: Array, options: InferenceSession.RunOptions): Promise => { - const wasm = getInstance(); + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -159,108 +262,31 @@ export const runTrainStep = async( const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); try { + // prepare parameters by moving them to heap [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - // TODO: - // move all input and output processing -> wasm heap to one helper method???? - // can abstract out the similarities between input and output - // create input tensors - for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, trainingSessionId, inputIndices[i]); - } - - // create output tensors - for (let i = 0; i < outputCount; i++) { - prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, trainingSessionId, inputCount + outputIndices[i]); - } - - let inputValuesIndex = inputValuesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; - } - for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; - } - - let errorCode: number; + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); if (wasm._OrtTrainingRunTrainStep) { - errorCode = await wasm._OrtTrainingRunTrainStep(trainingSessionId, inputValuesOffset, inputCount, - outputValuesOffset, outputCount, runOptionsHandle); - } - else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } + const errorCode = wasm._OrtTrainingRunTrainStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } - - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let keepOutputTensor = false; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - if (!keepOutputTensor) { - wasm._OrtReleaseTensor(tensor); - } + if (errorCode !== 0) { + checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); } - return output; + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount); } finally { wasm.stackRestore(beforeRunStack); From b44b70508f03718278221719c9f58ff2791f612d Mon Sep 17 00:00:00 2001 From: carzh Date: Sun, 22 Oct 2023 19:02:59 -0700 Subject: [PATCH 16/18] getContiguousParameters & loadParametersBuffer impl wrote untested getContiguousParameters method updated getInputOutputCount and getInputOutputNames signature, added more informative error message updated parameter names according to suggestions semi working getContiguousParameters impl working getContiguousParams, started writing loadParametersBuffer working version of loadParametersBuffer --- js/common/lib/backend.ts | 5 +- js/common/lib/training-session-impl.ts | 12 +- js/common/lib/training-session.ts | 13 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 6 +- js/web/lib/wasm/session-handler-training.ts | 23 +- js/web/lib/wasm/wasm-training-core-impl.ts | 253 +++++++++++++++----- onnxruntime/wasm/api.cc | 50 ++-- onnxruntime/wasm/api.h | 14 +- 8 files changed, 277 insertions(+), 99 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index dd04ef3f15997..fd2e8bb74bbf5 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler { feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; + getParametersSize(trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; } /** diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index faf597931165a..48fed4224514f 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -150,12 +150,16 @@ export class TrainingSession implements TrainingSessionInterface { return returnValue; } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async getParametersSize(trainableOnly: boolean): Promise { + return this.handler.getParametersSize(trainableOnly); } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + return this.handler.loadParametersBuffer(array, trainableOnly); + } + + async getContiguousParameters(trainableOnly: boolean): Promise { + return this.handler.getContiguousParameters(trainableOnly); } async release(): Promise { diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 0967d79b33434..40ea16cf05ce4 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -49,13 +50,21 @@ export interface TrainingSession { // #endregion // #region copy parameters + + /** + * Retrieves the size of all parameters for the training state. + * + * @param trainableOnly skips non-trainable parameters when true. + */ + getParametersSize(trainableOnly: boolean): Promise; + /** * Copies from a buffer containing parameters to the TrainingSession parameters. * * @param buffer - buffer containing parameters * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; /** * Copies from the TrainingSession parameters to a buffer. @@ -63,7 +72,7 @@ export interface TrainingSession { * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. * @returns A promise that resolves to a buffer of the requested parameters. */ - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; // #endregion // #region release() diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 060fb1e756ef9..def706f53fc3a 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,8 +102,10 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; - _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; + _OrtTrainingGetInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index e754e0bf64282..af8f6dc0e2dd2 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,20 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, + releaseTrainingSessionAndCheckpoint, runTrainStep, loadParametersBuffer} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } private sessionId: number; private checkpointId: number; @@ -104,6 +99,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return resultMap; } + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + async dispose(): Promise { return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index d8bb0fae905f0..75035e4b9f694 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -6,22 +6,25 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. ' + - 'Make sure to use the onnxruntime-training package for training functionality.'; +const NO_TRAIN_FUNCS_MSG = + `Built without training API's enabled. Use the onnxruntime-web/training import for training \ + functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \ + using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); + const [checkpointDataOffset, checkpointDataLength] = checkpointData; let checkpointHandle = 0; try { if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointData[0], checkpointData[1]); + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -47,7 +50,7 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n try { const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetInputOutputCount) { - const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4); + const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false); if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } @@ -68,7 +71,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetInputOutputName) { - const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput); + const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false); if (name === 0) { checkLastError('Can\'t get input or output name'); } @@ -182,69 +185,65 @@ const createAndAllocateTensors = * @param outputCount * @returns */ -const moveOutputToTensorMetadataArr = - (outputValuesOffset: number, outputCount: number) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; +const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount: number) => { + const wasm = getInstance(); + const output: TensorMetadata[] = []; - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - const keepOutputTensor = false; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength) - .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - if (!keepOutputTensor) { - wasm._OrtReleaseTensor(tensor); - } + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); + + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); + + if (type === 'string') { + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } + output.push([type, dims, stringData, 'cpu']); + } else { + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); + const data = new typedArrayConstructor(size); + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); + output.push([type, dims, data, 'cpu']); } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + wasm._OrtReleaseTensor(tensor); + } + } - return output; - }; + return output; +}; export const runTrainStep = async( trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], @@ -301,6 +300,134 @@ export const runTrainStep = async( } }; +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): + number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + + if (errorCode !== 0) { + checkLastError('Can\'t get parameters size'); + } + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } + }; + +export const getContiguousParameters = async(trainingSessionId: number, trainableOnly: boolean): + Promise => { + const wasm = getInstance(); + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + // alloc buffer -- assumes parameters will be of type float32 + const stack = wasm.stackSave(); + let tensor: number = 0; + + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm.stackAlloc(paramsByteLength); + const bufferAlloc = wasm.stackAlloc(paramsOffset/4); + wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); + + // handles the dimensions-related createTensor parameters + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum('float32'), paramsOffset, paramsByteLength, dimsOffset, 1, + dataLocationStringToEnum('cpu')); + if (tensor === 0) { + checkLastError(`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`); + } + wasm.HEAPU32[bufferAlloc] = tensor; + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = + wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + if (errCode !== 0) { + checkLastError('Can\'t get contiguous parameters.'); + } + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + const typedArrayConstructor = tensorTypeToTypedArrayConstructor('float32'); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push(['float32', [parametersSize], data, 'cpu']); + if (output.length > 1 || output.length < 1) { + throw new Error( + `something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + console.log('test'); + if (tensor !== 0) { + console.log('tensor is not equal to 0'); + wasm._OrtReleaseTensor(tensor); + } + console.log('test after ortReleaseTensor call but before stackRestore call'); + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm._free(bufferAlloc); + wasm.stackRestore(stack); + } + }; + +export const loadParametersBuffer = async (trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): + Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + const bufferCount = buffer.length; + const bufferByteLength = bufferCount * 4; + const bufferOffset = wasm.stackAlloc(bufferByteLength); + wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor: number = 0; + const bufferAlloc = wasm.stackAlloc(bufferOffset/4); + + try { + tensor = wasm._OrtCreateTensor(tensorDataTypeStringToEnum('float32'), bufferOffset, bufferByteLength, dimsOffset, dimsLength, dataLocationStringToEnum('cpu')); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${trainingSessionId}`); + } + wasm.HEAPU32[bufferAlloc] = tensor; + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = + wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + + if (errCode !== 0) { + checkLastError('Can\'t copy buffer to parameters.'); + } + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferAlloc); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +} + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 2645d6f05222f..f8375c0a77ae3 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -581,30 +581,52 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count) { - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); - RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); - return ORT_OK; + size_t* output_count, + bool isEvalModel) { + if (isEvalModel) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); + return ORT_OK; + } else { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; + } } char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput) { + bool isInput, + bool isEvalModel) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); char* name = nullptr; - if (isInput) { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isEvalModel) { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } else { - return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, - allocator, &name) == ORT_OK) - ? name - : nullptr; + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } } } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index ea21eb8a9e8c8..d7bc84c0f00bd 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -433,27 +433,33 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio bool trainable_only); /** - * Gets the input count and output count of the training model associated with the given training handle. + * Gets the input count and output count of the training or eval model associated with the given training handle. * @param traning_handle handle of the traning session * @param input_count [out] a pointer to a size_t variable to accept input_count * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output + * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputCount(ort_training_session_handle_t training_handle, size_t* input_count, - size_t* output_count); + size_t* output_count, + bool isEvalModel); /** - * Gets the input or output name at the specified index associated with the training model from the + * Gets the input or output name at the specified index associated with the training or eval model from the * given training session. * @param traning_handle handle of the traning session * @param index the input or output index * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output + * names of the eval model. * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetInputOutputName(ort_training_session_handle_t training_handle, size_t index, - bool isInput); + bool isInput, + bool isEvalModel); /** * @brief Release the specified ORT training session. From c74112e164a028fcdb6b9a473f9d4f33e58f93fb Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 24 Oct 2023 16:05:49 -0700 Subject: [PATCH 17/18] lint + format + added error code checking wrapper --- js/web/lib/wasm/session-handler-training.ts | 3 +- js/web/lib/wasm/wasm-training-core-impl.ts | 266 ++++++++++---------- 2 files changed, 136 insertions(+), 133 deletions(-) diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index af8f6dc0e2dd2..0f58d6d288378 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -6,8 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio import {SerializableModeldata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, - releaseTrainingSessionAndCheckpoint, runTrainStep, loadParametersBuffer} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number; diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 75035e4b9f694..b63f3fd1da311 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG = functionality, and make sure that all the correct artifacts are built & moved to the correct folder if \ using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.`; +/** + * Runs the checkLastError function which will throw an error, if the provided error code matches the specified + * pattern for an error code. + * @param errCode number to evaluated for if it's an erro + * @param message message to pass into checkLastError + * @param checkNeqZero when true, treats not equal to zero as an error. + * When false, treats equal to zero as an error. + */ +const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { + if (checkNeqZero && errCode !== 0) { + checkLastError(message); + } else if (!checkNeqZero && errCode === 0) { + checkLastError(message); + } +}; + export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -29,9 +45,8 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n throw new Error(NO_TRAIN_FUNCS_MSG); } - if (checkpointHandle === 0) { - checkLastError('Error occurred when trying to create a CheckpointState.'); - } + ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); + return checkpointHandle; } catch (e) { if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { @@ -51,9 +66,8 @@ const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, n const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetInputOutputCount) { const errorCode = wasm._OrtTrainingGetInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, false); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } + ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -72,9 +86,7 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetInputOutputName) { const name = wasm._OrtTrainingGetInputOutputName(trainingSessionId, i, isInput, false); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); @@ -119,9 +131,7 @@ export const createTrainingSessionHandle = throw new Error(NO_TRAIN_FUNCS_MSG); } - if (trainingSessionHandle === 0) { - checkLastError('Error occurred when trying to create a TrainingSession.'); - } + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = getTrainingModelInputOutputNames(trainingSessionHandle); @@ -200,9 +210,8 @@ const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount: try { const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; dataOffset = wasm.HEAPU32[tensorDataIndex++]; @@ -278,9 +287,7 @@ export const runTrainStep = async( const errorCode = wasm._OrtTrainingRunTrainStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -300,133 +307,130 @@ export const runTrainStep = async( } }; -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): - number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); - if (errorCode !== 0) { - checkLastError('Can\'t get parameters size'); - } + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } - }; +export const getContiguousParameters = + async(trainingSessionId: number, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); -export const getContiguousParameters = async(trainingSessionId: number, trainableOnly: boolean): - Promise => { - const wasm = getInstance(); - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - // alloc buffer -- assumes parameters will be of type float32 - const stack = wasm.stackSave(); - let tensor: number = 0; - - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm.stackAlloc(paramsByteLength); - const bufferAlloc = wasm.stackAlloc(paramsOffset/4); - wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); - - // handles the dimensions-related createTensor parameters - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum('float32'), paramsOffset, paramsByteLength, dimsOffset, 1, - dataLocationStringToEnum('cpu')); - if (tensor === 0) { - checkLastError(`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`); - } - wasm.HEAPU32[bufferAlloc] = tensor; - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = - wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - if (errCode !== 0) { - checkLastError('Can\'t get contiguous parameters.'); - } - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + let tensor = 0; - const typedArrayConstructor = tensorTypeToTypedArrayConstructor('float32'); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); - output.push(['float32', [parametersSize], data, 'cpu']); - if (output.length > 1 || output.length < 1) { - throw new Error( - `something unexpected happened in the getContiguousParameters function. Expected output length of + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm.stackAlloc(paramsByteLength); + wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); + + const tensorOffset = wasm.stackAlloc(paramsOffset / 4); + + // handles the dimensions-related createTensor parameters + const dims = [parametersSize]; + + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError( + tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + + wasm.HEAPU32[tensorOffset] = tensor; + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push([tensorTypeAsString, dims, data, locationAsString]); + if (output.length > 1 || output.length < 1) { + throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - console.log('test'); - if (tensor !== 0) { - console.log('tensor is not equal to 0'); - wasm._OrtReleaseTensor(tensor); - } - console.log('test after ortReleaseTensor call but before stackRestore call'); - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm._free(bufferAlloc); - wasm.stackRestore(stack); - } - }; + } else { + return output[0]; + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm._free(tensorOffset); + wasm.stackRestore(stack); + } +}; -export const loadParametersBuffer = async (trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): - Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - const bufferCount = buffer.length; - const bufferByteLength = bufferCount * 4; - const bufferOffset = wasm.stackAlloc(bufferByteLength); - wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor: number = 0; - const bufferAlloc = wasm.stackAlloc(bufferOffset/4); +export const loadParametersBuffer = + async(trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); - try { - tensor = wasm._OrtCreateTensor(tensorDataTypeStringToEnum('float32'), bufferOffset, bufferByteLength, dimsOffset, dimsLength, dataLocationStringToEnum('cpu')); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${trainingSessionId}`); - } - wasm.HEAPU32[bufferAlloc] = tensor; + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = - wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + const bufferCount = buffer.length; + const bufferByteLength = bufferCount * 4; + const bufferOffset = wasm.stackAlloc(bufferByteLength); + wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor = 0; + const bufferAlloc = wasm.stackAlloc(bufferOffset / 4); - if (errCode !== 0) { - checkLastError('Can\'t copy buffer to parameters.'); - } + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } + wasm.HEAPU32[bufferAlloc] = tensor; - } finally { - if (tensor !== 0) { + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (tensor !== 0) { wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferAlloc); - wasm._free(bufferOffset); - wasm._free(dimsOffset); } -} + wasm.stackRestore(stack); + wasm._free(bufferAlloc); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +}; export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): From cd2e4fa0417d5a99ac6aa4ad0acb0d6c82111ca9 Mon Sep 17 00:00:00 2001 From: carzh Date: Tue, 24 Oct 2023 17:10:09 -0700 Subject: [PATCH 18/18] enforced that loadParametersBuffer takes in a buffer that matches the number of parameters --- js/common/lib/training-session-impl.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 48fed4224514f..db99f420bef3a 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -155,6 +155,11 @@ export class TrainingSession implements TrainingSessionInterface { } async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + const paramsSize = await this.getParametersSize(trainableOnly); + if (array.length !== paramsSize) { + throw new Error('Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.'); + } return this.handler.loadParametersBuffer(array, trainableOnly); }