Skip to content

Commit

Permalink
edited wasm factory + added training backend
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Oct 2, 2023
1 parent cff6a7f commit 7a1365d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
17 changes: 17 additions & 0 deletions js/web/lib/backend-wasm-with-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
}
};

class OnnxruntimeWebAssemblyBackend implements Backend {
export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
// populate wasm flags
initializeFlags();
Expand Down
11 changes: 7 additions & 4 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// So we import code inside the if-clause to allow terser remove the code safely.

export * from 'onnxruntime-common';
import {registerBackend, env} from 'onnxruntime-common';
import {registerBackend, env, Backend} from 'onnxruntime-common';
import {version} from './version';

if (!BUILD_DEFS.DISABLE_WEBGL) {
Expand All @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}

if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
const wasmBackend = BUILD_DEFS.ENABLE_TRAINING ? require('./backend-wasm-with-training').wasmBackend :
require('./backend-wasm').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
if (!BUILD_DEFS.ENABLE_TRAINING) {
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
22 changes: 9 additions & 13 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;

if (BUILD_DEFS.ENABLE_TRAINING) {
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
}
else {
ortWasmFactory = BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
} else {
ortWasmFactory =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
}

const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
Expand Down Expand Up @@ -78,18 +78,14 @@ const isSimdSupported = (): boolean => {
};

const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => {
let wasmArtifact : string = 'ort';
if (useTraining) {
wasmArtifact += '-training';
}
wasmArtifact += '-wasm';
if (useSimd) {
wasmArtifact += '-simd';
}
if (useThreads) {
wasmArtifact += '-threaded';
if (useTraining) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
} else {
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
}
return wasmArtifact + '.wasm';
};

export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise<void> => {
Expand Down

0 comments on commit 7a1365d

Please sign in to comment.