Skip to content

Commit

Permalink
applied suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Sep 22, 2023
1 parent d3c8f7e commit 9885866
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 3 additions & 5 deletions js/web/lib/backend-wasm-with-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
/* eslint-disable @typescript-eslint/no-unused-vars */
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
/* eslint-enable @typescript-eslint/no-unused-vars */
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
6 changes: 3 additions & 3 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ const isSimdSupported = (): boolean => {
}
};

const getWasmFileName = (useSimd: boolean, useThreads: boolean, useTraining: boolean) => {
const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useSimd) {
if (useTraining) {
if (BUILD_DEFS.ENABLE_TRAINING) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
Expand Down Expand Up @@ -111,7 +111,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise

const wasmPaths = flags.wasmPaths;
const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined;
const wasmFileName = getWasmFileName(useSimd, useThreads, BUILD_DEFS.ENABLE_TRAINING);
const wasmFileName = getWasmFileName(useSimd, useThreads);
const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined;

let isTimeout = false;
Expand Down
2 changes: 2 additions & 0 deletions js/web/webpack.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ module.exports = () => {
build_defs: {
...DEFAULT_BUILD_DEFS,
ENABLE_TRAINING: true,
DISABLE_WEBGL: true,
DISABLE_WEBGPU: true,
}
}),
);
Expand Down

0 comments on commit 9885866

Please sign in to comment.