Skip to content

Commit

Permalink
[js/web/training] Add CreateTrainingSession (microsoft#17891)
Browse files Browse the repository at this point in the history
### Description
* Adds TrainingSession.create() functionality following the web bindings
for training design doc
* Added 2 new training APIs to wasm/api.h:
   * OrtTrainingGetInputOutputName
   * OrtTrainingGetInputOutputCount
* Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a
method that references it

### Motivation and Context
* Adding web bindings for training

#### Related work
* microsoft#16521 allowed for training artifacts to be built
* microsoft#17333 added interfaces for training
* microsoft#17474 allows for training package to be built + adds training backend
to web package **[MUST BE MERGED IN BEFORE THIS ONE]**

---------

Co-authored-by: Yulong Wang <[email protected]>
Co-authored-by: Ashwini Khade <[email protected]>
  • Loading branch information
3 people authored and kleiti committed Mar 22, 2024
1 parent 5d6a16c commit 2f97b4d
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 12 deletions.
21 changes: 19 additions & 2 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';

type SessionOptions = InferenceSession.SessionOptions;
const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
Expand All @@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}

static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
} else {
throw new Error(noBackendErrMsg);
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
Expand Down
12 changes: 8 additions & 4 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
}
}

Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetModelInputOutputCount?
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
_OrtTrainingGetModelInputOutputName?
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;

_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion

Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}

interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
10 changes: 9 additions & 1 deletion js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// <reference lib="webworker" />

import {OrtWasmMessage} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
Expand Down Expand Up @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
case 'is-ort-env-initialized':
try {
const ortEnvInitialized = isOrtEnvInitialized();
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
}
break;
default:
}
};
21 changes: 21 additions & 0 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];

const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
Expand Down Expand Up @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
endProfilingCallbacks.shift()![0]();
}
break;
case 'is-ort-env-initialized':
if (ev.data.err) {
isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
} else {
isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
}
break;
default:
}
};
Expand Down Expand Up @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise<void> => {
core.endProfiling(sessionId);
}
};

export const isOrtEnvInitialized = async(): Promise<boolean> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<boolean>((resolve, reject) => {
isOrtEnvInitializedCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
proxyWorker!.postMessage(message);
});
} else {
return core.isOrtEnvInitialized();
}
};
73 changes: 73 additions & 0 deletions js/web/lib/wasm/session-handler-for-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata} from './proxy-messages';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

inputNames: string[];
outputNames: string[];

inputEncodedNames: number[];
outputEncodedNames: number[];

async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return createSessionAllocate(buffer);
}

async createTrainingSession(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions) {
if (!isOrtEnvInitialized()) {
await initRuntime(env);
}
const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableModeldata = [0, 0];
let optimizerModelData: SerializableModeldata = [0, 0];

if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}

this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}

async runTrainStep(
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
throw new Error('Method not implemented yet.');
}
}
6 changes: 2 additions & 4 deletions js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise<void>|undefined;

const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
Expand Down Expand Up @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!runtimeInitialized) {
if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
runtimeInitialized = true;
}

if (typeof pathOrBuffer === 'string') {
Expand Down
6 changes: 6 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';

let ortEnvInitialized = false;

/**
* get the input/output count of the session.
* @param sessionHandle the handle representing the session. should be non-zero.
Expand Down Expand Up @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env);
}

ortEnvInitialized = true;
};

/**
Expand Down Expand Up @@ -93,6 +97,8 @@ type SessionMetadata = [

const activeSessions = new Map<number, SessionMetadata>();

export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;

/**
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
Expand Down
Loading

0 comments on commit 2f97b4d

Please sign in to comment.