Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] Implemented runEvalStep & runOptimizerStep #18259

Merged
merged 16 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,16 @@ export interface InferenceSessionHandler extends SessionHandler {
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];

runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
Expand Down
68 changes: 57 additions & 11 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;

get inputNames(): readonly string[] {
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get outputNames(): readonly string[] {
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}

get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}

static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
Expand All @@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
Expand All @@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface {
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
[SessionHandler.FetchesType, RunOptions] {
typeNarrowingForRunStep(
inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions,
arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
Expand Down Expand Up @@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.outputNames.indexOf(name) === -1) {
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
Expand All @@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
Expand All @@ -130,15 +154,15 @@ export class TrainingSession implements TrainingSessionInterface {
}

// check if all inputs are in feed
for (const name of this.inputNames) {
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}

// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
for (const name of outputNames) {
fetches[name] = null;
}
}
Expand Down Expand Up @@ -171,11 +195,33 @@ export class TrainingSession implements TrainingSessionInterface {
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const [fetches, options] =
this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}

async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}

runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] =
this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}

async getParametersSize(trainableOnly = true): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}
Expand Down
53 changes: 48 additions & 5 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,46 @@ export interface TrainingSession {
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model inference.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
Promise<InferenceSession.ReturnType>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

// #endregion

// #region copy parameters
Expand Down Expand Up @@ -90,14 +122,25 @@ export interface TrainingSession {
// #region metadata

/**
* Get input names of the loaded model.
* Get input names of the loaded training model.
*/
readonly inputNames: readonly string[];
readonly trainingInputNames: readonly string[];

/**
* Get output names of the loaded model.
* Get output names of the loaded training model.
*/
readonly outputNames: readonly string[];
readonly trainingOutputNames: readonly string[];

/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalInputNames: readonly string[];

/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];

// #endregion
}

Expand Down
38 changes: 33 additions & 5 deletions js/web/lib/wasm/session-handler-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
<<<<<<< HEAD
Fixed Show fixed Hide fixed
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';
=======
Fixed Show fixed Hide fixed
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
>>>>>>> main
Fixed Show fixed Hide fixed

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
Expand All @@ -15,8 +19,8 @@
inputNames: string[];
outputNames: string[];

inputEncodedNames: number[];
outputEncodedNames: number[];
evalInputNames: string[] = [];
evalOutputNames: string[] = [];

async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
Expand Down Expand Up @@ -51,8 +55,12 @@
}

this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
this.sessionId =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
if (evalModelUriOrBuffer !== '') {
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
}
}

/**
Expand Down Expand Up @@ -118,6 +126,27 @@
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}

async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
await runOptimizerStep(this.sessionId, options);
}

async runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds, this.evalInputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`));

const [outputArray, outputIndices, outputs] =
this.convertMapIntoValuesArrayAndIndicesArray<Tensor|null, TensorMetadata|null>(
fetches, this.evalOutputNames,
(t, i): TensorMetadata|null =>
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null);

const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}

async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}
Expand All @@ -131,7 +160,6 @@
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}
Loading
Loading