Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web/training] Add CreateTrainingSession #17891

Merged
merged 24 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cff6a7f
added training build configuration
carzh Aug 29, 2023
7a1365d
edited wasm factory + added training backend
carzh Aug 29, 2023
71f9dbc
added package.json modification for training + minimized training art…
carzh Aug 31, 2023
5adbcf4
applied suggestions
carzh Sep 22, 2023
9068ab4
create training session implementation + supporting changes
carzh Oct 6, 2023
edbe9cc
Merge remote-tracking branch 'origin/main' into carzh/training-wasm-b…
fs-eire Oct 6, 2023
5202e3b
fixed variable names + enforced wasmBackend being a singleton with su…
carzh Oct 6, 2023
a277347
format + lint
carzh Oct 7, 2023
adfa5cf
Merge branch 'main' into carzh/training-wasm-binding
carzh Oct 11, 2023
c9931ad
Merge branch 'carzh/training-wasm-binding' into carzh/create-training…
carzh Oct 11, 2023
e775933
minor tweak to remove placeholder return statement
carzh Oct 11, 2023
c745a65
format
carzh Oct 11, 2023
beca6dc
lint fixes
carzh Oct 11, 2023
b37e77f
lint + format
carzh Oct 11, 2023
8c1e227
Merge branch 'main' into carzh/create-training-session
carzh Oct 12, 2023
d2f3d88
added isOrtEnvInitialized case to proxy wrapper
carzh Oct 12, 2023
712e73c
Merge remote-tracking branch 'refs/remotes/origin/carzh/create-traini…
carzh Oct 12, 2023
05a708f
fixed proxy wrapper case statement
carzh Oct 13, 2023
daa1023
updated getInputOutputCount and getInputOutputNames signature, added …
carzh Oct 23, 2023
389f08e
updated parameter names according to suggestions
carzh Oct 23, 2023
70505c4
format + lint
carzh Oct 24, 2023
a5bc60c
lint fix -- changed multiline string to singlequote string
carzh Oct 24, 2023
ae936e5
Apply naming suggestions from code review
carzh Oct 25, 2023
316a6e7
implemented naming suggestions + fixed training session & checkpoint …
carzh Oct 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
import { resolveBackend } from './backend-impl.js';

type SessionOptions = InferenceSession.SessionOptions;
const noBackendErrMsg: string = "Training backend could not be resolved. " +
"Make sure you\'re using the correct configuration & WebAssembly files.";

export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
Expand All @@ -20,9 +23,26 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}

static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions,sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
let checkpointState: string|Uint8Array = trainingOptions.checkpointState;
let trainModel: string|Uint8Array = trainingOptions.trainModel;
let evalModel: string|Uint8Array = trainingOptions.evalModel ? trainingOptions.evalModel : '';
let optimizerModel: string|Uint8Array = trainingOptions.optimizerModel ? trainingOptions.optimizerModel : '';
let options: SessionOptions = sessionOptions ? sessionOptions : {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler =
await backend.createTrainingSessionHandler(checkpointState, trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
}
else {
throw new Error(noBackendErrMsg);
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/backend-wasm-inference.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
21 changes: 21 additions & 0 deletions js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
}
}

export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
4 changes: 1 addition & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
}
};

class OnnxruntimeWebAssemblyBackend implements Backend {
export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
// populate wasm flags
initializeFlags();
Expand All @@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}

export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
4 changes: 4 additions & 0 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ interface BuildDefinitions {
* defines whether to disable multi-threading feature in WebAssembly backend in the build.
*/
readonly DISABLE_WASM_THREAD: boolean;
/**
* defines whether to disable training APIs in WebAssembly backend.
*/
readonly DISABLE_TRAINING: boolean;
}

declare const BUILD_DEFS: BuildDefinitions;
9 changes: 6 additions & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}

if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
if (BUILD_DEFS.DISABLE_TRAINING) {
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
3 changes: 3 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;

_OrtTrainingGetInputOutputCount?(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtTrainingGetInputOutputName?(sessionHandle: number, index: number, isInput: boolean): number;
carzh marked this conversation as resolved.
Show resolved Hide resolved

_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion

Expand Down
7 changes: 6 additions & 1 deletion js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}

interface MessageIsOrtEnvInitialized extends MessageError {
askhade marked this conversation as resolved.
Show resolved Hide resolved
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
14 changes: 14 additions & 0 deletions js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];

const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
Expand Down Expand Up @@ -251,3 +252,16 @@
core.endProfiling(sessionId);
}
};

export const isOrtEnvInitialized = async(): Promise<boolean> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<boolean>((resolve, reject) => {
isOrtEnvInitializedCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
proxyWorker!.postMessage(message);
});
} else {
return core.isOrtEnvInitialized();
}
}
Fixed Show fixed Hide fixed
72 changes: 72 additions & 0 deletions js/web/lib/wasm/session-handler-for-training.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';

import {SerializableModeldata} from './proxy-messages';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';

export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
getContiguousParameters(trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;

inputNames: string[];
outputNames: string[];

inputEncodedNames: number[];
outputEncodedNames: number[];

async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return createSessionAllocate(buffer);
}

async createTrainingSession(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions) {
if (!isOrtEnvInitialized()) {
await initRuntime(env);
}
let checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
let trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableModeldata = [0, 0];
let optimizerModelData: SerializableModeldata = [0, 0];

if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}

this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
}

async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}

async runTrainStep(
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
throw new Error('Method not implemented yet.');
}
}
6 changes: 2 additions & 4 deletions js/web/lib/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';

import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run, isOrtEnvInitialized} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise<void>|undefined;

const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
Expand Down Expand Up @@ -57,13 +56,12 @@
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!runtimeInitialized) {
if (!isOrtEnvInitialized()) {
Fixed Show fixed Hide fixed
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
runtimeInitialized = true;
}

if (typeof pathOrBuffer === 'string') {
Expand Down
7 changes: 7 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env);
}

ortEnvInitialized = true;
};

/**
Expand Down Expand Up @@ -92,6 +94,11 @@ type SessionMetadata = [
];

const activeSessions = new Map<number, SessionMetadata>();
let ortEnvInitialized = false;

export const isOrtEnvInitialized = (): boolean => {
return ortEnvInitialized;
};

/**
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
Expand Down
19 changes: 14 additions & 5 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm';
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';

/* eslint-disable @typescript-eslint/no-require-imports */
const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;

if (!BUILD_DEFS.DISABLE_TRAINING) {
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
} else {
ortWasmFactory =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
}

const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
(BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
Expand Down Expand Up @@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => {
};

const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useThreads) {
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
if (useSimd) {
if (!BUILD_DEFS.DISABLE_TRAINING) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
} else {
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
}
};

Expand Down
Loading
Loading