From cff6a7f11e5a687b359b3d5c28063afb37fa81dc Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Tue, 29 Aug 2023 21:50:18 +0000 Subject: [PATCH 1/8] added training build configuration --- js/web/lib/build-def.d.ts | 4 ++++ js/web/lib/wasm/wasm-factory.ts | 27 ++++++++++++++++++++------- js/web/script/build.ts | 4 ++++ js/web/webpack.config.js | 10 ++++++++++ 4 files changed, 38 insertions(+), 7 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 2049b2663ead3..939486926f319 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -30,6 +30,10 @@ interface BuildDefinitions { * defines whether to disable multi-threading feature in WebAssembly backend in the build. */ DISABLE_WASM_THREAD: boolean; + /** + * defines whether to enable training APIs in WebAssembly backend. + */ + ENABLE_TRAINING: boolean; } declare let BUILD_DEFS: BuildDefinitions; diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 7648f0c473f07..4ccb06e281b9d 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ -const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> = - BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>; + +if (BUILD_DEFS.ENABLE_TRAINING) { + ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); +} +else { + ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : @@ -71,12 +77,19 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { + let wasmArtifact : string = 'ort'; + if (useTraining) { + wasmArtifact += '-training'; + } + wasmArtifact += '-wasm'; + if (useSimd) { + wasmArtifact += '-simd'; + } if (useThreads) { - return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm'; - } else { - return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm'; + wasmArtifact += '-threaded'; } + return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise<void> => { @@ -102,7 +115,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads); + const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 03510ae86b85f..bda7da0ab2a8b 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -33,6 +33,7 @@ const FILTER = args.f || args.filter; const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); +const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -45,6 +46,7 @@ const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist'); const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm'); const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm'); const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm'); +const TRAINING_WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-training-wasm-simd.wasm'); const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm'); const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm'); const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm'); @@ -75,6 +77,8 @@ if (WASM) { validateFile(WASM_SIMD_THREADED_WASM_PATH); validateFile(WASM_SIMD_JSEP_WASM_PATH); validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + validateFile(TRAINING_WASM_SIMD_WASM_PATH); } catch (e) { npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`); throw e; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 81c69ffdcf6bf..7115a3e3b4258 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -57,6 +57,7 @@ const DEFAULT_BUILD_DEFS = { DISABLE_WASM: false, DISABLE_WASM_PROXY: false, DISABLE_WASM_THREAD: false, + ENABLE_TRAINING: false }; // common config for release bundle @@ -289,6 +290,15 @@ module.exports = () => { buildOrtWebConfig({suffix: '.es6.min', target: 'es6'}), // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), + + // ort.wasm.min.js + buildOrtConfig({ + suffix: '.training.wasm.min', + build_defs: { + ...DEFAULT_BUILD_DEFS, + ENABLE_TRAINING: true, + } + }), ); case 'node': From 7a1365d93272567feda3a6682ad6a2cb79422898 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Tue, 29 Aug 2023 22:13:34 +0000 Subject: [PATCH 2/8] edited wasm factory + added training backend --- js/web/lib/backend-wasm-with-training.ts | 17 +++++++++++++++++ js/web/lib/backend-wasm.ts | 2 +- js/web/lib/index.ts | 11 +++++++---- js/web/lib/wasm/wasm-factory.ts | 22 +++++++++------------- 4 files changed, 34 insertions(+), 18 deletions(-) create mode 100644 js/web/lib/backend-wasm-with-training.ts diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts new file mode 100644 index 0000000000000..6d31861fb1bcd --- /dev/null +++ b/js/web/lib/backend-wasm-with-training.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> { + throw new Error('Method not implemented yet.'); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 04108c2ad0f66..3f06eea5329a0 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise<void> { // populate wasm flags initializeFlags(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d5ed536034f3e..03e1108b31d12 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env} from 'onnxruntime-common'; +import {registerBackend, env, Backend} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : + require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (!BUILD_DEFS.ENABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 4ccb06e281b9d..1ce06b0f63b88 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -12,9 +12,9 @@ let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>; if (BUILD_DEFS.ENABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); -} -else { - ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); } const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ? @@ -78,18 +78,14 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { - let wasmArtifact : string = 'ort'; - if (useTraining) { - wasmArtifact += '-training'; - } - wasmArtifact += '-wasm'; if (useSimd) { - wasmArtifact += '-simd'; - } - if (useThreads) { - wasmArtifact += '-threaded'; + if (useTraining) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; + } else { + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } - return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise<void> => { From 71f9dbc94735908d4093e4fa114b62867e29306c Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Thu, 31 Aug 2023 21:02:16 +0000 Subject: [PATCH 3/8] added package.json modification for training + minimized training artifact --- js/web/lib/backend-wasm-with-training.ts | 4 +++- js/web/lib/index.ts | 2 +- js/web/package.json | 4 ++++ js/web/script/build.ts | 26 ++++++++++++++++++++++++ js/web/webpack.config.js | 5 +++-- 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 6d31861fb1bcd..5fac10da7791d 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -3,15 +3,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; -import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> { throw new Error('Method not implemented yet.'); } + /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 03e1108b31d12..30d0af5dc9e47 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env, Backend} from 'onnxruntime-common'; +import {registerBackend, env} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { diff --git a/js/web/package.json b/js/web/package.json index 8ae5b733e5f21..798434bf5d574 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -78,6 +78,10 @@ "./webgpu": { "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" + }, + "./training": { + "types": "./types.d.ts", + "default": "./dist/ort-web.training.wasm.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index bda7da0ab2a8b..ff94d1a806a09 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -34,6 +34,7 @@ const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); const TRAINING_WASM_BINDING_SIMD_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.js'); +const TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-training-wasm-simd.min.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); @@ -93,6 +94,31 @@ if (WASM) { */ `; + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"...'); + try { + const terser = spawnSync( + 'npx', + [ + 'terser', TRAINING_WASM_BINDING_SIMD_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', + '--mangle', 'reserved=[_scriptDir]', '--module' + ], + {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); + if (terser.status !== 0) { + console.error(terser.error); + process.exit(terser.status === null ? undefined : terser.status); + } + + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH, terser.stdout); + fs.writeFileSync(TRAINING_WASM_BINDING_SIMD_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + + validateFile(TRAINING_WASM_BINDING_SIMD_MIN_JS_PATH); + validateFile(TRAINING_WASM_BINDING_SIMD_JS_PATH); + } catch (e) { + npmlog.error('Build', `Failed to run terser on ort-training-wasm-simd.js. ERR: ${e}`); + throw e; + } + npmlog.info('Build', 'Minimizing file "ort-training-wasm-simd.js"... DONE'); + npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"...'); try { const terser = spawnSync( diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 7115a3e3b4258..1ac421b5246d9 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -104,6 +104,7 @@ function buildConfig({filename, format, target, mode, devtool, build_defs}) { config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js'; config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js'; config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js'; + config.resolve.alias['./binding/ort-training-wasm-simd.js'] = './binding/ort-training-wasm-simd.min.js'; const options = defaultTerserPluginOptions(target); options.terserOptions.format.preamble = COPYRIGHT_BANNER; @@ -291,8 +292,8 @@ module.exports = () => { // ort-web.es5.min.js buildOrtWebConfig({suffix: '.es5.min', target: 'es5'}), - // ort.wasm.min.js - buildOrtConfig({ + // ort-web.training.wasm.min.js + buildOrtWebConfig({ suffix: '.training.wasm.min', build_defs: { ...DEFAULT_BUILD_DEFS, From 5adbcf407e5c5b2c2797ff9568b3a4d80ccf2a74 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Fri, 22 Sep 2023 21:21:42 +0000 Subject: [PATCH 4/8] applied suggestions --- js/web/lib/backend-wasm-with-training.ts | 8 +++----- js/web/lib/wasm/wasm-factory.ts | 6 +++--- js/web/webpack.config.js | 2 ++ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts index 5fac10da7791d..af5b575c87a7f 100644 --- a/js/web/lib/backend-wasm-with-training.ts +++ b/js/web/lib/backend-wasm-with-training.ts @@ -6,14 +6,12 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - /* eslint-disable @typescript-eslint/no-unused-vars */ async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, - evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, - options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> { + _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, + _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, + _options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> { throw new Error('Method not implemented yet.'); } - /* eslint-enable @typescript-eslint/no-unused-vars */ } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 1ce06b0f63b88..08192b4174272 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -77,9 +77,9 @@ const isSimdSupported = (): boolean => { } }; -const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { +const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (useTraining) { + if (BUILD_DEFS.ENABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; @@ -111,7 +111,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const wasmPaths = flags.wasmPaths; const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; - const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING); + const wasmFileName = getWasmFileName(useSimd, useThreads); const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 1ac421b5246d9..28d30604dc8d3 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -298,6 +298,8 @@ module.exports = () => { build_defs: { ...DEFAULT_BUILD_DEFS, ENABLE_TRAINING: true, + DISABLE_WEBGL: true, + DISABLE_WEBGPU: true, } }), ); From 5202e3b9363753fd363c4265b88865671dc0d442 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Fri, 6 Oct 2023 16:26:40 -0700 Subject: [PATCH 5/8] fixed variable names + enforced wasmBackend being a singleton with suggested fix of adding backend-wasm-inference.ts --- js/web/lib/backend-wasm-inference.ts | 5 +++++ ...ckend-wasm-with-training.ts => backend-wasm-training.ts} | 0 js/web/lib/backend-wasm.ts | 2 -- js/web/lib/index.ts | 6 +++--- js/web/lib/wasm/wasm-factory.ts | 4 ++-- 5 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/backend-wasm-inference.ts rename js/web/lib/{backend-wasm-with-training.ts => backend-wasm-training.ts} (100%) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts new file mode 100644 index 0000000000000..059a30bd45a57 --- /dev/null +++ b/js/web/lib/backend-wasm-inference.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-training.ts similarity index 100% rename from js/web/lib/backend-wasm-with-training.ts rename to js/web/lib/backend-wasm-training.ts diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 60468581968e3..5740263583031 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -51,5 +51,3 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } - -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 30d0af5dc9e47..09f478dc2ddbe 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -16,14 +16,14 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : - require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.ENABLE_TRAINING) { + if (BUILD_DEFS.DISABLE_TRAINING) { registerBackend('xnnpack', wasmBackend, 9); registerBackend('webnn', wasmBackend, 9); } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 87fa0a7771d3c..2b7d492cc70ba 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -10,7 +10,7 @@ import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>; -if (BUILD_DEFS.ENABLE_TRAINING) { +if (!BUILD_DEFS.DISABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); } else { ortWasmFactory = @@ -79,7 +79,7 @@ const isSimdSupported = (): boolean => { const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { if (useSimd) { - if (BUILD_DEFS.ENABLE_TRAINING) { + if (!BUILD_DEFS.DISABLE_TRAINING) { return 'ort-training-wasm-simd.wasm'; } return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; From a277347111af09bc5b7623d42215a36732c4f7e9 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Fri, 6 Oct 2023 17:08:47 -0700 Subject: [PATCH 6/8] format + lint --- js/web/lib/backend-wasm-inference.ts | 2 +- js/web/lib/index.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts index 059a30bd45a57..475a0243ebd3d 100644 --- a/js/web/lib/backend-wasm-inference.ts +++ b/js/web/lib/backend-wasm-inference.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { OnnxruntimeWebAssemblyBackend } from "./backend-wasm"; +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 09f478dc2ddbe..4b3d874534398 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -17,7 +17,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : - require('./backend-wasm-training').wasmBackend; + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } From 5f717d7b3282f3a2517a94797a907fd37d4655a2 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Wed, 11 Oct 2023 14:49:01 -0700 Subject: [PATCH 7/8] added training module declaration to types.d.ts --- js/web/types.d.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 2cb4578d99687..b9d12cf47b5c5 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } + +declare module 'onnxruntime-web/training' { + export * from 'onnxruntime-web'; +} From 29e148c178db59df1537aebf0774071c7b4c6cb2 Mon Sep 17 00:00:00 2001 From: carzh <wolfivyaura@gmail.com> Date: Wed, 11 Oct 2023 16:26:08 -0700 Subject: [PATCH 8/8] fixed naming prefix for training in build script --- js/web/script/build.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 314fd7f6babbc..5151f27582c1f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -408,7 +408,7 @@ async function main() { }); // ort.wasm-core[.min].js await addAllWebBuildTasks({ - outputBundleName: 'ort.wasm-core.min', + outputBundleName: 'ort.wasm-core', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', @@ -417,9 +417,9 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, }); - // ort.training.wasm.min.js + // ort.training.wasm[.min].js await addAllWebBuildTasks({ - outputBundleName: 'ort.training.wasm.min', + outputBundleName: 'ort.training.wasm', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_TRAINING': 'false',