Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "glue" between training WASM artifacts and training web #17474

Merged
merged 10 commits into from
Oct 12, 2023
5 changes: 5 additions & 0 deletions js/web/lib/backend-wasm-inference.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
17 changes: 17 additions & 0 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
4 changes: 1 addition & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
}
};

class OnnxruntimeWebAssemblyBackend implements Backend {
export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
// populate wasm flags
initializeFlags();
Expand All @@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}

export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
4 changes: 4 additions & 0 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ interface BuildDefinitions {
* defines whether to disable multi-threading feature in WebAssembly backend in the build.
*/
readonly DISABLE_WASM_THREAD: boolean;
/**
* defines whether to disable training APIs in WebAssembly backend.
*/
readonly DISABLE_TRAINING: boolean;
}

declare const BUILD_DEFS: BuildDefinitions;
9 changes: 6 additions & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}

if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
if (BUILD_DEFS.DISABLE_TRAINING) {
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
19 changes: 14 additions & 5 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm';
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';

/* eslint-disable @typescript-eslint/no-require-imports */
const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;

if (!BUILD_DEFS.DISABLE_TRAINING) {
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
} else {
ortWasmFactory =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
}

const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
(BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
Expand Down Expand Up @@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => {
};

const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useThreads) {
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
if (useSimd) {
if (!BUILD_DEFS.DISABLE_TRAINING) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
} else {
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
}
};

Expand Down
14 changes: 14 additions & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@
"development": "./dist/ort.webgpu.js",
"default": "./dist/ort.webgpu.min.js"
}
},
"./training": {
"import": {
"development": "./dist/esm/ort.training.wasm.js",
"default": "./dist/esm/ort.training.wasm.min.js"
},
"require": {
"development": "./dist/cjs/ort.training.wasm.js",
"default": "./dist/cjs/ort.training.wasm.min.js"
},
"default": {
"development": "./dist/ort.training.wasm.js",
"default": "./dist/ort.training.wasm.min.js"
}
}
},
"types": "./types.d.ts",
Expand Down
13 changes: 12 additions & 1 deletion js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
'BUILD_DEFS.DISABLE_TRAINING': 'true',
};

const COPYRIGHT_HEADER = `/*!
Expand Down Expand Up @@ -407,7 +408,7 @@ async function main() {
});
// ort.wasm-core[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.wasm-core.min',
outputBundleName: 'ort.wasm-core',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
Expand All @@ -416,6 +417,16 @@ async function main() {
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
});
// ort.training.wasm[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.training.wasm',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
},
});
}

if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {
Expand Down
4 changes: 4 additions & 0 deletions js/web/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' {
declare module 'onnxruntime-web/webgpu' {
export * from 'onnxruntime-web';
}

declare module 'onnxruntime-web/training' {
export * from 'onnxruntime-web';
}