Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

wasm: update backend to consume latest ONNX Runtime #270

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
Please follow the following steps to running tests:

1. run `npm ci` in the root folder of the repo.
2. (Optional) run `npm run build` in the root folder of the repo to enable WebAssebmly features.
2. (Optional) build WebAssembly backend:
1. build ONNX Runtime WebAssembly and copy files "onnxruntime_wasm.\*" to /dist/.
2. if building ONNX Runtime WebAssembly with multi-threads support, copy files "onnxruntime_wasm_threads.\*" to /dist/.
3. run `npm run build` in the root folder of the repo to enable WebAssebmly features.
3. run `npm test` to run suite0 test cases and check the console output.
- if (2) is not run, please run `npm test -- -b=cpu,webgl` to skip WebAssebmly tests

Expand Down
12 changes: 10 additions & 2 deletions karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,18 @@ module.exports = function (config) {
{ pattern: 'test/data/**/*', included: false, nocache: true },
{ pattern: 'deps/data/data/test/**/*', included: false, nocache: true },
{ pattern: 'deps/onnx/onnx/backend/test/data/**/*', included: false, nocache: true },
{ pattern: 'dist/onnx-wasm.wasm', included: false },
{ pattern: 'dist/onnxruntime_wasm.js', included: true },
{ pattern: 'dist/onnxruntime_wasm.wasm', included: false },
{ pattern: 'dist/onnxruntime_wasm_threads.js', included: true },
{ pattern: 'dist/onnxruntime_wasm_threads.wasm', included: false },
{ pattern: 'dist/onnxruntime_wasm_threads.worker.js', included: false },
],
proxies: {
'/onnx-wasm.wasm': '/base/dist/onnx-wasm.wasm',
'/onnxruntime_wasm.js': '/base/dist/onnxruntime_wasm.js',
'/onnxruntime_wasm.wasm': '/base/dist/onnxruntime_wasm.wasm',
'/onnxruntime_wasm_threads.js': '/base/dist/onnxruntime_wasm_threads.js',
'/onnxruntime_wasm_threads.wasm': '/base/dist/onnxruntime_wasm_threads.wasm',
'/onnxruntime_wasm_threads.worker.js': '/base/dist/onnxruntime_wasm_threads.worker.js',
'/onnx-worker.js': '/base/test/onnx-worker.js',
},
plugins: karmaPlugins,
Expand Down
2 changes: 1 addition & 1 deletion lib/api/inference-session-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export class InferenceSession implements InferenceSessionInterface {
output = await this.session.run(modelInputFeed);
} else if (Array.isArray(inputFeed)) {
const modelInputFeed: InternalTensor[] = [];
inputFeed.forEach((value) => {
inputFeed.forEach((value: ApiTensor) => {
modelInputFeed.push(value.internalTensor);
});
output = await this.session.run(modelInputFeed);
Expand Down
184 changes: 173 additions & 11 deletions lib/backends/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,190 @@
import {Backend, InferenceHandler, SessionHandler} from '../../backend';
import {Graph} from '../../graph';
import {Operator} from '../../operators';
import {OpSet, resolveOperator} from '../../opset';
import {OpSet} from '../../opset';
import {Session} from '../../session';
import {CPU_OP_RESOLVE_RULES} from '../cpu/op-resolve-rules';
import {Tensor} from '../../tensor';
import {ProtoUtil} from '../../util';
import {getInstance} from '../../wasm-binding-core';

import {WasmInferenceHandler} from './inference-handler';
import {WASM_OP_RESOLVE_RULES} from './op-resolve-rules';

export class WasmSessionHandler implements SessionHandler {
private opResolveRules: ReadonlyArray<OpSet.ResolveRule>;
constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) {
this.opResolveRules = fallbackToCpuOps ? WASM_OP_RESOLVE_RULES.concat(CPU_OP_RESOLVE_RULES) : WASM_OP_RESOLVE_RULES;
constructor(readonly backend: Backend, readonly context: Session.Context, fallbackToCpuOps: boolean) {}
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator {
throw new Error('Method not implemented.');
}

createInferenceHandler(): InferenceHandler {
return new WasmInferenceHandler(this, this.context.profiler);
}

dispose(): void {}
// vNEXT latest:
ortInit: boolean;
sessionHandle: number;

resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>, graph: Graph): Operator {
const op = resolveOperator(node, opsets, this.opResolveRules);
op.initialize(node.attributes, node, graph);
return op;
inputNames: string[];
inputNamesUTF8Encoded: number[];
outputNames: string[];
outputNamesUTF8Encoded: number[];

loadModel(model: Uint8Array) {
const wasm = getInstance();
if (!this.ortInit) {
wasm._OrtInit();
this.ortInit = true;
}

const modelDataOffset = wasm._malloc(model.byteLength);
try {
wasm.HEAPU8.set(model, modelDataOffset);
this.sessionHandle = wasm._OrtCreateSession(modelDataOffset, model.byteLength);
} finally {
wasm._free(modelDataOffset);
}

const inputCount = wasm._OrtGetInputCount(this.sessionHandle);
const outputCount = wasm._OrtGetOutputCount(this.sessionHandle);

this.inputNames = [];
this.inputNamesUTF8Encoded = [];
this.outputNames = [];
this.outputNamesUTF8Encoded = [];
for (let i = 0; i < inputCount; i++) {
const name = wasm._OrtGetInputName(this.sessionHandle, i);
this.inputNamesUTF8Encoded.push(name);
this.inputNames.push(wasm.UTF8ToString(name));
}
for (let i = 0; i < outputCount; i++) {
const name = wasm._OrtGetOutputName(this.sessionHandle, i);
this.outputNamesUTF8Encoded.push(name);
this.outputNames.push(wasm.UTF8ToString(name));
}
}

run(inputs: Map<string, Tensor>|Tensor[]): Map<string, Tensor> {
const wasm = getInstance();

let inputIndices: number[] = [];
if (!Array.isArray(inputs)) {
const inputArray: Tensor[] = [];
inputs.forEach((tensor, name) => {
const index = this.inputNames.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}'`);
}
inputArray.push(tensor);
inputIndices.push(index);
});
inputs = inputArray;
} else {
inputIndices = inputs.map((t, i) => i);
}

const inputCount = inputs.length;
const outputCount = this.outputNames.length;

const inputValues: number[] = [];
const inputDataOffsets: number[] = [];
// create input tensors
for (let i = 0; i < inputCount; i++) {
const data = inputs[i].numberData;
const dataOffset = wasm._malloc(data.byteLength);
inputDataOffsets.push(dataOffset);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, data.byteLength), dataOffset);

const dims = inputs[i].dims;

const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
ProtoUtil.tensorDataTypeStringToEnum(inputs[i].type), dataOffset, data.byteLength, dimsOffset, dims.length);
inputValues.push(tensor);
} finally {
wasm.stackRestore(stack);
}
}

const beforeRunStack = wasm.stackSave();
const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
try {
let inputValuesIndex = inputValuesOffset / 4;
let inputNamesIndex = inputNamesOffset / 4;
let outputValuesIndex = outputValuesOffset / 4;
let outputNamesIndex = outputNamesOffset / 4;
for (let i = 0; i < inputCount; i++) {
wasm.HEAP32[inputValuesIndex++] = inputValues[i];
wasm.HEAP32[inputNamesIndex++] = this.inputNamesUTF8Encoded[i];
}
for (let i = 0; i < outputCount; i++) {
wasm.HEAP32[outputValuesIndex++] = 0;
wasm.HEAP32[outputNamesIndex++] = this.outputNamesUTF8Encoded[i];
}

wasm._OrtRun(
this.sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset);

const output = new Map<string, Tensor>();

for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];

const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);
try {
wasm._OrtGetTensorData(
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
const dataOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);

const t = new Tensor(dims, ProtoUtil.tensorDataTypeFromProto(dataType));
new Uint8Array(t.numberData.buffer, t.numberData.byteOffset, t.numberData.byteLength)
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + t.numberData.byteLength));
output.set(this.outputNames[i], t);
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
}

wasm._OrtReleaseTensor(tensor);
}

inputValues.forEach(t => wasm._OrtReleaseTensor(t));
inputDataOffsets.forEach(i => wasm._free(i));

return output;
} finally {
wasm.stackRestore(beforeRunStack);
}
}
dispose() {
const wasm = getInstance();
if (this.inputNamesUTF8Encoded) {
this.inputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str));
this.inputNamesUTF8Encoded = [];
}
if (this.outputNamesUTF8Encoded) {
this.outputNamesUTF8Encoded.forEach(str => wasm._OrtFree(str));
this.outputNamesUTF8Encoded = [];
}
if (this.sessionHandle) {
wasm._OrtReleaseSession(this.sessionHandle);
this.sessionHandle = 0;
}
}
}
10 changes: 10 additions & 0 deletions lib/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {readFile} from 'fs';
import {promisify} from 'util';

import {Backend, SessionHandlerType} from './backend';
import {WasmSessionHandler} from './backends/wasm/session-handler';
import {ExecutionPlan} from './execution-plan';
import {Graph} from './graph';
import {Profiler} from './instrument';
Expand Down Expand Up @@ -79,6 +80,11 @@ export class Session {
}

this.profiler.event('session', 'Session.initialize', () => {
if ((this.sessionHandler as {run?: unknown}).run) {
(this.sessionHandler as WasmSessionHandler).loadModel(modelProtoBlob);
return;
}

// load graph
const graphInitializer =
this.sessionHandler.transformGraph ? this.sessionHandler as Graph.Initializer : undefined;
Expand All @@ -104,6 +110,10 @@ export class Session {
}

return this.profiler.event('session', 'Session.run', async () => {
if ((this.sessionHandler as {run?: unknown}).run) {
return (this.sessionHandler as WasmSessionHandler).run(inputs);
}

const inputTensors = this.normalizeAndValidateInputs(inputs);

const outputTensors = await this._executionPlan.execute(this.sessionHandler, inputTensors);
Expand Down
32 changes: 32 additions & 0 deletions lib/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,38 @@ export class ProtoUtil {
}
}

static tensorDataTypeStringToEnum(type: string): onnx.TensorProto.DataType {
switch (type) {
case 'int8':
return onnx.TensorProto.DataType.INT8;
case 'uint8':
return onnx.TensorProto.DataType.UINT8;
case 'bool':
return onnx.TensorProto.DataType.BOOL;
case 'int16':
return onnx.TensorProto.DataType.INT16;
case 'uint16':
return onnx.TensorProto.DataType.UINT16;
case 'int32':
return onnx.TensorProto.DataType.INT32;
case 'uint32':
return onnx.TensorProto.DataType.UINT32;
case 'float32':
return onnx.TensorProto.DataType.FLOAT;
case 'float64':
return onnx.TensorProto.DataType.DOUBLE;
case 'string':
return onnx.TensorProto.DataType.STRING;
case 'int64':
return onnx.TensorProto.DataType.INT64;
case 'uint64':
return onnx.TensorProto.DataType.UINT64;

default:
throw new Error(`unsupported data type: ${type}`);
}
}

static tensorDimsFromProto(dims: Array<number|Long>): number[] {
// get rid of Long type for dims
return dims.map(d => Long.isLong(d) ? d.toNumber() : d);
Expand Down
Loading