Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] Implemented runEvalStep & runOptimizerStep #18259

Merged
merged 16 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,20 @@ export interface InferenceSessionHandler extends SessionHandler {
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];

runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

/**
Expand Down
86 changes: 71 additions & 15 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;

get inputNames(): readonly string[] {
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get outputNames(): readonly string[] {
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}

get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}

static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
Expand All @@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
Expand All @@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface {
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
[SessionHandler.FetchesType, RunOptions] {
typeNarrowingForRunStep(
inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions,
arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
Expand Down Expand Up @@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.outputNames.indexOf(name) === -1) {
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
Expand All @@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
Expand All @@ -130,15 +154,15 @@ export class TrainingSession implements TrainingSessionInterface {
}

// check if all inputs are in feed
for (const name of this.inputNames) {
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}

// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
for (const name of outputNames) {
fetches[name] = null;
}
}
Expand Down Expand Up @@ -171,17 +195,49 @@ export class TrainingSession implements TrainingSessionInterface {
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const [fetches, options] =
this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}

runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] =
this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}

async getParametersSize(trainableOnly: boolean): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
if (array.length !== paramsSize) {
throw new Error(
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.');
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}

async release(): Promise<void> {
Expand Down
74 changes: 64 additions & 10 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -38,32 +39,74 @@ export interface TrainingSession {
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model inference.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
Promise<InferenceSession.ReturnType>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

// #endregion

// #region copy parameters

/**
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
* the parameters) elements of all the parameters in the training state.
*
* @param trainableOnly - When set to true, the size is calculated for trainable params only.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;

/**
* Copies from a buffer containing parameters to the TrainingSession parameters.
* Copies parameter values from the given array to the training state.
*
* @param buffer - buffer containing parameters
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise.
*/
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;

/**
* Copies from the TrainingSession parameters to a buffer.
* Copies from the TrainingSession parameters to a contiguous buffer.
*
* @param trainableOnly - True if trainable parameters only to be copied, false othrwise.
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
* for which requires_grad is set to true.
* @returns A promise that resolves to a buffer of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion

// #region release()
Expand All @@ -77,14 +120,25 @@ export interface TrainingSession {
// #region metadata

/**
* Get input names of the loaded model.
* Get input names of the loaded training model.
*/
readonly trainingInputNames: readonly string[];

/**
* Get output names of the loaded training model.
*/
readonly inputNames: readonly string[];
readonly trainingOutputNames: readonly string[];

/**
* Get output names of the loaded model.
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly outputNames: readonly string[];
readonly evalInputNames: readonly string[];

/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];

// #endregion
}

Expand Down
56 changes: 43 additions & 13 deletions js/web/lib/wasm/session-handler-training.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';
import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

inputNames: string[];
outputNames: string[];

inputEncodedNames: number[];
outputEncodedNames: number[];
evalInputNames: string[] = [];
evalOutputNames: string[] = [];

async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
Expand Down Expand Up @@ -57,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}

this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
this.sessionId =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
if (evalModelUriOrBuffer !== '') {
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
}
}

/**
Expand Down Expand Up @@ -124,8 +122,40 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}

async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
await runOptimizerStep(this.sessionId, options);
}

async runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds, this.evalInputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`));

const [outputArray, outputIndices, outputs] =
this.convertMapIntoValuesArrayAndIndicesArray<Tensor|null, TensorMetadata|null>(
fetches, this.evalOutputNames,
(t, i): TensorMetadata|null =>
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null);

const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}

async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}

async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
await loadParametersBuffer(this.sessionId, array, trainableOnly);
}
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
return decodeTensorMetadata(tensorResult);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}
Loading
Loading