From 453902f37527402b1f3303ffef4ce43950c57d03 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Thu, 26 Oct 2023 09:22:10 -0700 Subject: [PATCH] [js/web/training] Add CreateTrainingSession (#17891) ### Description * Adds TrainingSession.create() functionality following the web bindings for training design doc * Added 2 new training APIs to wasm/api.h: * OrtTrainingGetInputOutputName * OrtTrainingGetInputOutputCount * Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a method that references it ### Motivation and Context * Adding web bindings for training #### Related work * #16521 allowed for training artifacts to be built * #17333 added interfaces for training * #17474 allows for training package to be built + adds training backend to web package **[MUST BE MERGED IN BEFORE THIS ONE]** --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Ashwini Khade --- common/lib/training-session-impl.ts | 21 ++- web/lib/backend-wasm-training.ts | 12 +- web/lib/wasm/binding/ort-wasm.d.ts | 5 + web/lib/wasm/proxy-messages.ts | 7 +- web/lib/wasm/proxy-worker/main.ts | 10 +- web/lib/wasm/proxy-wrapper.ts | 21 +++ web/lib/wasm/session-handler-for-training.ts | 73 +++++++++ web/lib/wasm/session-handler.ts | 6 +- web/lib/wasm/wasm-core-impl.ts | 6 + web/lib/wasm/wasm-training-core-impl.ts | 162 +++++++++++++++++++ 10 files changed, 311 insertions(+), 12 deletions(-) create mode 100644 web/lib/wasm/session-handler-for-training.ts create mode 100644 web/lib/wasm/wasm-training-core-impl.ts diff --git a/common/lib/training-session-impl.ts b/common/lib/training-session-impl.ts index f06d06bda035f..47e67879e66ce 100644 --- a/common/lib/training-session-impl.ts +++ b/common/lib/training-session-impl.ts @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { + throw new Error(noBackendErrMsg); + } } async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { diff --git a/web/lib/backend-wasm-training.ts b/web/lib/backend-wasm-training.ts index af5b575c87a7f..98e40807aa29c 100644 --- a/web/lib/backend-wasm-training.ts +++ b/web/lib/backend-wasm-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/web/lib/wasm/binding/ort-wasm.d.ts b/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff4537095..00431a4e86d5b 100644 --- a/web/lib/wasm/binding/ort-wasm.d.ts +++ b/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/web/lib/wasm/proxy-messages.ts b/web/lib/wasm/proxy-messages.ts index 7aa866773bcb1..efeb086256cf3 100644 --- a/web/lib/wasm/proxy-messages.ts +++ b/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/web/lib/wasm/proxy-worker/main.ts b/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b191..1f4595818e5c0 100644 --- a/web/lib/wasm/proxy-worker/main.ts +++ b/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/web/lib/wasm/proxy-wrapper.ts b/web/lib/wasm/proxy-wrapper.ts index a3e4a1ef1fc75..069a1fa452dbc 100644 --- a/web/lib/wasm/proxy-wrapper.ts +++ b/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } + break; default: } }; @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +}; diff --git a/web/lib/wasm/session-handler-for-training.ts b/web/lib/wasm/session-handler-for-training.ts new file mode 100644 index 0000000000000..83d133b9a5157 --- /dev/null +++ b/web/lib/wasm/session-handler-for-training.ts @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata} from './proxy-messages'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } + + async runTrainStep( + _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} diff --git a/web/lib/wasm/session-handler.ts b/web/lib/wasm/session-handler.ts index d1760e37c93f7..a5017a920f38b 100644 --- a/web/lib/wasm/session-handler.ts +++ b/web/lib/wasm/session-handler.ts @@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/web/lib/wasm/wasm-core-impl.ts b/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e3..947242945c665 100644 --- a/web/lib/wasm/wasm-core-impl.ts +++ b/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -93,6 +97,8 @@ type SessionMetadata = [ const activeSessions = new Map(); +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; + /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. * @returns a 2-elements tuple - the pointer and size of the allocated buffer diff --git a/web/lib/wasm/wasm-training-core-impl.ts b/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..4830b5d2b5e80 --- /dev/null +++ b/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import {setSessionOptions} from './session-options'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; + +const NO_TRAIN_FUNCS_MSG = + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + const [checkpointDataOffset, checkpointDataLength] = checkpointData; + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); + + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + };