diff --git a/common/lib/backend.ts b/common/lib/backend.ts index 67d283b694955..20dca8942d387 100644 --- a/common/lib/backend.ts +++ b/common/lib/backend.ts @@ -45,9 +45,16 @@ export interface InferenceSessionHandler extends SessionHandler { * @ignore */ export interface TrainingSessionHandler extends SessionHandler { + readonly evalInputNames: readonly string[]; + readonly evalOutputNames: readonly string[]; + runTrainStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise; + runOptimizerStep(options: InferenceSession.RunOptions): Promise; + runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; diff --git a/common/lib/training-session-impl.ts b/common/lib/training-session-impl.ts index 03694738387f2..5260b54b69221 100644 --- a/common/lib/training-session-impl.ts +++ b/common/lib/training-session-impl.ts @@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { - private constructor(handler: TrainingSessionHandler) { + private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) { this.handler = handler; + this.hasOptimizerModel = hasOptimizerModel; + this.hasEvalModel = hasEvalModel; } private handler: TrainingSessionHandler; + private hasOptimizerModel: boolean; + private hasEvalModel: boolean; - get inputNames(): readonly string[] { + get trainingInputNames(): readonly string[] { return this.handler.inputNames; } - get outputNames(): readonly string[] { + get trainingOutputNames(): readonly string[] { return this.handler.outputNames; } + get evalInputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalInputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + get evalOutputNames(): readonly string[] { + if (this.hasEvalModel) { + return this.handler.evalOutputNames; + } else { + throw new Error('This training session has no evalModel loaded.'); + } + } + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; @@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface { if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); - return new TrainingSession(handler); + return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); } @@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface { * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from * the given parameters to SessionHandler.FetchesType and RunOptions. * + * @param inputNames the feeds object is checked that they contain all input names in the provided list of input + * names. + * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output + * names. * @param feeds the required input * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object * @param arg2 optional RunOptions object. * @returns */ - typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): - [SessionHandler.FetchesType, RunOptions] { + typeNarrowingForRunStep( + inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions, + arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] { const fetches: {[name: string]: OnnxValue|null} = {}; let options: RunOptions = {}; // check inputs @@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface { if (typeof name !== 'string') { throw new TypeError('\'fetches\' must be a string array or an object.'); } - if (this.outputNames.indexOf(name) === -1) { + if (outputNames.indexOf(name) === -1) { throw new RangeError(`'fetches' contains invalid output name: ${name}.`); } fetches[name] = null; @@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface { // if any output name is present and its value is valid OnnxValue, we consider it fetches let isFetches = false; const arg1Keys = Object.getOwnPropertyNames(arg1); - for (const name of this.outputNames) { + for (const name of outputNames) { if (arg1Keys.indexOf(name) !== -1) { const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name]; if (v === null || v instanceof Tensor) { @@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface { } // check if all inputs are in feed - for (const name of this.inputNames) { + for (const name of inputNames) { if (typeof feeds[name] === 'undefined') { throw new Error(`input '${name}' is missing in 'feeds'.`); } @@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface { // if no fetches is specified, we use the full output names list if (isFetchesEmpty) { - for (const name of this.outputNames) { + for (const name of outputNames) { fetches[name] = null; } } @@ -171,11 +195,33 @@ export class TrainingSession implements TrainingSessionInterface { runTrainStep(feeds: FeedsType, options?: RunOptions): Promise; runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { - const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2); + const [fetches, options] = + this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2); const results = await this.handler.runTrainStep(feeds, fetches, options); return this.convertHandlerReturnTypeToMapOfTensors(results); } + async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise { + if (this.hasOptimizerModel) { + await this.handler.runOptimizerStep(options || {}); + } else { + throw new Error('This TrainingSession has no OptimizerModel loaded.'); + } + } + + runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise; + runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise; + async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise { + if (this.hasEvalModel) { + const [fetches, options] = + this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2); + const results = await this.handler.runEvalStep(feeds, fetches, options); + return this.convertHandlerReturnTypeToMapOfTensors(results); + } else { + throw new Error('This TrainingSession has no EvalModel loaded.'); + } + } + async getParametersSize(trainableOnly = true): Promise { return this.handler.getParametersSize(trainableOnly); } diff --git a/common/lib/training-session.ts b/common/lib/training-session.ts index 810ec2a8583b3..0cd35ee6c4087 100644 --- a/common/lib/training-session.ts +++ b/common/lib/training-session.ts @@ -39,7 +39,7 @@ export interface TrainingSession { * @param feeds - Representation of the model input. * @param fetches - Representation of the model output. * detail. - * @param options - Optional. A set of options that controls the behavior of model inference. + * @param options - Optional. A set of options that controls the behavior of model training. * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. */ @@ -47,6 +47,38 @@ export interface TrainingSession { feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, options?: InferenceSession.RunOptions): Promise; + /** + * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model. + * + * @param options - Optional. A set of options that controls the behavior of model optimizing. + */ + runOptimizerStep(options?: InferenceSession.RunOptions): Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single eval step with the given inputs and options using the eval model. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model eval step. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runEvalStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + // #endregion // #region copy parameters @@ -90,14 +122,25 @@ export interface TrainingSession { // #region metadata /** - * Get input names of the loaded model. + * Get input names of the loaded training model. */ - readonly inputNames: readonly string[]; + readonly trainingInputNames: readonly string[]; /** - * Get output names of the loaded model. + * Get output names of the loaded training model. */ - readonly outputNames: readonly string[]; + readonly trainingOutputNames: readonly string[]; + + /** + * Get input names of the loaded eval model. Is an empty array if no eval model is loaded. + */ + readonly evalInputNames: readonly string[]; + + /** + * Get output names of the loaded eval model. Is an empty array if no eval model is loaded. + */ + readonly evalOutputNames: readonly string[]; + // #endregion } diff --git a/web/lib/wasm/session-handler-training.ts b/web/lib/wasm/session-handler-training.ts index 7de3f4dc2c89e..721669b2fc0a6 100644 --- a/web/lib/wasm/session-handler-training.ts +++ b/web/lib/wasm/session-handler-training.ts @@ -6,7 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; -import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number; @@ -15,8 +15,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes inputNames: string[]; outputNames: string[]; - inputEncodedNames: number[]; - outputEncodedNames: number[]; + evalInputNames: string[] = []; + evalOutputNames: string[] = []; async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; @@ -51,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } this.checkpointId = createCheckpointHandle(checkpointData); - [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + this.sessionId = createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); + if (evalModelUriOrBuffer !== '') { + [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); + } } /** @@ -118,6 +122,27 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } + async runOptimizerStep(options: InferenceSession.RunOptions): Promise { + await runOptimizerStep(this.sessionId, options); + } + + async runEvalStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise { + const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( + feeds, this.evalInputNames, + (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`)); + + const [outputArray, outputIndices, outputs] = + this.convertMapIntoValuesArrayAndIndicesArray( + fetches, this.evalOutputNames, + (t, i): TensorMetadata|null => + t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null); + + const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); + return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); + } + async getParametersSize(trainableOnly: boolean): Promise { return getParametersSize(this.sessionId, trainableOnly); } @@ -131,7 +156,6 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint( - this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); } } diff --git a/web/lib/wasm/wasm-training-core-impl.ts b/web/lib/wasm/wasm-training-core-impl.ts index c0a4235113148..3aea4e308ea6e 100644 --- a/web/lib/wasm/wasm-training-core-impl.ts +++ b/web/lib/wasm/wasm-training-core-impl.ts @@ -3,7 +3,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; @@ -77,50 +77,44 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea }; const getModelInputOutputNamesLoop = - (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => { const names = []; const wasm = getInstance(); - const namesUTF8Encoded = []; - for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - namesUTF8Encoded.push(name); names.push(wasm.UTF8ToString(name)); + wasm._free(name); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } - return [names, namesUTF8Encoded]; + return names; }; -const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); +export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { + let inputNames: string[] = []; + let outputNames: string[] = []; + + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); - const [outputNames, outputNamesUTF8Encoded] = - getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); + outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; + return [inputNames, outputNames]; }; export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, - optimizerModelData: SerializableModeldata, - options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; let sessionOptionsHandle = 0; let allocs: number[] = []; - let inputNamesUTF8Encoded: number[] = []; - let outputNamesUTF8Encoded: number[] = []; - - let inputNames: string[] = []; - let outputNames: string[] = []; try { [sessionOptionsHandle, allocs] = setSessionOptions(options); @@ -133,11 +127,7 @@ export const createTrainingSessionHandle = } ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - - [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = - getTrainingModelInputOutputNames(trainingSessionHandle); - return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; - + return trainingSessionHandle; } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); @@ -152,8 +142,6 @@ export const createTrainingSessionHandle = wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } allocs.forEach(alloc => wasm._free(alloc)); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); } }; @@ -317,6 +305,83 @@ export const runTrainStep = async( } }; +export const runOptimizerStep = + async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + try { + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + if (wasm._OrtTrainingOptimizerStep) { + const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); + ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + +export const runEvalStep = async( + trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { + const wasm = getInstance(); + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; + + let runOptionsHandle = 0; + let runOptionsAllocs: number[] = []; + + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + + try { + // prepare parameters by moving them to heap + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + + // handle inputs -- you don't want anything added to the index + const inputValuesOffset = createAndAllocateTensors( + trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); + // handle outputs + // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor + const outputValuesOffset = createAndAllocateTensors( + trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); + + if (wasm._OrtTrainingEvalStep) { + const errorCode = wasm._OrtTrainingEvalStep( + trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); + + ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); + } finally { + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); + + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + runOptionsAllocs.forEach(p => wasm._free(p)); + } +}; + export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { const wasm = getInstance(); const stack = wasm.stackSave(); @@ -439,17 +504,13 @@ export const loadParametersBuffer = } }; -export const releaseTrainingSessionAndCheckpoint = - (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): - void => { - const wasm = getInstance(); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); +export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { + const wasm = getInstance(); - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } - }; + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } +};