Skip to content

Commit

Permalink
Add "glue" between training WASM artifacts and training web (microsof…
Browse files Browse the repository at this point in the history
…t#17474)

### Description
* follows the packaging approach according to the design document
    * adds `ENABLE_TRAINING` boolean flag to `BUILD_DEFS`
    * modifies `package.json` to include training submodule
* modifies build script to handle, validate, and minimize training WASM
artifacts
* adds the binding for the new backend with training enabled & the new
training artifacts
    * adds training backend
    * edits `index.ts` to use training backend depending on `BUILD_DEFS`
    * edits `wasm-factory.ts` to use the training artifacts if necessary

### Motivation and Context
* we are in the process of adding web bindings to enable training. 
* Adding the "glue" to allow onnxruntime-web to use the training WASM
artifacts is required for this work.
* Since BUILD_DEFS is defined and used at build time, I thought that it
made sense to bundle the changes to building in the same PR.
#### Related work
* microsoft#16521 allowed for training artifacts to be built
* microsoft#17333 must be merged in before this one

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent 254b87c commit c5aac0e
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 12 deletions.
5 changes: 5 additions & 0 deletions js/web/lib/backend-wasm-inference.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
17 changes: 17 additions & 0 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
4 changes: 1 addition & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
}
};

class OnnxruntimeWebAssemblyBackend implements Backend {
export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
// populate wasm flags
initializeFlags();
Expand All @@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}

export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
4 changes: 4 additions & 0 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ interface BuildDefinitions {
* defines whether to disable multi-threading feature in WebAssembly backend in the build.
*/
readonly DISABLE_WASM_THREAD: boolean;
/**
* defines whether to disable training APIs in WebAssembly backend.
*/
readonly DISABLE_TRAINING: boolean;
}

declare const BUILD_DEFS: BuildDefinitions;
9 changes: 6 additions & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}

if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
if (BUILD_DEFS.DISABLE_TRAINING) {
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
19 changes: 14 additions & 5 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm';
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';

/* eslint-disable @typescript-eslint/no-require-imports */
const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;

if (!BUILD_DEFS.DISABLE_TRAINING) {
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
} else {
ortWasmFactory =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
}

const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
(BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
Expand Down Expand Up @@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => {
};

const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useThreads) {
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
if (useSimd) {
if (!BUILD_DEFS.DISABLE_TRAINING) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
} else {
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
}
};

Expand Down
14 changes: 14 additions & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@
"development": "./dist/ort.webgpu.js",
"default": "./dist/ort.webgpu.min.js"
}
},
"./training": {
"import": {
"development": "./dist/esm/ort.training.wasm.js",
"default": "./dist/esm/ort.training.wasm.min.js"
},
"require": {
"development": "./dist/cjs/ort.training.wasm.js",
"default": "./dist/cjs/ort.training.wasm.min.js"
},
"default": {
"development": "./dist/ort.training.wasm.js",
"default": "./dist/ort.training.wasm.min.js"
}
}
},
"types": "./types.d.ts",
Expand Down
13 changes: 12 additions & 1 deletion js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
'BUILD_DEFS.DISABLE_TRAINING': 'true',
};

const COPYRIGHT_HEADER = `/*!
Expand Down Expand Up @@ -407,7 +408,7 @@ async function main() {
});
// ort.wasm-core[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.wasm-core.min',
outputBundleName: 'ort.wasm-core',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
Expand All @@ -416,6 +417,16 @@ async function main() {
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
});
// ort.training.wasm[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.training.wasm',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
},
});
}

if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {
Expand Down
4 changes: 4 additions & 0 deletions js/web/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' {
declare module 'onnxruntime-web/webgpu' {
export * from 'onnxruntime-web';
}

declare module 'onnxruntime-web/training' {
export * from 'onnxruntime-web';
}

0 comments on commit c5aac0e

Please sign in to comment.