Skip to content

Commit

Permalink
[js/web/training] runTrainStep implementation (microsoft#18006)
Browse files Browse the repository at this point in the history
### Description
* based on design document & following InferenceSession's run
implementation, implemented TrainingSession.runTrainStep

### Motivation and Context
* Adding web bindings for training

#### Related work
* microsoft#16521 allowed for training artifacts to be built
* microsoft#17333 added interfaces for training
* microsoft#17474 allowed for training package to be built + added training
backend to web package
* microsoft#17891 implementation for createTrainingSession on the TypeScript side
**[SHOULD BE MERGED IN BEFORE THIS PR]**

---------

Co-authored-by: Yulong Wang <[email protected]>
Co-authored-by: Ashwini Khade <[email protected]>
  • Loading branch information
3 people authored and kleiti committed Mar 22, 2024
1 parent e0a68a7 commit 2cffc72
Show file tree
Hide file tree
Showing 10 changed files with 441 additions and 92 deletions.
146 changes: 135 additions & 11 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {SessionHandler, TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';

type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
type FetchesType = InferenceSession.FetchesType;
type ReturnType = InferenceSession.ReturnType;
type RunOptions = InferenceSession.RunOptions;

const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';

Expand Down Expand Up @@ -42,21 +49,138 @@ export class TrainingSession implements TrainingSessionInterface {
}
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
/**
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
[SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
throw new TypeError(
'\'feeds\' must be an object that use input names as keys and OnnxValue as corresponding values.');
}

let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof Tensor) {
throw new TypeError('\'fetches\' cannot be a Tensor');
}

if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError('\'fetches\' cannot be an empty array.');
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}

if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}

if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError('\'options\' must be an object.');
}
} else {
options = arg1 as RunOptions;
}
}
} else if (typeof arg1 !== 'undefined') {
throw new TypeError('Unexpected argument[1]: must be \'fetches\' or \'options\'.');
}

// check if all inputs are in feed
for (const name of this.inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}

// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
fetches[name] = null;
}
}

return [fetches, options];
}

async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
/**
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
* and changes it into a map of Tensors.
*
* @param results
* @returns
*/
convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
const returnValue: {[name: string]: OnnxValue} = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof Tensor) {
returnValue[key] = result;
} else {
returnValue[key] = new Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}

runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}

async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}

runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined):
Promise<InferenceSession.OnnxValueMapType>;
runTrainStep(
feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions|undefined): Promise<InferenceSession.OnnxValueMapType>;
async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown):
Promise<InferenceSession.OnnxValueMapType> {
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}

Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-onnxjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {Session} from './onnxjs/session';
import {OnnxjsSessionHandler} from './onnxjs/session-handler';
import {OnnxjsSessionHandler} from './onnxjs/session-handler-inference';

class OnnxjsBackend implements Backend {
// eslint-disable-next-line @typescript-eslint/no-empty-function
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm-training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';

import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-training';

class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
* This function initializes all flags for WebAssembly.
Expand Down
File renamed without changes.
73 changes: 0 additions & 73 deletions js/web/lib/wasm/session-handler-for-training.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {isGpuBufferSupportedType} from './wasm-common';

let runtimeInitializationPromise: Promise<void>|undefined;

const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
switch (tensor.location) {
case 'cpu':
return [tensor.type, tensor.dims, tensor.data, 'cpu'];
Expand All @@ -21,7 +21,7 @@ const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMeta
}
};

const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
switch (tensor[3]) {
case 'cpu':
return new Tensor(tensor[0], tensor[2], tensor[1]);
Expand Down
Loading

0 comments on commit 2cffc72

Please sign in to comment.