Skip to content

Commit

Permalink
[js] enable external data loading for ort-web (#19087)
Browse files Browse the repository at this point in the history
### Description
enable external data loading for ort-web.

### Why
The ORT external data design is highly depending on the file system,
especially synchronous file I/O APIs. Those are not available in web
platforms. We need to have extra code to make external data working on
web.

### How
Considering there is no file system in web, an implementation for web to
support external data is to use pre-loaded data. Assume model file
a.onnx includes initializers that linked to ./b.bin, we require users to
pass a full data file list when creating the session. The user code will
be look like:
```js
const mySess = await ort.InferenceSession.create('./path/model/a.onnx', {
  // session options
  externalData: [
    {
      // relative or absolute path/URL of the file,
      // or a pre-loaded Uint8Array containing the data of the external data file
      data: './path/data/b.bin', 

      // the relative path of the external data. Should match initializers' "location" value defined in the model file
      path: './b.bin'
    },
    // { } if multiple external data file
  ]
});
```

Currently, this feature only works with JSEP build enabled.
  • Loading branch information
fs-eire authored Jan 13, 2024
1 parent e5eacc6 commit 07cfc56
Show file tree
Hide file tree
Showing 19 changed files with 420 additions and 108 deletions.
3 changes: 2 additions & 1 deletion js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js';
import {OnnxModelOptions} from './onnx-model.js';
import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js';

/* eslint-disable @typescript-eslint/no-redeclare */
Expand Down Expand Up @@ -43,7 +44,7 @@ export declare namespace InferenceSession {
/**
* A set of configurations for session behavior.
*/
export interface SessionOptions {
export interface SessionOptions extends OnnxModelOptions {
/**
* An array of execution provider options.
*
Expand Down
57 changes: 57 additions & 0 deletions js/common/lib/onnx-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/**
* A string that represents a file's URL or path.
*
* Path is vailable only in onnxruntime-node or onnxruntime-web running in Node.js.
*/
export type FileUrlOrPath = string;

/**
* A Blob object that represents a file.
*/
export type FileBlob = Blob;

/**
* A Uint8Array, ArrayBuffer or SharedArrayBuffer object that represents a file content.
*
* When it is an ArrayBuffer or SharedArrayBuffer, the whole buffer is assumed to be the file content.
*/
export type FileData = Uint8Array|ArrayBufferLike;

/**
* Represents a file that can be loaded by the ONNX Runtime JavaScript API.
*/
export type FileType = FileUrlOrPath|FileBlob|FileData;

/**
* Represents an external data file.
*/
export interface ExternalDataFileDescription {
/**
* Specify the external data file.
*/
data: FileType;
/**
* Specify the file path.
*/
path: string;
}

/**
* Represents an external data file.
*
* When using a string, it should be a file URL or path that in the same directory as the model file.
*/
export type ExternalDataFileType = ExternalDataFileDescription|FileUrlOrPath;

/**
* Options for model loading.
*/
export interface OnnxModelOptions {
/**
* Specifying a list of files that represents the external data.
*/
externalData?: readonly ExternalDataFileType[];
}
5 changes: 5 additions & 0 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ export interface OrtWasmModule extends EmscriptenModule {
mainScriptUrlOrBlob?: string|Blob;
// #endregion

// #region external data API
mountExternalData?(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
unmountExternalData?(): void;
// #endregion

// #region JSEP
/**
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
Expand Down
10 changes: 8 additions & 2 deletions js/web/lib/wasm/proxy-worker/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
}
case 'create': {
const {model, options} = message!;
const sessionMetadata = createSession(model, options);
postMessage({type, out: sessionMetadata} as OrtWasmMessage);
createSession(model, options)
.then(
sessionMetadata => {
postMessage({type, out: sessionMetadata} as OrtWasmMessage);
},
err => {
postMessage({type, err});
});
break;
}
case 'release':
Expand Down
14 changes: 4 additions & 10 deletions js/web/lib/wasm/session-handler-inference.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {readFile} from 'node:fs/promises';
import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';
import {loadFile} from './wasm-utils-load-file';

export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
switch (tensor.location) {
Expand Down Expand Up @@ -43,14 +43,8 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
outputNames: string[];

async fetchModelAndCopyToWasmMemory(path: string): Promise<SerializableInternalBuffer> {
// fetch model from url and move to wasm heap. The arraybufffer that held the http
// response is freed once we return
const response = await fetch(path);
if (response.status !== 200) {
throw new Error(`failed to load model: ${path}`);
}
const arrayBuffer = await response.arrayBuffer();
return copyFromExternalBuffer(new Uint8Array(arrayBuffer));
// fetch model from url and move to wasm heap.
return copyFromExternalBuffer(await loadFile(path));
}

async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
Expand All @@ -60,7 +54,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
if (typeof pathOrBuffer === 'string') {
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
// node
model = await readFile(pathOrBuffer);
model = await loadFile(pathOrBuffer);
} else {
// browser
// fetch model and copy to wasm heap.
Expand Down
199 changes: 108 additions & 91 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
import {loadFile} from './wasm-utils-load-file';

// #region Initializations

Expand Down Expand Up @@ -187,108 +188,124 @@ export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
* @param options an optional session options object.
* @returns a 3-elements tuple containing [session handle, input names, output names]
*/
export const createSession =
(modelData: Uint8Array|SerializableInternalBuffer,
options?: InferenceSession.SessionOptions): SerializableSessionMetadata => {
let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();

if (Array.isArray(modelData)) {
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
[modelDataOffset, modelDataLength] = modelData;
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
} else {
// otherwise, copy the model data to the WASM heap.
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
}
export const createSession = async(
modelData: Uint8Array|SerializableInternalBuffer,
options?: InferenceSession.SessionOptions): Promise<SerializableSessionMetadata> => {
let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();

let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];
if (Array.isArray(modelData)) {
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
[modelDataOffset, modelDataLength] = modelData;
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
} else {
// otherwise, copy the model data to the WASM heap.
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
}

try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
let allocs: number[] = [];
const inputNamesUTF8Encoded = [];
const outputNamesUTF8Encoded = [];

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);

if (options?.externalData && wasm.mountExternalData) {
const loadingPromises = [];
for (const file of options.externalData) {
const path = typeof file === 'string' ? file : file.path;
loadingPromises.push(loadFile(typeof file === 'string' ? file : file.data).then(data => {
wasm.mountExternalData!(path, data);
}));
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
// wait for all external data files to be loaded
await Promise.all(loadingPromises);
}

const inputNames = [];
const outputNames = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an input name.');
}
inputNamesUTF8Encoded.push(name);
inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an output name.');
}
outputNamesUTF8Encoded.push(name);
const nameString = wasm.UTF8ToString(name);
outputNames.push(nameString);

if (!BUILD_DEFS.DISABLE_WEBGPU) {
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
outputPreferredLocations.push(location);
}
}
sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}

// use IO binding only when at least one output is preffered to be on GPU.
let bindingState: IOBindingState|null = null;
if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError('Can\'t create IO binding.');
}
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
};
const inputNames = [];
const outputNames = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an input name.');
}
inputNamesUTF8Encoded.push(name);
inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(sessionHandle, i);
if (name === 0) {
checkLastError('Can\'t get an output name.');
}
outputNamesUTF8Encoded.push(name);
const nameString = wasm.UTF8ToString(name);
outputNames.push(nameString);

if (!BUILD_DEFS.DISABLE_WEBGPU) {
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
outputPreferredLocations.push(location);
}
}

activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
// use IO binding only when at least one output is preffered to be on GPU.
let bindingState: IOBindingState|null = null;
if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
if (ioBindingHandle === 0) {
checkLastError('Can\'t create IO binding.');
}

if (ioBindingHandle !== 0) {
wasm._OrtReleaseBinding(ioBindingHandle);
}
bindingState = {
handle: ioBindingHandle,
outputPreferredLocations,
outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
};
}

if (sessionHandle !== 0) {
wasm._OrtReleaseSession(sessionHandle);
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
}
};
activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));

if (ioBindingHandle !== 0) {
wasm._OrtReleaseBinding(ioBindingHandle);
}

if (sessionHandle !== 0) {
wasm._OrtReleaseSession(sessionHandle);
}
throw e;
} finally {
wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));

// unmount external data if necessary
wasm.unmountExternalData?.();
}
};

export const releaseSession = (sessionId: number): void => {
const wasm = getInstance();
Expand Down
Loading

0 comments on commit 07cfc56

Please sign in to comment.