Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] added parameters methods implementations #18086

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cff6a7f
added training build configuration
carzh Aug 29, 2023
7a1365d
edited wasm factory + added training backend
carzh Aug 29, 2023
71f9dbc
added package.json modification for training + minimized training art…
carzh Aug 31, 2023
5adbcf4
applied suggestions
carzh Sep 22, 2023
9068ab4
create training session implementation + supporting changes
carzh Oct 6, 2023
edbe9cc
Merge remote-tracking branch 'origin/main' into carzh/training-wasm-b…
fs-eire Oct 6, 2023
5202e3b
fixed variable names + enforced wasmBackend being a singleton with su…
carzh Oct 6, 2023
a277347
format + lint
carzh Oct 7, 2023
adfa5cf
Merge branch 'main' into carzh/training-wasm-binding
carzh Oct 11, 2023
c9931ad
Merge branch 'carzh/training-wasm-binding' into carzh/create-training…
carzh Oct 11, 2023
e775933
minor tweak to remove placeholder return statement
carzh Oct 11, 2023
c745a65
format
carzh Oct 11, 2023
beca6dc
lint fixes
carzh Oct 11, 2023
b37e77f
lint + format
carzh Oct 11, 2023
8c1e227
Merge branch 'main' into carzh/create-training-session
carzh Oct 12, 2023
d2f3d88
added isOrtEnvInitialized case to proxy wrapper
carzh Oct 12, 2023
712e73c
Merge remote-tracking branch 'refs/remotes/origin/carzh/create-traini…
carzh Oct 12, 2023
05a708f
fixed proxy wrapper case statement
carzh Oct 13, 2023
46a9677
working runTrainStep implementation
carzh Oct 11, 2023
3ca1c27
light refactoring
carzh Oct 17, 2023
b44b705
getContiguousParameters & loadParametersBuffer impl
carzh Oct 23, 2023
c74112e
lint + format + added error code checking wrapper
carzh Oct 24, 2023
cd2e4fa
enforced that loadParametersBuffer takes in a buffer that matches the…
carzh Oct 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ export interface TrainingSessionHandler extends SessionHandler {
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

/**
Expand Down
152 changes: 138 additions & 14 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';

type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
type FetchesType = InferenceSession.FetchesType;
type ReturnType = InferenceSession.ReturnType;
type RunOptions = InferenceSession.RunOptions;

const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
Expand All @@ -20,27 +30,141 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}

static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
} else {
throw new Error(noBackendErrMsg);
}
}

runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
throw new TypeError(
'\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.');
}

let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof Tensor) {
throw new TypeError('\'fetches\' cannot be a Tensor');
}

if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError('\'fetches\' cannot be an empty array.');
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}

if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}

if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
options = arg1 as RunOptions;
}
}
} else if (typeof arg1 !== 'undefined') {
throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.');
}

// check if all inputs are in feed
for (const name of this.inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}

// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
fetches[name] = null;
}
}

const results = await this.handler.runTrainStep(feeds, fetches, options);
const returnValue: {[name: string]: OnnxValue} = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof Tensor) {
returnValue[key] = result;
} else {
returnValue[key] = new Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
async getParametersSize(trainableOnly: boolean): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
async loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
if (array.length !== paramsSize) {
throw new Error('Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.');
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}

runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined):
Promise<InferenceSession.OnnxValueMapType>;
runTrainStep(
feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions|undefined): Promise<InferenceSession.OnnxValueMapType>;
async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown):
Promise<InferenceSession.OnnxValueMapType> {
throw new Error('Method not implemented.');
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}

async release(): Promise<void> {
Expand Down
13 changes: 11 additions & 2 deletions js/common/lib/training-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -49,21 +50,29 @@ export interface TrainingSession {
// #endregion

// #region copy parameters

/**
* Retrieves the size of all parameters for the training state.
*
* @param trainableOnly skips non-trainable parameters when true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;

/**
* Copies from a buffer containing parameters to the TrainingSession parameters.
*
* @param buffer - buffer containing parameters
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise.
*/
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
loadParametersBuffer(array: Float32Array, trainableOnly: boolean): Promise<void>;

/**
* Copies from the TrainingSession parameters to a buffer.
*
* @param trainableOnly - True if trainable parameters only to be copied, false othrwise.
* @returns A promise that resolves to a buffer of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion

// #region release()
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-onnxjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {Session} from './onnxjs/session';
import {OnnxjsSessionHandler} from './onnxjs/session-handler';
import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference';

class OnnxjsBackend implements Backend {
// eslint-disable-next-line @typescript-eslint/no-empty-function
Expand Down
12 changes: 8 additions & 4 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
}
}

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
* This function initializes all flags for WebAssembly.
Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetInputOutputCount?
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
_OrtTrainingGetInputOutputName?
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;

_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion

Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}

interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
10 changes: 9 additions & 1 deletion js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// <reference lib="webworker" />

import {OrtWasmMessage} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
Expand Down Expand Up @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
case 'is-ort-env-initialized':
try {
const ortEnvInitialized = isOrtEnvInitialized();
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
}
break;
default:
}
};
21 changes: 21 additions & 0 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];

const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
Expand Down Expand Up @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
endProfilingCallbacks.shift()![0]();
}
break;
case 'is-ort-env-initialized':
if (ev.data.err) {
isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
} else {
isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
}
break;
default:
}
};
Expand Down Expand Up @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise<void> => {
core.endProfiling(sessionId);
}
};

export const isOrtEnvInitialized = async(): Promise<boolean> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<boolean>((resolve, reject) => {
isOrtEnvInitializedCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
proxyWorker!.postMessage(message);
});
} else {
return core.isOrtEnvInitialized();
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise<void>|undefined;

const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
switch (tensor.location) {
case 'cpu':
return [tensor.type, tensor.dims, tensor.data, 'cpu'];
Expand All @@ -22,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta
}
};

const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
switch (tensor[3]) {
case 'cpu':
return new Tensor(tensor[0], tensor[2], tensor[1]);
Expand Down Expand Up @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!runtimeInitialized) {
if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
runtimeInitialized = true;
}

if (typeof pathOrBuffer === 'string') {
Expand Down
Loading
Loading