Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] runEvalStep & runOptimizerStep implementations #18117

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cff6a7f
added training build configuration
carzh Aug 29, 2023
7a1365d
edited wasm factory + added training backend
carzh Aug 29, 2023
71f9dbc
added package.json modification for training + minimized training art…
carzh Aug 31, 2023
5adbcf4
applied suggestions
carzh Sep 22, 2023
9068ab4
create training session implementation + supporting changes
carzh Oct 6, 2023
edbe9cc
Merge remote-tracking branch 'origin/main' into carzh/training-wasm-b…
fs-eire Oct 6, 2023
5202e3b
fixed variable names + enforced wasmBackend being a singleton with su…
carzh Oct 6, 2023
a277347
format + lint
carzh Oct 7, 2023
adfa5cf
Merge branch 'main' into carzh/training-wasm-binding
carzh Oct 11, 2023
c9931ad
Merge branch 'carzh/training-wasm-binding' into carzh/create-training…
carzh Oct 11, 2023
e775933
minor tweak to remove placeholder return statement
carzh Oct 11, 2023
c745a65
format
carzh Oct 11, 2023
beca6dc
lint fixes
carzh Oct 11, 2023
b37e77f
lint + format
carzh Oct 11, 2023
8c1e227
Merge branch 'main' into carzh/create-training-session
carzh Oct 12, 2023
d2f3d88
added isOrtEnvInitialized case to proxy wrapper
carzh Oct 12, 2023
712e73c
Merge remote-tracking branch 'refs/remotes/origin/carzh/create-traini…
carzh Oct 12, 2023
05a708f
fixed proxy wrapper case statement
carzh Oct 13, 2023
46a9677
working runTrainStep implementation
carzh Oct 11, 2023
3ca1c27
light refactoring
carzh Oct 17, 2023
daa1023
updated getInputOutputCount and getInputOutputNames signature, added …
carzh Oct 23, 2023
389f08e
updated parameter names according to suggestions
carzh Oct 23, 2023
70505c4
format + lint
carzh Oct 24, 2023
a5bc60c
lint fix -- changed multiline string to singlequote string
carzh Oct 24, 2023
b44b705
getContiguousParameters & loadParametersBuffer impl
carzh Oct 23, 2023
c74112e
lint + format + added error code checking wrapper
carzh Oct 24, 2023
f425083
added runOptimizerStep and runEvalStep to interface
carzh Oct 25, 2023
cd2e4fa
enforced that loadParametersBuffer takes in a buffer that matches the…
carzh Oct 25, 2023
fa9f9ee
Merge branch 'carzh/web-parameters-methods' into carzh/web-runstep-me…
carzh Oct 25, 2023
799f54d
wrote evalStep and optimizerStep impl + refactoring similarities
carzh Oct 25, 2023
ae936e5
Apply naming suggestions from code review
carzh Oct 25, 2023
316a6e7
implemented naming suggestions + fixed training session & checkpoint …
carzh Oct 25, 2023
0af7d50
Merge branch 'carzh/create-training-session' into carzh/web-runstep-m…
carzh Oct 25, 2023
dcdb62d
evalModel restructuring updates :(
carzh Oct 26, 2023
48c9db6
updated release method + format + lint
carzh Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,21 @@ export interface InferenceSessionHandler extends SessionHandler {
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];

runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;

runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
}

/**
Expand Down
219 changes: 201 additions & 18 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,46 +1,229 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {TrainingSessionHandler} from './backend.js';
import {resolveBackend} from './backend-impl.js';
import {SessionHandler, TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';

type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
type FetchesType = InferenceSession.FetchesType;
type ReturnType = InferenceSession.ReturnType;
type RunOptions = InferenceSession.RunOptions;

const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;

get inputNames(): readonly string[] {
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get outputNames(): readonly string[] {
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}

static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}

static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(
handler, trainingOptions.optimizerModel ? true : false, trainingOptions.evalModel ? true : false);
} else {
throw new Error(noBackendErrMsg);
}
}

/**
* Helper function for the run methods
*
* @param feeds
* @param arg1
* @param arg2
* @returns
*/
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
[SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
throw new TypeError(
'\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.');
}

let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof Tensor) {
throw new TypeError('\'fetches\' cannot be a Tensor');
}

if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError('\'fetches\' cannot be an empty array.');
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.trainingOutputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}

if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.trainingOutputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}

if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
options = arg1 as RunOptions;
}
}
} else if (typeof arg1 !== 'undefined') {
throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.');
}

// check if all inputs are in feed
for (const name of this.trainingInputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}

// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.trainingOutputNames) {
fetches[name] = null;
}
}

return [fetches, options];
}

processHandlerReturnToSessionReturn(results: SessionHandler.ReturnType): ReturnType {
const returnValue: {[name: string]: OnnxValue} = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof Tensor) {
returnValue[key] = result;
} else {
returnValue[key] = new Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}

runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.processHandlerReturnToSessionReturn(results);
}

async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}

runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.processHandlerReturnToSessionReturn(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
async getParametersSize(trainableOnly: boolean): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
if (array.length !== paramsSize) {
throw new Error(
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.');
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}

runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined):
Promise<InferenceSession.OnnxValueMapType>;
runTrainStep(
feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions|undefined): Promise<InferenceSession.OnnxValueMapType>;
async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown):
Promise<InferenceSession.OnnxValueMapType> {
throw new Error('Method not implemented.');
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}

async release(): Promise<void> {
Expand Down
65 changes: 58 additions & 7 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -38,32 +39,72 @@ export interface TrainingSession {
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model inference.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
Promise<InferenceSession.ReturnType>;

/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;

// #endregion

// #region copy parameters

/**
* Retrieves the size of all parameters for the training state.
*
* @param trainableOnly skips non-trainable parameters when true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;

/**
* Copies from a buffer containing parameters to the TrainingSession parameters.
*
* @param buffer - buffer containing parameters
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise.
*/
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;

/**
* Copies from the TrainingSession parameters to a buffer.
*
* @param trainableOnly - True if trainable parameters only to be copied, false othrwise.
* @returns A promise that resolves to a buffer of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion

// #region release()
Expand All @@ -77,14 +118,24 @@ export interface TrainingSession {
// #region metadata

/**
* Get input names of the loaded model.
* Get input names of the loaded training model.
*/
readonly trainingInputNames: readonly string[];

/**
* Get output names of the loaded training model.
*/
readonly trainingOutputNames: readonly string[];

/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly inputNames: readonly string[];
readonly evalInputNames: readonly string[];

/**
* Get output names of the loaded model.
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly outputNames: readonly string[];
readonly evalOutputNames: readonly string[];
// #endregion
}

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-onnxjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {Session} from './onnxjs/session';
import {OnnxjsSessionHandler} from './onnxjs/session-handler';
import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference';

class OnnxjsBackend implements Backend {
// eslint-disable-next-line @typescript-eslint/no-empty-function
Expand Down
Loading