Skip to content

Commit

Permalink
Merge branch 'main' into yufeng/turnon_neural_speed
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Mar 18, 2024
2 parents f12a66d + 4d31076 commit e48225f
Show file tree
Hide file tree
Showing 33 changed files with 563 additions and 320 deletions.
121 changes: 90 additions & 31 deletions js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {Backend} from './backend.js';
import {InferenceSession} from './inference-session.js';

interface BackendInfo {
backend: Backend;
Expand All @@ -10,6 +11,7 @@ interface BackendInfo {
initPromise?: Promise<void>;
initialized?: boolean;
aborted?: boolean;
error?: string;
}

const backends: Map<string, BackendInfo> = new Map();
Expand Down Expand Up @@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number
};

/**
* Resolve backend by specified hints.
* Try to resolve and initialize a backend.
*
* @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list.
* @returns a promise that resolves to the backend.
* @param backendName - the name of the backend.
* @returns the backend instance if resolved and initialized successfully, or an error message if failed.
*/
const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backend|string> => {
const backendInfo = backends.get(backendName);
if (!backendInfo) {
return 'backend not found.';
}

if (backendInfo.initialized) {
return backendInfo.backend;
} else if (backendInfo.aborted) {
return backendInfo.error!;
} else {
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
} catch (e) {
if (!isInitializing) {
backendInfo.error = `${e}`;
backendInfo.aborted = true;
}
return backendInfo.error!;
} finally {
delete backendInfo.initPromise;
}
}
};

/**
* Resolve execution providers from the specific session options.
*
* @param options - the session options object.
* @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
* filtered EP list.
*
* @ignore
*/
export const resolveBackend = async(backendHints: readonly string[]): Promise<Backend> => {
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
const errors = [];
for (const backendName of backendNames) {
const backendInfo = backends.get(backendName);
if (backendInfo) {
if (backendInfo.initialized) {
return backendInfo.backend;
} else if (backendInfo.aborted) {
continue; // current backend is unavailable; try next
}
export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
// extract backend hints from session options
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;

const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
// try to resolve and initialize all requested backends
let backend: Backend|undefined;
const errors = [];
const availableBackendNames = new Set<string>();
for (const backendName of backendNames) {
const resolveResult = await tryResolveAndInitializeBackend(backendName);
if (typeof resolveResult === 'string') {
errors.push({name: backendName, err: resolveResult});
} else {
if (!backend) {
backend = resolveResult;
}
if (backend === resolveResult) {
availableBackendNames.add(backendName);
}
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
} catch (e) {
if (!isInitializing) {
errors.push({name: backendName, err: e});
}

// if no backend is available, throw error.
if (!backend) {
throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
}

// for each explicitly requested backend, if it's not available, output warning message.
for (const {name, err} of errors) {
if (backendHints.includes(name)) {
// eslint-disable-next-line no-console
console.warn(`removing requested execution provider "${
name}" from session options because it is not available: ${err}`);
}
backendInfo.aborted = true;
} finally {
delete backendInfo.initPromise;
}
}
}

throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
};
const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name));

