Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
carzh committed Oct 11, 2023
1 parent e775933 commit c745a65
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 36 deletions.
17 changes: 8 additions & 9 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
import { resolveBackend } from './backend-impl.js';

type SessionOptions = InferenceSession.SessionOptions;
const noBackendErrMsg: string = "Training backend could not be resolved. " +
"Make sure you\'re using the correct configuration & WebAssembly files.";
const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
Expand All @@ -23,7 +23,7 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}

static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
let checkpointState: string|Uint8Array = trainingOptions.checkpointState;
let trainModel: string|Uint8Array = trainingOptions.trainModel;
Expand All @@ -36,11 +36,10 @@ export class TrainingSession implements TrainingSessionInterface {
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler =
await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
}
else {
const handler =
await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
} else {
throw new Error(noBackendErrMsg);
}
}
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBacken
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number;

_OrtTrainingReleaseSession?(trainingHandle: number): void;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ export const isOrtEnvInitialized = async(): Promise<boolean> => {
} else {
return core.isOrtEnvInitialized();
}
}
};
5 changes: 3 additions & 2 deletions js/web/lib/wasm/session-handler-for-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}

async runTrainStep(
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
throw new Error('Method not implemented yet.');
throw new Error('Method not implemented yet.');
}
}
4 changes: 2 additions & 2 deletions js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitializationPromise: Promise<void>|undefined;
Expand Down Expand Up @@ -56,7 +56,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!isOrtEnvInitialized()) {
if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
Expand Down
39 changes: 19 additions & 20 deletions js/web/lib/wasm/wasm-training-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
// import {InferenceSession, Tensor} from 'onnxruntime-common';
import {InferenceSession} from 'onnxruntime-common';

import {SerializableModeldata, SerializableSessionMetadata } from './proxy-messages';
// import {setRunOptions} from './run-options';
import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
import {setSessionOptions} from './session-options';
// import {tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {checkLastError} from './wasm-utils';
// import {allocWasmString, checkLastError} from './wasm-utils';
// import { prepareInputOutputTensor } from './wasm-core-impl';

const NO_TRAIN_FUNCS_MSG = 'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.');
const NO_TRAIN_FUNCS_MSG =
'Built without training APIs enabled. Make sure to use the onnxruntime-training package for training functionality.';

export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
const wasm = getInstance();
Expand Down Expand Up @@ -93,7 +90,7 @@ export const createTrainingSessionHandle =
}
};

const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
const [inputCount, outputCount] = getTrainingModelInputOutputCount(trainingSessionId);

const [inputNames, inputNamesUTF8Encoded] = getTrainingNamesLoop(trainingSessionId, inputCount, true);
Expand All @@ -102,7 +99,7 @@ export const createTrainingSessionHandle =
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
}

const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => {
const getTrainingModelInputOutputCount = (trainingSessionId: number): [number, number] => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
Expand Down Expand Up @@ -143,15 +140,17 @@ const getTrainingNamesLoop = (trainingSessionId: number, count: number, isInput:
return [names, namesUTF8Encoded];
}

export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => {
const wasm = getInstance();
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));

if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
}
export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
void => {
const wasm = getInstance();
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));

if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
}

0 comments on commit c745a65

Please sign in to comment.