From 8361b7eae9ff9dcf8ef9d54965ea954be6326d0d Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 10:07:07 -0700 Subject: [PATCH 01/12] implemented parameters methods + cleaned up error handling --- js/common/lib/backend.ts | 5 +- js/common/lib/training-session-impl.ts | 18 ++- js/common/lib/training-session.ts | 13 +- js/web/lib/wasm/session-handler-training.ts | 20 ++- js/web/lib/wasm/wasm-training-core-impl.ts | 169 +++++++++++++++++--- 5 files changed, 190 insertions(+), 35 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index dd04ef3f15997..fd2e8bb74bbf5 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler { feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; - getContiguousParameters(trainableOnly: boolean): Promise; + getParametersSize(trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; } /** diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index ee6d26b22b1f6..a9a9d42e2a594 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -176,12 +176,22 @@ export class TrainingSession implements TrainingSessionInterface { return this.convertHandlerReturnTypeToMapOfTensors(results); } - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async getParametersSize(trainableOnly: boolean): Promise { + return this.handler.getParametersSize(trainableOnly); } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + const paramsSize = await this.getParametersSize(trainableOnly); + if (array.length !== paramsSize) { + throw new Error( + 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + + 'the model. Please use getParametersSize method to check.'); + } + return this.handler.loadParametersBuffer(array, trainableOnly); + } + + async getContiguousParameters(trainableOnly: boolean): Promise { + return this.handler.getContiguousParameters(trainableOnly); } async release(): Promise { diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 0967d79b33434..40ea16cf05ce4 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {InferenceSession} from './inference-session.js'; +import {OnnxValue} from './onnx-value.js'; import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; /* eslint-disable @typescript-eslint/no-redeclare */ @@ -49,13 +50,21 @@ export interface TrainingSession { // #endregion // #region copy parameters + + /** + * Retrieves the size of all parameters for the training state. + * + * @param trainableOnly skips non-trainable parameters when true. + */ + getParametersSize(trainableOnly: boolean): Promise; + /** * Copies from a buffer containing parameters to the TrainingSession parameters. * * @param buffer - buffer containing parameters * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; /** * Copies from the TrainingSession parameters to a buffer. @@ -63,7 +72,7 @@ export interface TrainingSession { * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. * @returns A promise that resolves to a buffer of the requested parameters. */ - getContiguousParameters(trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; // #endregion // #region release() diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 09d91591128d1..3fe0eefb830af 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -6,15 +6,9 @@ import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} f import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } - async getContiguousParameters(_trainableOnly: boolean): Promise { - throw new Error('Method not implemented.'); - } private sessionId: number; private checkpointId: number; @@ -124,6 +118,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } + async getParametersSize(trainableOnly: boolean): Promise { + return getParametersSize(this.sessionId, trainableOnly); + } + + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + await loadParametersBuffer(this.sessionId, array, trainableOnly); + } + async getContiguousParameters(trainableOnly: boolean): Promise { + const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); + return decodeTensorMetadata(tensorResult); + } + async dispose(): Promise { return releaseTrainingSessionAndCheckpoint( this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index a35d285346db4..dfad93099ea0d 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -6,7 +6,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {tensorDataTypeEnumToString, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; @@ -16,6 +16,22 @@ const NO_TRAIN_FUNCS_MSG = 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; +/** + * Runs the checkLastError function which will throw an error, if the provided error code matches the specified + * pattern for an error code. + * @param errCode number to evaluated for if it's an erro + * @param message message to pass into checkLastError + * @param checkNeqZero when true, treats not equal to zero as an error. + * When false, treats equal to zero as an error. + */ +const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { + if (checkNeqZero && errCode !== 0) { + checkLastError(message); + } else if (!checkNeqZero && errCode === 0) { + checkLastError(message); + } +}; + export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { const wasm = getInstance(); @@ -29,9 +45,7 @@ export const createCheckpointHandle = (checkpointData: SerializableModeldata): n throw new Error(NO_TRAIN_FUNCS_MSG); } - if (checkpointHandle === 0) { - checkLastError('Error occurred when trying to create a CheckpointState.'); - } + ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); return checkpointHandle; } catch (e) { if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { @@ -52,9 +66,7 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } + ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); @@ -74,9 +86,7 @@ const getModelInputOutputNamesLoop = for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - if (name === 0) { - checkLastError('Can\'t get input or output name'); - } + ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); @@ -122,9 +132,7 @@ export const createTrainingSessionHandle = throw new Error(NO_TRAIN_FUNCS_MSG); } - if (trainingSessionHandle === 0) { - checkLastError('Error occurred when trying to create a TrainingSession.'); - } + ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = getTrainingModelInputOutputNames(trainingSessionHandle); @@ -213,9 +221,8 @@ const moveOutputToTensorMetadataArr = try { const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); - } + ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); + let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; dataOffset = wasm.HEAPU32[tensorDataIndex++]; @@ -290,10 +297,7 @@ export const runTrainStep = async( if (wasm._OrtTrainingRunTrainStep) { const errorCode = wasm._OrtTrainingRunTrainStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); - - if (errorCode !== 0) { - checkLastError('failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -313,6 +317,131 @@ export const runTrainStep = async( } }; +export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + try { + const sizeOffset = wasm.stackAlloc(4); + if (wasm._OrtTrainingGetParametersSize) { + const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); + ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); + + return wasm.HEAP32[sizeOffset / 4]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +export const getContiguousParameters = + async(trainingSessionId: number, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const parametersSize = getParametersSize(trainingSessionId, trainableOnly); + let tensor = 0; + + const paramsByteLength = 4 * parametersSize; + const paramsOffset = wasm.stackAlloc(paramsByteLength); + wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); + + const tensorOffset = wasm.stackAlloc(paramsOffset / 4); + + // handles the dimensions-related createTensor parameters + const dims = [parametersSize]; + + const dimsOffset = wasm.stackAlloc(4); + const dimsIndex = dimsOffset / 4; + wasm.HEAP32[dimsIndex] = parametersSize; + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError( + tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); + + wasm.HEAPU32[tensorOffset] = tensor; + if (wasm._OrtTrainingCopyParametersToBuffer) { + const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); + + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); + const data = new typedArrayConstructor(parametersSize); + const output: TensorMetadata[] = []; + new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); + output.push([tensorTypeAsString, dims, data, locationAsString]); + if (output.length > 1 || output.length < 1) { + throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of + one, got ${output.length}`); + } else { + return output[0]; + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm._free(paramsOffset); + wasm._free(dimsOffset); + wasm._free(tensorOffset); + wasm.stackRestore(stack); + } +}; + +export const loadParametersBuffer = + async(trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): Promise => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + + const tensorTypeAsString = 'float32'; + const locationAsString = 'cpu'; + + const bufferCount = buffer.length; + const bufferByteLength = bufferCount * 4; + const bufferOffset = wasm.stackAlloc(bufferByteLength); + wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + const dimsOffset = wasm.stackAlloc(4); + wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsLength = 1; + let tensor = 0; + const bufferAlloc = wasm.stackAlloc(bufferOffset / 4); + + try { + tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, + dataLocationStringToEnum(locationAsString)); + ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); + + wasm.HEAPU32[bufferAlloc] = tensor; + + if (wasm._OrtTrainingCopyParametersFromBuffer) { + const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); + ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (tensor !== 0) { + wasm._OrtReleaseTensor(tensor); + } + wasm.stackRestore(stack); + wasm._free(bufferAlloc); + wasm._free(bufferOffset); + wasm._free(dimsOffset); + } +}; + export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): void => { From 354f7adfd7b613da5470ccdabdad143c796e8f77 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 13:36:53 -0700 Subject: [PATCH 02/12] applied suggested doc fixes --- js/common/lib/training-session.ts | 12 +++++++----- js/web/lib/wasm/wasm-training-core-impl.ts | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 40ea16cf05ce4..c71bc12a7d53f 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -52,14 +52,15 @@ export interface TrainingSession { // #region copy parameters /** - * Retrieves the size of all parameters for the training state. + * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of + * the parameters) elements of all the parameters in the training state. * - * @param trainableOnly skips non-trainable parameters when true. + * @param trainableOnly - When set to true, the size is calculated for trainable params only. */ getParametersSize(trainableOnly: boolean): Promise; /** - * Copies from a buffer containing parameters to the TrainingSession parameters. + * Copies parameter values from the given array to the training state. * * @param buffer - buffer containing parameters * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. @@ -67,9 +68,10 @@ export interface TrainingSession { loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; /** - * Copies from the TrainingSession parameters to a buffer. + * Copies from the TrainingSession parameters to a contiguous buffer. * - * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. + * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters + * for which requires_grad is set to true. * @returns A promise that resolves to a buffer of the requested parameters. */ getContiguousParameters(trainableOnly: boolean): Promise; diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index dfad93099ea0d..80e6aac40feca 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -19,7 +19,7 @@ const NO_TRAIN_FUNCS_MSG = /** * Runs the checkLastError function which will throw an error, if the provided error code matches the specified * pattern for an error code. - * @param errCode number to evaluated for if it's an erro + * @param errCode number to evaluated for if it's an error * @param message message to pass into checkLastError * @param checkNeqZero when true, treats not equal to zero as an error. * When false, treats equal to zero as an error. From 3cbb01a652deb03c292637cb9af4bb28c76f5bf4 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 15:42:52 -0700 Subject: [PATCH 03/12] added runEvalStep and runOptimizerStep --- js/common/lib/backend.ts | 7 + js/common/lib/training-session-impl.ts | 68 ++++++++-- js/common/lib/training-session.ts | 53 +++++++- js/web/lib/wasm/session-handler-training.ts | 38 +++++- js/web/lib/wasm/wasm-training-core-impl.ts | 139 ++++++++++++++------ 5 files changed, 243 insertions(+), 62 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index fd2e8bb74bbf5..3f3efdaa140eb 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -45,9 +45,16 @@ export interface InferenceSessionHandler extends SessionHandler { * @ignore */ export interface TrainingSessionHandler extends SessionHandler { + readonly evalInputNames: readonly string[]; + readonly evalOutputNames: readonly string[]; + runTrainStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; + runOptimizerStep(options: InferenceSession.RunOptions): Promise; + runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index a9a9d42e2a594..77abc204b1997 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler) { + private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { this.handler = handler; + this.hasOptimizerModel = hasOptimizerModel; + this.hasEvalModel = hasEvalModel; } private handler: TrainingSessionHandler; + private hasOptimizerModel: boolean; + private hasEvalModel: boolean; - get inputNames(): readonly string[] { + get trainingInputNames(): readonly string[] { return this.handler.inputNames; } - get outputNames(): readonly string[] { + get trainingOutputNames(): readonly string[] { return this.handler.outputNames; } + get evalInputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalInputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + get evalOutputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalOutputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; @@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface { if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); + return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); } @@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface { * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from * the given parameters to SessionHandler.FetchesType and RunOptions. * + * @param inputNames the feeds object is checked that they contain all input names in the provided list of input + * names. + * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output + * names. * @param feeds the required input * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object * @param arg2 optional RunOptions object. * @returns */ - typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): - [SessionHandler.FetchesType, RunOptions] { + typeNarrowingForRunStep( + inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions, + arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] { const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; // check inputs @@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface { if (typeof name !== 'string') { throw new TypeError('\'fetches\' must be a string array or an object.'); } - if (this.outputNames.indexOf(name) === -1) { + if (outputNames.indexOf(name) === -1) { throw new RangeError(`'fetches' contains invalid output name: ${name}.`); } fetches[name] = null; @@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface { // if any output name is present and its value is valid OnnxValue, we consider it fetches let isFetches = false; const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of this.outputNames) { + for (const name of outputNames) { if (arg1Keys.indexOf(name) !== -1) { const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; if (v === null || v instanceof Tensor) { @@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface { } // check if all inputs are in feed - for (const name of this.inputNames) { + for (const name of inputNames) { if (typeof feeds[name] === 'undefined') { throw new Error(`input '${name}' is missing in 'feeds'.`); } @@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface { // if no fetches is specified, we use the full output names list if (isFetchesEmpty) { - for (const name of this.outputNames) { + for (const name of outputNames) { fetches[name] = null; } } @@ -171,11 +195,33 @@ export class TrainingSession implements TrainingSessionInterface { runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const [fetches, options] = + this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2); const results = await this.handler.runTrainStep(feeds, fetches, options); return this.convertHandlerReturnTypeToMapOfTensors(results); } + async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise { + if (this.hasOptimizerModel) { + await this.handler.runOptimizerStep(options || {}); + } else { + throw new Error('This TrainingSession has no OptimizerModel loaded.'); + } + } + + runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise; + runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise; + async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + if (this.hasEvalModel) { + const [fetches, options] = + this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2); + const results = await this.handler.runEvalStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } else { + throw new Error('This TrainingSession has no EvalModel loaded.'); + } + } + async getParametersSize(trainableOnly: boolean): Promise { return this.handler.getParametersSize(trainableOnly); } diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index c71bc12a7d53f..d1387e44184b2 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -39,7 +39,7 @@ export interface TrainingSession { * @param feeds - Representation of the model input. * @param fetches - Representation of the model output. * detail. - * @param options - Optional. A set of options that controls the behavior of model inference. + * @param options - Optional. A set of options that controls the behavior of model training. * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ @@ -47,6 +47,38 @@ export interface TrainingSession { feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, options?: InferenceSession.RunOptions): Promise; + /** + * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. + * + * @param options - Optional. A set of options that controls the behavior of model optimizing. + */ + runOptimizerStep(options?: InferenceSession.RunOptions): Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + // #endregion // #region copy parameters @@ -88,14 +120,25 @@ export interface TrainingSession { // #region metadata /** - * Get input names of the loaded model. + * Get input names of the loaded training model. */ - readonly inputNames: readonly string[]; + readonly trainingInputNames: readonly string[]; /** - * Get output names of the loaded model. + * Get output names of the loaded training model. */ - readonly outputNames: readonly string[]; + readonly trainingOutputNames: readonly string[]; + + /** + * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. + */ + readonly evalInputNames: readonly string[]; + + /** + * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. + */ + readonly evalOutputNames: readonly string[]; + // #endregion } diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 3fe0eefb830af..84b999675ea2c 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number; @@ -15,8 +15,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes inputNames: string[]; outputNames: string[]; - inputEncodedNames: number[]; - outputEncodedNames: number[]; + evalInputNames: string[] = []; + evalOutputNames: string[] = []; async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; @@ -51,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } this.checkpointId = createCheckpointHandle(checkpointData); - [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + this.sessionId = createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); + if (evalModelUriOrBuffer !== '') { + [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); + } } /** @@ -118,6 +122,27 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } + async runOptimizerStep(options: InferenceSession.RunOptions): Promise { + await runOptimizerStep(this.sessionId, options); + } + + async runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.evalInputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.evalOutputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null); + + const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + async getParametersSize(trainableOnly: boolean): Promise { return getParametersSize(this.sessionId, trainableOnly); } @@ -131,7 +156,6 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint( - this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); } } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 80e6aac40feca..37e226aff2d9c 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,7 +3,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; @@ -77,50 +77,43 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea }; const getModelInputOutputNamesLoop = - (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => { const names = []; const wasm = getInstance(); - const namesUTF8Encoded = []; - for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } - return [names, namesUTF8Encoded]; + return names; }; -const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); +export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { + let inputNames: string[] = []; + let outputNames: string[] = []; + + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); - const [outputNames, outputNamesUTF8Encoded] = - getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); + outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; + return [inputNames, outputNames]; }; export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, - optimizerModelData: SerializableModeldata, - options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; let sessionOptionsHandle = 0; let allocs: number[] = []; - let inputNamesUTF8Encoded: number[] = []; - let outputNamesUTF8Encoded: number[] = []; - - let inputNames: string[] = []; - let outputNames: string[] = []; try { [sessionOptionsHandle, allocs] = setSessionOptions(options); @@ -133,11 +126,7 @@ export const createTrainingSessionHandle = } ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - - [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = - getTrainingModelInputOutputNames(trainingSessionHandle); - return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; - + return trainingSessionHandle; } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); @@ -152,8 +141,6 @@ export const createTrainingSessionHandle = wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } allocs.forEach(alloc => wasm._free(alloc)); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); } }; @@ -317,6 +304,83 @@ export const runTrainStep = async( } }; +export const runOptimizerStep = + async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + try { + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + if (wasm._OrtTrainingOptimizerStep) { + const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); + ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const runEvalStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingEvalStep) { + const errorCode = wasm._OrtTrainingEvalStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { const wasm = getInstance(); const stack = wasm.stackSave(); @@ -443,16 +507,13 @@ export const loadParametersBuffer = }; export const releaseTrainingSessionAndCheckpoint = - (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): - void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - }; + (checkpointId: number, sessionId: number): void => { + const wasm = getInstance(); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + }; From 423560876e3df3f2a90523c2167590b3faf4e201 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 15:53:03 -0700 Subject: [PATCH 04/12] made sure to free the encoded name --- js/web/lib/wasm/wasm-training-core-impl.ts | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 37e226aff2d9c..dd1b38d09fe81 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -87,6 +87,7 @@ const getModelInputOutputNamesLoop = ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); names.push(wasm.UTF8ToString(name)); + wasm._free(name); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -506,14 +507,13 @@ export const loadParametersBuffer = } }; -export const releaseTrainingSessionAndCheckpoint = - (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); +export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { + const wasm = getInstance(); - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - }; + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } +}; From bbf70c5ebf224d46b710cf4565eb300918c537bd Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 16:13:52 -0700 Subject: [PATCH 05/12] added default value of true for parameters methods --- js/common/lib/training-session-impl.ts | 6 +++--- js/common/lib/training-session.ts | 6 +++--- js/web/lib/wasm/session-handler-training.ts | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index a9a9d42e2a594..8e3e9b6eaa1eb 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -176,11 +176,11 @@ export class TrainingSession implements TrainingSessionInterface { return this.convertHandlerReturnTypeToMapOfTensors(results); } - async getParametersSize(trainableOnly: boolean): Promise { + async getParametersSize(trainableOnly: boolean = true): Promise { return this.handler.getParametersSize(trainableOnly); } - async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + async loadParametersBuffer(array: Float32Array, trainableOnly: boolean = true): Promise { const paramsSize = await this.getParametersSize(trainableOnly); if (array.length !== paramsSize) { throw new Error( @@ -190,7 +190,7 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.loadParametersBuffer(array, trainableOnly); } - async getContiguousParameters(trainableOnly: boolean): Promise { + async getContiguousParameters(trainableOnly: boolean = true): Promise { return this.handler.getContiguousParameters(trainableOnly); } diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index c71bc12a7d53f..a7919b81b22ae 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -55,7 +55,7 @@ export interface TrainingSession { * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of * the parameters) elements of all the parameters in the training state. * - * @param trainableOnly - When set to true, the size is calculated for trainable params only. + * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true. */ getParametersSize(trainableOnly: boolean): Promise; @@ -63,7 +63,7 @@ export interface TrainingSession { * Copies parameter values from the given array to the training state. * * @param buffer - buffer containing parameters - * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; @@ -71,7 +71,7 @@ export interface TrainingSession { * Copies from the TrainingSession parameters to a contiguous buffer. * * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters - * for which requires_grad is set to true. + * for which requires_grad is set to true. Default value is true. * @returns A promise that resolves to a buffer of the requested parameters. */ getContiguousParameters(trainableOnly: boolean): Promise; diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 3fe0eefb830af..d09e21842b337 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; From 5e175529da29fc4cda954e98e91520d1ea26cdb0 Mon Sep 17 00:00:00 2001 From: carzh Date: Thu, 2 Nov 2023 16:16:39 -0700 Subject: [PATCH 06/12] lint --- js/common/lib/training-session-impl.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 8e3e9b6eaa1eb..e4c40cc9d44fb 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -176,11 +176,11 @@ export class TrainingSession implements TrainingSessionInterface { return this.convertHandlerReturnTypeToMapOfTensors(results); } - async getParametersSize(trainableOnly: boolean = true): Promise { + async getParametersSize(trainableOnly = true): Promise { return this.handler.getParametersSize(trainableOnly); } - async loadParametersBuffer(array: Float32Array, trainableOnly: boolean = true): Promise { + async loadParametersBuffer(array: Float32Array, trainableOnly = true): Promise { const paramsSize = await this.getParametersSize(trainableOnly); if (array.length !== paramsSize) { throw new Error( @@ -190,7 +190,7 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.loadParametersBuffer(array, trainableOnly); } - async getContiguousParameters(trainableOnly: boolean = true): Promise { + async getContiguousParameters(trainableOnly = true): Promise { return this.handler.getContiguousParameters(trainableOnly); } From f20757abc533f818dce0c5344951be956b80606a Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Mon, 20 Nov 2023 13:26:01 -0800 Subject: [PATCH 07/12] add suggested addition to interface doc Co-authored-by: Ashwini Khade --- js/common/lib/training-session.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index a7919b81b22ae..04814b181a94a 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -68,7 +68,7 @@ export interface TrainingSession { loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; /** - * Copies from the TrainingSession parameters to a contiguous buffer. + * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning * * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters * for which requires_grad is set to true. Default value is true. From 43aaeff53d6842bf8e3ef919fd93704aad2281f2 Mon Sep 17 00:00:00 2001 From: carzh Date: Mon, 20 Nov 2023 17:01:03 -0800 Subject: [PATCH 08/12] changed from stackalloc to malloc & added comments --- js/web/lib/wasm/wasm-training-core-impl.ts | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 80e6aac40feca..f5cf031c398b8 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -347,11 +347,11 @@ export const getContiguousParameters = const parametersSize = getParametersSize(trainingSessionId, trainableOnly); let tensor = 0; + // allocates a Float32Array of the correct size on the WASM heap const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm.stackAlloc(paramsByteLength); - wasm.HEAPU8.set(new Float32Array(parametersSize), paramsOffset); - - const tensorOffset = wasm.stackAlloc(paramsOffset / 4); + const paramsOffset = wasm._malloc(paramsByteLength); + const arr = new Float32Array(parametersSize); + wasm.HEAPU8.set(arr, paramsOffset); // handles the dimensions-related createTensor parameters const dims = [parametersSize]; @@ -361,13 +361,13 @@ export const getContiguousParameters = wasm.HEAP32[dimsIndex] = parametersSize; try { + // wraps allocated array in a tensor tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, dataLocationStringToEnum(locationAsString)); ifErrCodeCheckLastError( tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); - wasm.HEAPU32[tensorOffset] = tensor; if (wasm._OrtTrainingCopyParametersToBuffer) { const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); @@ -376,13 +376,14 @@ export const getContiguousParameters = throw new Error(NO_TRAIN_FUNCS_MSG); } + // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); const data = new typedArrayConstructor(parametersSize); const output: TensorMetadata[] = []; new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length > 1 || output.length < 1) { + if (output.length !== 1) { throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of one, got ${output.length}`); } else { @@ -394,7 +395,6 @@ export const getContiguousParameters = } wasm._free(paramsOffset); wasm._free(dimsOffset); - wasm._free(tensorOffset); wasm.stackRestore(stack); } }; @@ -407,15 +407,17 @@ export const loadParametersBuffer = const tensorTypeAsString = 'float32'; const locationAsString = 'cpu'; + // allocates & copies JavaScript buffer to WASM heap const bufferCount = buffer.length; const bufferByteLength = bufferCount * 4; - const bufferOffset = wasm.stackAlloc(bufferByteLength); + const bufferOffset = wasm._malloc(bufferByteLength); wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + + // allocates and handles moving dimensions information to WASM memory const dimsOffset = wasm.stackAlloc(4); wasm.HEAP32[dimsOffset / 4] = bufferCount; const dimsLength = 1; let tensor = 0; - const bufferAlloc = wasm.stackAlloc(bufferOffset / 4); try { tensor = wasm._OrtCreateTensor( @@ -423,8 +425,6 @@ export const loadParametersBuffer = dataLocationStringToEnum(locationAsString)); ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - wasm.HEAPU32[bufferAlloc] = tensor; - if (wasm._OrtTrainingCopyParametersFromBuffer) { const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); @@ -436,7 +436,6 @@ export const loadParametersBuffer = wasm._OrtReleaseTensor(tensor); } wasm.stackRestore(stack); - wasm._free(bufferAlloc); wasm._free(bufferOffset); wasm._free(dimsOffset); } From 242cebe246b41d2e4073e8b0d0fb60d867f48e8e Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 22 Nov 2023 10:26:49 -0800 Subject: [PATCH 09/12] removed unnecessary float32 array --- js/web/lib/wasm/wasm-training-core-impl.ts | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index f5cf031c398b8..1e4cc7e6a0d4e 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -347,11 +347,9 @@ export const getContiguousParameters = const parametersSize = getParametersSize(trainingSessionId, trainableOnly); let tensor = 0; - // allocates a Float32Array of the correct size on the WASM heap + // allocates a buffer of the correct size on the WASM heap const paramsByteLength = 4 * parametersSize; const paramsOffset = wasm._malloc(paramsByteLength); - const arr = new Float32Array(parametersSize); - wasm.HEAPU8.set(arr, paramsOffset); // handles the dimensions-related createTensor parameters const dims = [parametersSize]; From a4ff61cb980a4cfd094992fb88e5d7b83205ba53 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 22 Nov 2023 14:09:45 -0800 Subject: [PATCH 10/12] updated loadParametersBuffer to take in a Uint8Array & removed redundant copy operation --- js/common/lib/training-session-impl.ts | 6 ++++-- js/common/lib/training-session.ts | 12 +++++++----- js/web/lib/wasm/session-handler-training.ts | 2 +- js/web/lib/wasm/wasm-training-core-impl.ts | 8 ++++---- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index e4c40cc9d44fb..03694738387f2 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -180,9 +180,11 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.getParametersSize(trainableOnly); } - async loadParametersBuffer(array: Float32Array, trainableOnly = true): Promise { + async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise { const paramsSize = await this.getParametersSize(trainableOnly); - if (array.length !== paramsSize) { + // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number + // of parameters + if (array.length !== 4 * paramsSize) { throw new Error( 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' + 'the model. Please use getParametersSize method to check.'); diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index 04814b181a94a..810ec2a8583b3 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -60,19 +60,21 @@ export interface TrainingSession { getParametersSize(trainableOnly: boolean): Promise; /** - * Copies parameter values from the given array to the training state. + * Copies parameter values from the given array to the training state. Currently, only supporting models with + * parameters of type Float32. * - * @param buffer - buffer containing parameters + * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ - loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; /** - * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning + * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. + * Currently, only supporting models with parameters of type Float32. * * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters * for which requires_grad is set to true. Default value is true. - * @returns A promise that resolves to a buffer of the requested parameters. + * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters. */ getContiguousParameters(trainableOnly: boolean): Promise; // #endregion diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index d09e21842b337..7de3f4dc2c89e 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -122,7 +122,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return getParametersSize(this.sessionId, trainableOnly); } - async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise { + async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { await loadParametersBuffer(this.sessionId, array, trainableOnly); } async getContiguousParameters(trainableOnly: boolean): Promise { diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 1e4cc7e6a0d4e..251bc612ab085 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -398,7 +398,7 @@ export const getContiguousParameters = }; export const loadParametersBuffer = - async(trainingSessionId: number, buffer: Float32Array, trainableOnly: boolean): Promise => { + async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { const wasm = getInstance(); const stack = wasm.stackSave(); @@ -406,10 +406,10 @@ export const loadParametersBuffer = const locationAsString = 'cpu'; // allocates & copies JavaScript buffer to WASM heap - const bufferCount = buffer.length; - const bufferByteLength = bufferCount * 4; + const bufferCount = getParametersSize(trainingSessionId, trainableOnly); + const bufferByteLength = buffer.length; const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(new Uint8Array(buffer.buffer, buffer.byteOffset, buffer.byteLength), bufferOffset); + wasm.HEAPU8.set(buffer, bufferOffset); // allocates and handles moving dimensions information to WASM memory const dimsOffset = wasm.stackAlloc(4); From fa9f54576acdd5a9cf07e34d4e46ddea9595ded3 Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 22 Nov 2023 14:25:27 -0800 Subject: [PATCH 11/12] added suggestion --- js/common/lib/backend.ts | 2 +- js/web/lib/wasm/wasm-training-core-impl.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index fd2e8bb74bbf5..67d283b694955 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -50,7 +50,7 @@ export interface TrainingSessionHandler extends SessionHandler { options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise; + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; getContiguousParameters(trainableOnly: boolean): Promise; } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 251bc612ab085..c0a4235113148 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -406,8 +406,8 @@ export const loadParametersBuffer = const locationAsString = 'cpu'; // allocates & copies JavaScript buffer to WASM heap - const bufferCount = getParametersSize(trainingSessionId, trainableOnly); const bufferByteLength = buffer.length; + const bufferCount = bufferByteLength / 4; const bufferOffset = wasm._malloc(bufferByteLength); wasm.HEAPU8.set(buffer, bufferOffset); From 491383d65714afe29f318d7d72ab9d53df151b8d Mon Sep 17 00:00:00 2001 From: carzh Date: Wed, 29 Nov 2023 11:24:44 -0800 Subject: [PATCH 12/12] missed a merge conflict marker oops --- js/web/lib/wasm/session-handler-training.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 195147a0b9c2e..721669b2fc0a6 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -6,11 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -<<<<<<< HEAD import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; -======= -import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; ->>>>>>> main export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number;