return [
backend, new Proxy(options, {
get: (target, prop) => {
if (prop === 'executionProviders') {
return filteredEps;
}
return Reflect.get(target, prop);
}
})
];
};
6 changes: 3 additions & 3 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler {
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;

getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

Expand All @@ -77,8 +77,8 @@ export interface Backend {
Promise<InferenceSessionHandler>;

createTrainingSessionHandler?
(checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer,
evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer,
(checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler>;
}

Expand Down
4 changes: 2 additions & 2 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ export declare namespace Env {
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
*
* see comments on {@link GpuBufferType}
* see comments on {@link Tensor.GpuBufferType}
*/
readonly adapter: unknown;
/**
Expand All @@ -184,7 +184,7 @@ export declare namespace Env {
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
*
* see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types".
* see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
*/
readonly device: unknown;
/**
Expand Down
5 changes: 4 additions & 1 deletion js/common/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native)
*
* See also:
* - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html)
* - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/)
* - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js)
*
* @packageDocumentation
Expand All @@ -21,6 +21,9 @@ export * from './backend.js';
export * from './env.js';
export * from './inference-session.js';
export * from './tensor.js';
export * from './tensor-conversion.js';
export * from './tensor-factory.js';
export * from './trace.js';
export * from './onnx-model.js';
export * from './onnx-value.js';
export * from './training-session.js';
10 changes: 4 additions & 6 deletions js/common/lib/inference-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
Expand Down Expand Up @@ -195,11 +195,9 @@ export class InferenceSession implements InferenceSessionInterface {
throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.');
}

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs);
TRACE_FUNC_END();
return new InferenceSession(handler);
}
Expand Down
43 changes: 35 additions & 8 deletions js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,22 @@ export declare namespace InferenceSession {
// #region execution providers

// Currently, we have the following backends to support execution providers:
// Backend Node.js binding: supports 'cpu' and 'cuda'.
// Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux).
// Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'.
// Backend ONNX.js: supports 'webgl'.
// Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android).
interface ExecutionProviderOptionMap {
coreml: CoreMLExecutionProviderOption;
cpu: CpuExecutionProviderOption;
coreml: CoreMlExecutionProviderOption;
cuda: CudaExecutionProviderOption;
dml: DmlExecutionProviderOption;
nnapi: NnapiExecutionProviderOption;
tensorrt: TensorRtExecutionProviderOption;
wasm: WebAssemblyExecutionProviderOption;
webgl: WebGLExecutionProviderOption;
xnnpack: XnnpackExecutionProviderOption;
webgpu: WebGpuExecutionProviderOption;
webnn: WebNNExecutionProviderOption;
nnapi: NnapiExecutionProviderOption;
xnnpack: XnnpackExecutionProviderOption;
}

type ExecutionProviderName = keyof ExecutionProviderOptionMap;
Expand All @@ -219,10 +219,6 @@ export declare namespace InferenceSession {
readonly name: 'cuda';
deviceId?: number;
}
export interface CoreMlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'coreml';
coreMlFlags?: number;
}
export interface DmlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'dml';
deviceId?: number;
Expand Down Expand Up @@ -253,8 +249,39 @@ export declare namespace InferenceSession {
}
export interface CoreMLExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'coreml';
/**
* The bit flags for CoreML execution provider.
*
* ```
* COREML_FLAG_USE_CPU_ONLY = 0x001
* COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002
* COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004
* COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008
* COREML_FLAG_CREATE_MLPROGRAM = 0x010
* ```
*
* See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details.
*
* This flag is available only in ONNXRuntime (Node.js binding).
*/
coreMlFlags?: number;
/**
* Specify whether to use CPU only in CoreML EP.
*
* This setting is available only in ONNXRuntime (react-native).
*/
useCPUOnly?: boolean;
/**
* Specify whether to enable CoreML EP on subgraph.
*
* This setting is available only in ONNXRuntime (react-native).
*/
enableOnSubgraph?: boolean;
/**
* Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine).
*
* This setting is available only in ONNXRuntime (react-native).
*/
onlyEnableDeviceWithANE?: boolean;
}
export interface NnapiExecutionProviderOption extends ExecutionProviderOption {
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/onnx-value.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Tensor} from './tensor.js';

type NonTensorType = never;
export type NonTensorType = never;

/**
* Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs.
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/tensor-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ export interface TensorFactory {
/**
* create a tensor from an ImageBitmap object
*
* @param bitMap - the ImageBitmap object to create tensor from
* @param bitmap - the ImageBitmap object to create tensor from
* @param options - An optional object representing options for creating tensor from URL.
*
* The following default settings will be applied:
Expand Down
4 changes: 2 additions & 2 deletions js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase<Tensor.Type>, TypedTensorUtils<T
/**
* type TensorConstructor defines the constructors of 'Tensor' to create CPU tensor instances.
*/
export interface TensorConstructor {
export interface TensorConstructor extends TensorFactory {
// #region CPU tensor - specify element type
/**
* Construct a new string tensor object from the given type, data and dims.
Expand Down Expand Up @@ -326,4 +326,4 @@ export interface TensorConstructor {
}

// eslint-disable-next-line @typescript-eslint/naming-convention
export const Tensor = TensorImpl as (TensorConstructor & TensorFactory);
export const Tensor = TensorImpl as TensorConstructor;
9 changes: 9 additions & 0 deletions js/common/lib/trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import {env} from './env-impl.js';

/**
* @ignore
*/
export const TRACE = (deviceType: string, label: string) => {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
Expand All @@ -29,13 +32,19 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => {
}
};

/**
* @ignore
*/
export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
}
TRACE_FUNC('BEGIN', extraMsg);
};

/**
* @ignore
*/
export const TRACE_FUNC_END = (extraMsg?: string) => {
if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
return;
Expand Down
11 changes: 5 additions & 6 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {SessionHandler, TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
Expand Down Expand Up @@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface {
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel,
optionsWithValidatedEPs);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
Expand Down
Loading

0 comments on commit e48225f

Please sign in to comment.