diff --git a/js/web/lib/backend-wasm-with-training.ts b/js/web/lib/backend-wasm-with-training.ts new file mode 100644 index 0000000000000..6d31861fb1bcd --- /dev/null +++ b/js/web/lib/backend-wasm-with-training.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm' + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 04108c2ad0f66..3f06eea5329a0 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise { // populate wasm flags initializeFlags(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d5ed536034f3e..03e1108b31d12 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -7,7 +7,7 @@ // So we import code inside the if-clause to allow terser remove the code safely. export * from 'onnxruntime-common'; -import {registerBackend, env} from 'onnxruntime-common'; +import {registerBackend, env, Backend} from 'onnxruntime-common'; import {version} from './version'; if (!BUILD_DEFS.DISABLE_WEBGL) { @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend : + require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (!BUILD_DEFS.ENABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 4ccb06e281b9d..1ce06b0f63b88 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -12,9 +12,9 @@ let ortWasmFactory: EmscriptenModuleFactory; if (BUILD_DEFS.ENABLE_TRAINING) { ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); -} -else { - ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); } const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? @@ -78,18 +78,14 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => { - let wasmArtifact : string = 'ort'; - if (useTraining) { - wasmArtifact += '-training'; - } - wasmArtifact += '-wasm'; if (useSimd) { - wasmArtifact += '-simd'; - } - if (useThreads) { - wasmArtifact += '-threaded'; + if (useTraining) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; + } else { + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } - return wasmArtifact + '.wasm'; }; export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise => {