Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] Add CreateTrainingSession #17891

Merged
merged 24 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cff6a7f
added training build configuration
carzh Aug 29, 2023
7a1365d
edited wasm factory + added training backend
carzh Aug 29, 2023
71f9dbc
added package.json modification for training + minimized training art…
carzh Aug 31, 2023
5adbcf4
applied suggestions
carzh Sep 22, 2023
9068ab4
create training session implementation + supporting changes
carzh Oct 6, 2023
edbe9cc
Merge remote-tracking branch 'origin/main' into carzh/training-wasm-b…
fs-eire Oct 6, 2023
5202e3b
fixed variable names + enforced wasmBackend being a singleton with su…
carzh Oct 6, 2023
a277347
format + lint
carzh Oct 7, 2023
adfa5cf
Merge branch 'main' into carzh/training-wasm-binding
carzh Oct 11, 2023
c9931ad
Merge branch 'carzh/training-wasm-binding' into carzh/create-training…
carzh Oct 11, 2023
e775933
minor tweak to remove placeholder return statement
carzh Oct 11, 2023
c745a65
format
carzh Oct 11, 2023
beca6dc
lint fixes
carzh Oct 11, 2023
b37e77f
lint + format
carzh Oct 11, 2023
8c1e227
Merge branch 'main' into carzh/create-training-session
carzh Oct 12, 2023
d2f3d88
added isOrtEnvInitialized case to proxy wrapper
carzh Oct 12, 2023
712e73c
Merge remote-tracking branch 'refs/remotes/origin/carzh/create-traini…
carzh Oct 12, 2023
05a708f
fixed proxy wrapper case statement
carzh Oct 13, 2023
daa1023
updated getInputOutputCount and getInputOutputNames signature, added …
carzh Oct 23, 2023
389f08e
updated parameter names according to suggestions
carzh Oct 23, 2023
70505c4
format + lint
carzh Oct 24, 2023
a5bc60c
lint fix -- changed multiline string to singlequote string
carzh Oct 24, 2023
ae936e5
Apply naming suggestions from code review
carzh Oct 25, 2023
316a6e7
implemented naming suggestions + fixed training session & checkpoint …
carzh Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';

type SessionOptions = InferenceSession.SessionOptions;
const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
Expand All @@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}

static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
askhade marked this conversation as resolved.
Show resolved Hide resolved
return new TrainingSession(handler);
} else {
throw new Error(noBackendErrMsg);
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
Expand Down
12 changes: 8 additions & 4 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
}
}

Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetInputOutputCount?
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
_OrtTrainingGetInputOutputName?
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;

_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion

Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}

interface MessageIsOrtEnvInitialized extends MessageError {
askhade marked this conversation as resolved.
Show resolved Hide resolved
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
10 changes: 9 additions & 1 deletion js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// <reference lib="webworker" />

import {OrtWasmMessage} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
Expand Down Expand Up @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
case 'is-ort-env-initialized':
try {
const ortEnvInitialized = isOrtEnvInitialized();
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
}
break;
default:
}
};
21 changes: 21 additions & 0 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];

const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
Expand Down Expand Up @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
endProfilingCallbacks.shift()![0]();
}
break;
case 'is-ort-env-initialized':
if (ev.data.err) {
isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
} else {
isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
}
break;
default:
}
};
Expand Down Expand Up @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise<void> => {
core.endProfiling(sessionId);
}
};

export const isOrtEnvInitialized = async(): Promise<boolean> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<boolean>((resolve, reject) => {
isOrtEnvInitializedCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
proxyWorker!.postMessage(message);
});
} else {
return core.isOrtEnvInitialized();
}
};
73 changes: 73 additions & 0 deletions js/web/lib/wasm/session-handler-for-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata} from './proxy-messages';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

inputNames: string[];
outputNames: string[];

inputEncodedNames: number[];
outputEncodedNames: number[];

async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return createSessionAllocate(buffer);
}

async createTrainingSession(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions) {
if (!isOrtEnvInitialized()) {
await initRuntime(env);
}
const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableModeldata = [0, 0];
let optimizerModelData: SerializableModeldata = [0, 0];

if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}

this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}

async runTrainStep(
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
throw new Error('Method not implemented yet.');
}
}
6 changes: 2 additions & 4 deletions js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise<void>|undefined;

const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
Expand Down Expand Up @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!runtimeInitialized) {
if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
runtimeInitialized = true;
}

if (typeof pathOrBuffer === 'string') {
Expand Down
6 changes: 6 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';

let ortEnvInitialized = false;

/**
* get the input/output count of the session.
* @param sessionHandle the handle representing the session. should be non-zero.
Expand Down Expand Up @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env);
}

ortEnvInitialized = true;
};

/**
Expand Down Expand Up @@ -93,6 +97,8 @@ type SessionMetadata = [

const activeSessions = new Map<number, SessionMetadata>();

export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;

/**
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
Expand Down
Loading
Loading