From c745a6570e5b7c46d573e57302e52f9570984799 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 11 Oct 2023 14:26:18 -0700 Subject: [PATCH] format --- js/common/lib/training-session-impl.ts | 17 ++++---- js/web/lib/backend-wasm-training.ts | 2 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/proxy-wrapper.ts | 2 +- .../lib/wasm/session-handler-for-training.ts | 5 ++- js/web/lib/wasm/session-handler.ts | 4 +- js/web/lib/wasm/wasm-training-core-impl.ts | 39 +++++++++---------- 7 files changed, 35 insertions(+), 36 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index bde1fb0c373ea..3cb6ec572c9a6 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; -import { resolveBackend } from './backend-impl.js'; type SessionOptions = InferenceSession.SessionOptions; -const noBackendErrMsg: string = "Training backend could not be resolved. " + - "Make sure you\'re using the correct configuration & WebAssembly files."; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -23,7 +23,7 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { let checkpointState: string|Uint8Array = trainingOptions.checkpointState; let trainModel: string|Uint8Array = trainingOptions.trainModel; @@ -36,11 +36,10 @@ export class TrainingSession implements TrainingSessionInterface { const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); if (backend.createTrainingSessionHandler) { - const handler = - await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); - } - else { + const handler = + await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { throw new Error(noBackendErrMsg); } } diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index 1796259d0f335..98e40807aa29c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -11,7 +11,7 @@ class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBacken checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); await handler.createTrainingSession( checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index c2b2bdf628d85..060fb1e756ef9 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,7 +102,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number; _OrtTrainingReleaseSession?(trainingHandle: number): void; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 6bb1bb733f194..55edd2106130d 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -264,4 +264,4 @@ export const isOrtEnvInitialized = async(): Promise => { } else { return core.isOrtEnvInitialized(); } -} +}; diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts index 5f694371919fc..e0a9db97b951d 100644 --- a/js/web/lib/wasm/session-handler-for-training.ts +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -61,12 +61,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); } async runTrainStep( _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, _options: InferenceSession.RunOptions): Promise { - throw new Error('Method not implemented yet.'); + throw new Error('Method not implemented yet.'); } } diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index 2e2239b948b0f..a5017a920f38b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,7 +5,7 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitializationPromise: Promise|undefined; @@ -56,7 +56,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!isOrtEnvInitialized()) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c55b4976c9d0d..7de1601cd73b5 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,16 +3,13 @@ // import {InferenceSession, Tensor} from 'onnxruntime-common'; import {InferenceSession} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages'; -// import {setRunOptions} from './run-options'; +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; import {setSessionOptions} from './session-options'; -// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; -// import {allocWasmString, checkLastError} from './wasm-utils'; -// import { prepareInputOutputTensor } from './wasm-core-impl'; -const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'); +const NO_TRAIN_FUNCS_MSG = + 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.'; export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -93,7 +90,7 @@ export const createTrainingSessionHandle = } }; - const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId); const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true); @@ -102,7 +99,7 @@ export const createTrainingSessionHandle = return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; } - const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { +const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { @@ -143,15 +140,17 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput: return [names, namesUTF8Encoded]; } -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } -} +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + }