Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js] enable external data loading for ort-web #19087

Merged
merged 11 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js';
import {OnnxModelOptions} from './onnx-model.js';
import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -43,7 +44,7 @@ export declare namespace InferenceSession {
/**
* A set of configurations for session behavior.
*/
export interface SessionOptions {
export interface SessionOptions extends OnnxModelOptions {
/**
* An array of execution provider options.
*
Expand Down
57 changes: 57 additions & 0 deletions js/common/lib/onnx-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/**
* A string that represents a file's URL or path.
*
* Path is vailable only in onnxruntime-node or onnxruntime-web running in Node.js.
*/
export type FileUrlOrPath = string;

/**
* A Blob object that represents a file.
*/
export type FileBlob = Blob;

/**
* A Uint8Array, ArrayBuffer or SharedArrayBuffer object that represents a file content.
*
* When it is an ArrayBuffer or SharedArrayBuffer, the whole buffer is assumed to be the file content.
*/
export type FileData = Uint8Array|ArrayBufferLike;

/**
* Represents a file that can be loaded by the ONNX Runtime JavaScript API.
*/
export type FileType = FileUrlOrPath|FileBlob|FileData;

/**
* Represents an external data file.
*/
export interface ExternalDataFileDescription {
/**
* Specify the external data file.
*/
data: FileType;
/**
* Specify the file path.
*/
path: string;
}

/**
* Represents an external data file.
*
* When using a string, it should be a file URL or path that in the same directory as the model file.
*/
export type ExternalDataFileType = ExternalDataFileDescription|FileUrlOrPath;

/**
* Options for model loading.
*/
export interface OnnxModelOptions {
/**
* Specifying a list of files that represents the external data.
*/
externalData?: readonly ExternalDataFileType[];
}
5 changes: 5 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ export interface OrtWasmModule extends EmscriptenModule {
mainScriptUrlOrBlob?: string|Blob;
// #endregion

// #region external data API
mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
unmountExternalData(): void;
// #endregion

// #region JSEP
/**
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
Expand Down
10 changes: 8 additions & 2 deletions js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
}
case 'create': {
const {model, options} = message!;
const sessionMetadata = createSession(model, options);
postMessage({type, out: sessionMetadata} as OrtWasmMessage);
createSession(model, options)
.then(
sessionMetadata => {
postMessage({type, out: sessionMetadata} as OrtWasmMessage);
},
err => {
postMessage({type, err});
});
break;
}
case 'release':
Expand Down
14 changes: 4 additions & 10 deletions js/web/lib/wasm/session-handler-inference.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {readFile} from 'node:fs/promises';
import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';
import {loadFile} from './wasm-utils-load-file';

export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
switch (tensor.location) {
Expand Down Expand Up @@ -43,14 +43,8 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
outputNames: string[];

async fetchModelAndCopyToWasmMemory(path: string): Promise<SerializableInternalBuffer> {
// fetch model from url and move to wasm heap. The arraybufffer that held the http
// response is freed once we return
const response = await fetch(path);
if (response.status !== 200) {
throw new Error(`failed to load model: ${path}`);
}
const arrayBuffer = await response.arrayBuffer();
return copyFromExternalBuffer(new Uint8Array(arrayBuffer));
// fetch model from url and move to wasm heap.
return copyFromExternalBuffer(await loadFile(path));
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
Expand All @@ -60,7 +54,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
if (typeof pathOrBuffer === 'string') {
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
// node
model = await readFile(pathOrBuffer);
model = await loadFile(pathOrBuffer);
} else {
// browser
// fetch model and copy to wasm heap.
Expand Down
199 changes: 108 additions & 91 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
import {loadFile} from './wasm-utils-load-file';

// #region Initializations

Expand Down Expand Up @@ -187,108 +188,124 @@ export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
* @param options an optional session options object.
* @returns a 3-elements tuple containing [session handle, input names, output names]
*/
export const createSession =
(modelData: Uint8Array|SerializableInternalBuffer,
options?: InferenceSession.SessionOptions): SerializableSessionMetadata => {
let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();

if (Array.isArray(modelData)) {
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
[modelDataOffset, modelDataLength] = modelData;
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
} else {
// otherwise, copy the model data to the WASM heap.
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
}
export const createSession = async(
modelData: Uint8Array|SerializableInternalBuffer,
options?: InferenceSession.SessionOptions): Promise<SerializableSessionMetadata> => {
let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();

let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];
if (Array.isArray(modelData)) {
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
[modelDataOffset, modelDataLength] = modelData;
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
} else {
// otherwise, copy the model data to the WASM heap.
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
}

try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);

if (options?.externalData && options.externalData.length > 0) {
const loadingPromises = [];
for (const file of options.externalData) {
const path = typeof file === 'string' ? file : file.path;
loadingPromises.push(loadFile(typeof file === 'string' ? file : file.data).then(data => {
wasm.mountExternalData(path, data);
}));
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
// wait for all external data files to be loaded
await Promise.all(loadingPromises);
}

const inputNames = [];
const outputNames = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an input name.');
}
inputNamesUTF8Encoded.push(name);
inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an output name.');
}
outputNamesUTF8Encoded.push(name);
const nameString = wasm.UTF8ToString(name);
outputNames.push(nameString);

if (!BUILD_DEFS.DISABLE_WEBGPU) {
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
outputPreferredLocations.push(location);
}
}
sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}

// use IO binding only when at least one output is preffered to be on GPU.
let bindingState: IOBindingState|null = null;
if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError('Can\'t create IO binding.');
}
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
};
const inputNames = [];
const outputNames = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an input name.');
}
inputNamesUTF8Encoded.push(name);
inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an output name.');
}
outputNamesUTF8Encoded.push(name);
const nameString = wasm.UTF8ToString(name);
outputNames.push(nameString);

if (!BUILD_DEFS.DISABLE_WEBGPU) {
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
outputPreferredLocations.push(location);
}
}

activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
// use IO binding only when at least one output is preffered to be on GPU.
let bindingState: IOBindingState|null = null;
if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError('Can\'t create IO binding.');
}

if (ioBindingHandle !== 0) {
wasm._OrtReleaseBinding(ioBindingHandle);
}
bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
};
}

if (sessionHandle !== 0) {
wasm._OrtReleaseSession(sessionHandle);
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
}
};
activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));

if (ioBindingHandle !== 0) {
wasm._OrtReleaseBinding(ioBindingHandle);
}

if (sessionHandle !== 0) {
wasm._OrtReleaseSession(sessionHandle);
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));

// unmount external data if necessary
wasm.unmountExternalData();
}
};

export const releaseSession = (sessionId: number): void => {
const wasm = getInstance();
Expand Down
Loading
Loading