Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] revise backend registration #18715

Merged
merged 14 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise<Ba
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init();
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export interface Backend {
/**
* Initialize the backend asynchronously. Should throw when failed.
*/
init(): Promise<void>;
init(backendName: string): Promise<void>;

createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
17 changes: 14 additions & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
Expand Down Expand Up @@ -33,12 +33,23 @@ export const initializeFlags = (): void => {
};

export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
/**
* This function initializes the WebAssembly backend.
*
* This function will be called only once for each backend name. It will be called the first time when
* `ort.InferenceSession.create()` is called with a registered backend name.
*
* @param backendName - the registered backend name.
*/
async init(backendName: string): Promise<void> {
// populate wasm flags
initializeFlags();

// init wasm
await initializeWebAssemblyInstance();
await initializeWebAssemblyAndOrtRuntime();

// performe EP specific initialization
await initializeOrtEp(backendName);
}
createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
Expand Down
12 changes: 1 addition & 11 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,7 @@ export class WebGpuBackend {
*/
sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();

async initialize(env: Env): Promise<void> {
if (!navigator.gpu) {
// WebGPU is not available.
throw new Error('WebGpuBackend: WebGPU is not available.');
}

const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error('WebGpuBackend: Failed to get GPU adapter.');
}

async initialize(env: Env, adapter: GPUAdapter): Promise<void> {
this.env = env;
const requiredFeatures: GPUFeatureName[] = [];
const deviceDescriptor: GPUDeviceDescriptor = {
Expand Down
22 changes: 18 additions & 4 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,31 @@ class ComputeContextImpl implements ComputeContext {
}
}

/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function is expected to be only available when WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* If WebGPU is not available in current environment, this function should not throw, but simply return.
*
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
*/
export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
const init = module.jsepInit;
if (init && navigator.gpu) {
const jsepInit = module.jsepInit;
if (!jsepInit) {
return;
}
const gpuAdapter = await navigator.gpu?.requestAdapter();
if (gpuAdapter) {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
if (!env.wasm.simd) {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP');
}
const backend = new WebGpuBackend();
await backend.initialize(env);
await backend.initialize(env, gpuAdapter);

init(
jsepInit(
// backend
backend,

Expand Down
53 changes: 31 additions & 22 deletions js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import type {Env, InferenceSession, Tensor} from 'onnxruntime-common';

/**
* Among all the tensor locations, only 'cpu' is serializable.
*/
export type SerializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];

Expand All @@ -12,51 +15,61 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};

/**
* Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable.
*/
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];

/**
* Tensor metadata is a tuple of [dataType, dims, data, location], where
* - dataType: tensor data type
* - dims: tensor dimensions
* - data: tensor data, which can be one of the following depending on the location:
* - cpu: Uint8Array
* - cpu-pinned: Uint8Array
* - gpu-buffer: GpuBufferMetadata
* - location: tensor data location
*/
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;

export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];

export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number];
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

interface MessageError {
err?: string;
}

interface MessageInitWasm extends MessageError {
type: 'init-wasm';
in ?: Env.WebAssemblyFlags;
}

interface MessageInitOrt extends MessageError {
type: 'init-ort';
in ?: Env;
out?: never;
}

interface MessageCreateSessionAllocate extends MessageError {
type: 'create_allocate';
in ?: {model: Uint8Array};
out?: SerializableModeldata;
interface MessageInitEp extends MessageError {
type: 'init-ep';
in ?: {env: Env; epName: string};
out?: never;
}

interface MessageCreateSessionFinalize extends MessageError {
type: 'create_finalize';
in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
interface MessageCopyFromExternalBuffer extends MessageError {
type: 'copy-from';
in ?: {buffer: Uint8Array};
out?: SerializableInternalBuffer;
}

interface MessageCreateSession extends MessageError {
type: 'create';
in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions};
in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}

interface MessageReleaseSession extends MessageError {
type: 'release';
in ?: number;
out?: never;
}

interface MessageRun extends MessageError {
Expand All @@ -71,12 +84,8 @@ interface MessageRun extends MessageError {
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
out?: never;
}

interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession|
MessageReleaseSession|MessageRun|MesssageEndProfiling;
136 changes: 57 additions & 79 deletions js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,104 +36,82 @@ declare global {
}

import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {createSession, copyFromExternalBuffer, endProfiling, extractTransferableBuffers, initEp, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';

self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
switch (ev.data.type) {
case 'init-wasm':
try {
initializeWebAssembly(ev.data.in!)
const {type, in : message} = ev.data;
try {
switch (type) {
case 'init-wasm':
initializeWebAssembly(message!.wasm)
.then(
() => postMessage({type: 'init-wasm'} as OrtWasmMessage),
err => postMessage({type: 'init-wasm', err} as OrtWasmMessage));
} catch (err) {
postMessage({type: 'init-wasm', err} as OrtWasmMessage);
}
break;
case 'init-ort':
try {
initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({
type: 'init-ort',
err
} as OrtWasmMessage));
} catch (err) {
postMessage({type: 'init-ort', err} as OrtWasmMessage);
}
break;
case 'create_allocate':
try {
const {model} = ev.data.in!;
const modeldata = createSessionAllocate(model);
postMessage({type: 'create_allocate', out: modeldata} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'create_allocate', err} as OrtWasmMessage);
() => {
initRuntime(message!).then(
() => {
postMessage({type});
},
err => {
postMessage({type, err});
});
},
err => {
postMessage({type, err});
});
break;
case 'init-ep': {
const {epName, env} = message!;
initEp(env, epName)
.then(
() => {
postMessage({type});
},
err => {
postMessage({type, err});
});
break;
}
break;
case 'create_finalize':
try {
const {modeldata, options} = ev.data.in!;
const sessionMetadata = createSessionFinalize(modeldata, options);
postMessage({type: 'create_finalize', out: sessionMetadata} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'create_finalize', err} as OrtWasmMessage);
case 'copy-from': {
const {buffer} = message!;
const bufferData = copyFromExternalBuffer(buffer);
postMessage({type, out: bufferData} as OrtWasmMessage);
break;
}
break;
case 'create':
try {
const {model, options} = ev.data.in!;
case 'create': {
const {model, options} = message!;
const sessionMetadata = createSession(model, options);
postMessage({type: 'create', out: sessionMetadata} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'create', err} as OrtWasmMessage);
postMessage({type, out: sessionMetadata} as OrtWasmMessage);
break;
}
break;
case 'release':
try {
releaseSession(ev.data.in!);
postMessage({type: 'release'} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'release', err} as OrtWasmMessage);
}
break;
case 'run':
try {
const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!;
case 'release':
releaseSession(message!);
postMessage({type});
break;
case 'run': {
const {sessionId, inputIndices, inputs, outputIndices, options} = message!;
run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options)
.then(
outputs => {
if (outputs.some(o => o[3] !== 'cpu')) {
postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'});
postMessage({type, err: 'Proxy does not support non-cpu tensor location.'});
} else {
postMessage(
{type: 'run', out: outputs} as OrtWasmMessage,
{type, out: outputs} as OrtWasmMessage,
extractTransferableBuffers(outputs as SerializableTensorMetadata[]));
}
},
err => {
postMessage({type: 'run', err} as OrtWasmMessage);
postMessage({type, err});
});
} catch (err) {
postMessage({type: 'run', err} as OrtWasmMessage);
}
break;
case 'end-profiling':
try {
const handler = ev.data.in!;
endProfiling(handler);
postMessage({type: 'end-profiling'} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
case 'is-ort-env-initialized':
try {
const ortEnvInitialized = isOrtEnvInitialized();
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
break;
}
break;
default:
case 'end-profiling':
endProfiling(message!);
postMessage({type});
break;
default:
}
} catch (err) {
postMessage({type, err} as OrtWasmMessage);
}
};
Loading
Loading