Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] revise backend registration #18715

Merged
merged 14 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise<Ba
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init();
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
Expand Down
2 changes: 1 addition & 1 deletion js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export interface Backend {
/**
* Initialize the backend asynchronously. Should throw when failed.
*/
init(): Promise<void>;
init(backendName: string): Promise<void>;

createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
17 changes: 14 additions & 3 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';

import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';

/**
Expand Down Expand Up @@ -33,12 +33,23 @@ export const initializeFlags = (): void => {
};

export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
/**
* This function initializes the WebAssembly backend.
*
* This function will be called only once for each backend name. It will be called the first time when
* `ort.InferenceSession.create()` is called with a registered backend name.
*
* @param backendName - the registered backend name.
*/
async init(backendName: string): Promise<void> {
// populate wasm flags
initializeFlags();

// init wasm
await initializeWebAssemblyInstance();
await initializeWebAssemblyAndOrtRuntime();

// performe EP specific initialization
await initializeOrtEp(backendName);
}
createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
Expand Down
12 changes: 1 addition & 11 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,7 @@ export class WebGpuBackend {
*/
sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();

async initialize(env: Env): Promise<void> {
if (!navigator.gpu) {
// WebGPU is not available.
throw new Error('WebGpuBackend: WebGPU is not available.');
}

const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error('WebGpuBackend: Failed to get GPU adapter.');
}

async initialize(env: Env, adapter: GPUAdapter): Promise<void> {
this.env = env;
const requiredFeatures: GPUFeatureName[] = [];
const deviceDescriptor: GPUDeviceDescriptor = {
Expand Down
130 changes: 71 additions & 59 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,64 +130,76 @@ class ComputeContextImpl implements ComputeContext {
}
}

export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
const init = module.jsepInit;
if (init && navigator.gpu) {
if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP');
}
const backend = new WebGpuBackend();
await backend.initialize(env);

init(
// backend
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),

// jsepFree()
(ptr: number) => backend.free(ptr),

// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src, src + size);
backend.upload(dst, data);
}
},

// jsepCopyAsync(src, dst, size)
async(gpuDataId: number, dataOffset: number, size: number):
Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu' backend.
*
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}

const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter);

jsepInit(
// backend
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),

// jsepFree()
(ptr: number) => backend.free(ptr),

// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src, src + size);
backend.upload(dst, data);
}
},

// jsepCopyAsync(src, dst, size)
async(gpuDataId: number, dataOffset: number, size: number):
Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
},

// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
};
53 changes: 31 additions & 22 deletions js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import type {Env, InferenceSession, Tensor} from 'onnxruntime-common';

/**
* Among all the tensor locations, only 'cpu' is serializable.
*/
export type SerializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];

Expand All @@ -12,51 +15,61 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};

/**
* Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable.
*/
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];

/**
* Tensor metadata is a tuple of [dataType, dims, data, location], where
* - dataType: tensor data type
* - dims: tensor dimensions
* - data: tensor data, which can be one of the following depending on the location:
* - cpu: Uint8Array
* - cpu-pinned: Uint8Array
* - gpu-buffer: GpuBufferMetadata
* - location: tensor data location
*/
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;

export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];

export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number];
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

interface MessageError {
err?: string;
}

interface MessageInitWasm extends MessageError {
type: 'init-wasm';
in ?: Env.WebAssemblyFlags;
}

interface MessageInitOrt extends MessageError {
type: 'init-ort';
in ?: Env;
out?: never;
}

interface MessageCreateSessionAllocate extends MessageError {
type: 'create_allocate';
in ?: {model: Uint8Array};
out?: SerializableModeldata;
interface MessageInitEp extends MessageError {
type: 'init-ep';
in ?: {env: Env; epName: string};
out?: never;
}

interface MessageCreateSessionFinalize extends MessageError {
type: 'create_finalize';
in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
interface MessageCopyFromExternalBuffer extends MessageError {
type: 'copy-from';
in ?: {buffer: Uint8Array};
out?: SerializableInternalBuffer;
}

interface MessageCreateSession extends MessageError {
type: 'create';
in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions};
in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}

interface MessageReleaseSession extends MessageError {
type: 'release';
in ?: number;
out?: never;
}

interface MessageRun extends MessageError {
Expand All @@ -71,12 +84,8 @@ interface MessageRun extends MessageError {
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
out?: never;
}

interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}

export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession|
MessageReleaseSession|MessageRun|MesssageEndProfiling;
Loading
Loading