diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts new file mode 100644 index 0000000000000..475a0243ebd3d --- /dev/null +++ b/js/web/lib/backend-wasm-inference.ts @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts new file mode 100644 index 0000000000000..af5b575c87a7f --- /dev/null +++ b/js/web/lib/backend-wasm-training.ts @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; + +import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; + +class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { + async createTrainingSessionHandler( + _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, + _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, + _options: InferenceSession.SessionOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} + +export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 29649a1645e9c..5740263583031 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -32,7 +32,7 @@ export const initializeFlags = (): void => { } }; -class OnnxruntimeWebAssemblyBackend implements Backend { +export class OnnxruntimeWebAssemblyBackend implements Backend { async init(): Promise { // populate wasm flags initializeFlags(); @@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } - -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 928be2004acf6..8b14b57acc062 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -30,6 +30,10 @@ interface BuildDefinitions { * defines whether to disable multi-threading feature in WebAssembly backend in the build. */ readonly DISABLE_WASM_THREAD: boolean; + /** + * defines whether to disable training APIs in WebAssembly backend. + */ + readonly DISABLE_TRAINING: boolean; } declare const BUILD_DEFS: BuildDefinitions; diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index d8bda10c7a0c7..c5c27a4318049 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = require('./backend-wasm').wasmBackend; + const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : + require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - registerBackend('xnnpack', wasmBackend, 9); - registerBackend('webnn', wasmBackend, 9); + if (BUILD_DEFS.DISABLE_TRAINING) { + registerBackend('xnnpack', wasmBackend, 9); + registerBackend('webnn', wasmBackend, 9); + } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index d953361fe715f..2b7d492cc70ba 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; /* eslint-disable @typescript-eslint/no-require-imports */ -const ortWasmFactory: EmscriptenModuleFactory = - BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +let ortWasmFactory: EmscriptenModuleFactory; + +if (!BUILD_DEFS.DISABLE_TRAINING) { + ortWasmFactory = require('./binding/ort-training-wasm-simd.js'); +} else { + ortWasmFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); +} const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : @@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => { }; const getWasmFileName = (useSimd: boolean, useThreads: boolean) => { - if (useThreads) { - return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm'; + if (useSimd) { + if (!BUILD_DEFS.DISABLE_TRAINING) { + return 'ort-training-wasm-simd.wasm'; + } + return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm'; } else { - return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm'; + return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm'; } }; diff --git a/js/web/package.json b/js/web/package.json index 997860055ff50..0bdc529df704d 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -151,6 +151,20 @@ "development": "./dist/ort.webgpu.js", "default": "./dist/ort.webgpu.min.js" } + }, + "./training": { + "import": { + "development": "./dist/esm/ort.training.wasm.js", + "default": "./dist/esm/ort.training.wasm.min.js" + }, + "require": { + "development": "./dist/cjs/ort.training.wasm.js", + "default": "./dist/cjs/ort.training.wasm.min.js" + }, + "default": { + "development": "./dist/ort.training.wasm.js", + "default": "./dist/ort.training.wasm.min.js" + } } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 9743cbe8ec326..5151f27582c1f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -47,6 +47,7 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', + 'BUILD_DEFS.DISABLE_TRAINING': 'true', }; const COPYRIGHT_HEADER = `/*! @@ -407,7 +408,7 @@ async function main() { }); // ort.wasm-core[.min].js await addAllWebBuildTasks({ - outputBundleName: 'ort.wasm-core.min', + outputBundleName: 'ort.wasm-core', define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', @@ -416,6 +417,16 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, }); + // ort.training.wasm[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.training.wasm', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_TRAINING': 'false', + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + }, + }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 2cb4578d99687..b9d12cf47b5c5 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } + +declare module 'onnxruntime-web/training' { + export * from 'onnxruntime-web'; +} diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index df94326e88045..cd8bc8fe909dc 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -514,7 +514,7 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { API_IMPL_END } -ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { API_IMPL_BEGIN #ifdef USE_DML auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); @@ -547,7 +547,8 @@ static constexpr OrtDmlApi ort_dml_api_10_to_x = { &OrtSessionOptionsAppendExecutionProviderEx_DML, &CreateGPUAllocationFromD3DResource, &FreeGPUAllocation, - &GetD3D12ResourceFromAllocation + &GetD3D12ResourceFromAllocation, + &OrtSessionOptionsAppendExecutionProvider_DML2, }; const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION { diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 7239e5242543d..cb9633ff049a5 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1230,6 +1230,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"candy", "Temporarily disabled pending investigation"}); broken_tests->insert({"BERT_Squad", "Temporarily disabled pending investigation"}); broken_tests->insert({"LSTM_Seq_lens_unpacked", "The parameter is incorrect"}); + broken_tests->insert({"mlperf_ssd_resnet34_1200", "The parameter is incorrect"}); broken_tests->insert({"resize_downsample_scales_linear", "DML uses half_pixel and this test assumed \"asymmetric\" but does not include \"mode\""});