Skip to content

Commit

Permalink
support WebGPU EP in Node.js binding (microsoft#22660)
Browse files Browse the repository at this point in the history
### Description

This change enhances the Node.js binding with the following features:
- support WebGPU EP
- lazy initialization of `OrtEnv`
- being able to initialize ORT with default log level setting from
`ort.env.logLevel`.
- session options:
  - `enableProfiling` and `profileFilePrefix`: support profiling.
  - `externalData`: explicit external data (optional in Node.js binding)
- `optimizedModelFilePath`: allow dumping optimized model for diagnosis
purpose
  - `preferredOutputLocation`: support IO binding.

======================================================
`Tensor.download()` is not implemented in this PR.
Build pipeline update is not included in this PR.
  • Loading branch information
fs-eire authored and ankitm3k committed Dec 11, 2024
1 parent 703e951 commit dfada06
Show file tree
Hide file tree
Showing 10 changed files with 479 additions and 63 deletions.
6 changes: 5 additions & 1 deletion js/node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.11)

project (onnxruntime-node)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)

add_compile_definitions(NAPI_VERSION=${napi_build_version})
add_compile_definitions(ORT_API_MANUAL_INIT)
Expand Down Expand Up @@ -34,6 +34,7 @@ include_directories(${CMAKE_SOURCE_DIR}/node_modules/node-addon-api)

# optional providers
option(USE_DML "Build with DirectML support" OFF)
option(USE_WEBGPU "Build with WebGPU support" OFF)
option(USE_CUDA "Build with CUDA support" OFF)
option(USE_TENSORRT "Build with TensorRT support" OFF)
option(USE_COREML "Build with CoreML support" OFF)
Expand All @@ -42,6 +43,9 @@ option(USE_QNN "Build with QNN support" OFF)
if(USE_DML)
add_compile_definitions(USE_DML=1)
endif()
if(USE_WEBGPU)
add_compile_definitions(USE_WEBGPU=1)
endif()
if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)
endif()
Expand Down
10 changes: 7 additions & 3 deletions js/node/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import { Backend, InferenceSession, InferenceSessionHandler, SessionHandler } from 'onnxruntime-common';

import { Binding, binding } from './binding';
import { Binding, binding, initOrt } from './binding';

class OnnxruntimeSessionHandler implements InferenceSessionHandler {
#inferenceSession: Binding.InferenceSession;

constructor(pathOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions) {
initOrt();

this.#inferenceSession = new binding.InferenceSession();
if (typeof pathOrBuffer === 'string') {
this.#inferenceSession.loadModel(pathOrBuffer, options);
Expand All @@ -27,10 +29,12 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
readonly outputNames: string[];

startProfiling(): void {
// TODO: implement profiling
// startProfiling is a no-op.
//
// if sessionOptions.enableProfiling is true, profiling will be enabled when the model is loaded.
}
endProfiling(): void {
// TODO: implement profiling
this.#inferenceSession.endProfiling();
}

async run(
Expand Down
35 changes: 34 additions & 1 deletion js/node/lib/binding.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { InferenceSession, OnnxValue } from 'onnxruntime-common';
import { InferenceSession, OnnxValue, Tensor, TensorConstructor, env } from 'onnxruntime-common';

type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = {
Expand All @@ -28,6 +28,8 @@ export declare namespace Binding {

run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;

endProfiling(): void;

dispose(): void;
}

Expand All @@ -48,4 +50,35 @@ export const binding =
// eslint-disable-next-line @typescript-eslint/naming-convention
InferenceSession: Binding.InferenceSessionConstructor;
listSupportedBackends: () => Binding.SupportedBackend[];
initOrtOnce: (logLevel: number, tensorConstructor: TensorConstructor) => void;
};

let ortInitialized = false;
export const initOrt = (): void => {
if (!ortInitialized) {
ortInitialized = true;
let logLevel = 2;
if (env.logLevel) {
switch (env.logLevel) {
case 'verbose':
logLevel = 0;
break;
case 'info':
logLevel = 1;
break;
case 'warning':
logLevel = 2;
break;
case 'error':
logLevel = 3;
break;
case 'fatal':
logLevel = 4;
break;
default:
throw new Error(`Unsupported log level: ${env.logLevel}`);
}
}
binding.initOrtOnce(logLevel, Tensor);
}
};
5 changes: 5 additions & 0 deletions js/node/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator'];
const REBUILD = !!buildArgs.rebuild;
// --use_dml
const USE_DML = !!buildArgs.use_dml;
// --use_webgpu
const USE_WEBGPU = !!buildArgs.use_webgpu;
// --use_cuda
const USE_CUDA = !!buildArgs.use_cuda;
// --use_tensorrt
Expand Down Expand Up @@ -65,6 +67,9 @@ if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') {
if (USE_DML) {
args.push('--CDUSE_DML=ON');
}
if (USE_WEBGPU) {
args.push('--CDUSE_WEBGPU=ON');
}
if (USE_CUDA) {
args.push('--CDUSE_CUDA=ON');
}
Expand Down
118 changes: 100 additions & 18 deletions js/node/src/inference_session_wrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
#include "tensor_helper.h"
#include <string>

Napi::FunctionReference InferenceSessionWrap::constructor;
Napi::FunctionReference InferenceSessionWrap::wrappedSessionConstructor;
Napi::FunctionReference InferenceSessionWrap::ortTensorConstructor;

Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
return InferenceSessionWrap::ortTensorConstructor;
}

Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
#if defined(USE_DML) && defined(_WIN32)
Expand All @@ -23,28 +28,51 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
Ort::Global<void>::api_ == nullptr, env,
"Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version "
"ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library).");
auto ortEnv = new Ort::Env{ORT_LOGGING_LEVEL_WARNING, "onnxruntime-node"};
env.SetInstanceData(ortEnv);

// initialize binding
Napi::HandleScope scope(env);

Napi::Function func = DefineClass(
env, "InferenceSession",
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel),
InstanceMethod("run", &InferenceSessionWrap::Run),
InstanceMethod("dispose", &InferenceSessionWrap::Dispose),
InstanceMethod("endProfiling", &InferenceSessionWrap::EndProfiling),
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});

constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
wrappedSessionConstructor = Napi::Persistent(func);
wrappedSessionConstructor.SuppressDestruct();
exports.Set("InferenceSession", func);

Napi::Function listSupportedBackends = Napi::Function::New(env, InferenceSessionWrap::ListSupportedBackends);
exports.Set("listSupportedBackends", listSupportedBackends);

Napi::Function initOrtOnce = Napi::Function::New(env, InferenceSessionWrap::InitOrtOnce);
exports.Set("initOrtOnce", initOrtOnce);

return exports;
}

Napi::Value InferenceSessionWrap::InitOrtOnce(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

int log_level = info[0].As<Napi::Number>().Int32Value();

Ort::Env* ortEnv = env.GetInstanceData<Ort::Env>();
if (ortEnv == nullptr) {
ortEnv = new Ort::Env{OrtLoggingLevel(log_level), "onnxruntime-node"};
env.SetInstanceData(ortEnv);
}

Napi::Function tensorConstructor = info[1].As<Napi::Function>();
ortTensorConstructor = Napi::Persistent(tensorConstructor);
ortTensorConstructor.SuppressDestruct();

return env.Undefined();
}

InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo& info)
: Napi::ObjectWrap<InferenceSessionWrap>(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {}

Expand Down Expand Up @@ -118,6 +146,12 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo& info) {
? typeInfo.GetTensorTypeAndShapeInfo().GetElementType()
: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
}

// cache preferred output locations
ParsePreferredOutputLocations(info[argsLength - 1].As<Napi::Object>(), outputNames_, preferredOutputLocations_);
if (preferredOutputLocations_.size() > 0) {
ioBinding_ = std::make_unique<Ort::IoBinding>(*session_);
}
} catch (Napi::Error const& e) {
throw e;
} catch (std::exception const& e) {
Expand Down Expand Up @@ -167,15 +201,16 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
std::vector<bool> reuseOutput;
size_t inputIndex = 0;
size_t outputIndex = 0;
OrtMemoryInfo* memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault).release();
Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault};

try {
for (auto& name : inputNames_) {
if (feed.Has(name)) {
inputIndex++;
inputNames_cstr.push_back(name.c_str());
auto value = feed.Get(name);
inputValues.push_back(NapiValueToOrtValue(env, value, memory_info));
inputValues.push_back(NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
}
}
for (auto& name : outputNames_) {
Expand All @@ -184,7 +219,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
outputNames_cstr.push_back(name.c_str());
auto value = fetch.Get(name);
reuseOutput.push_back(!value.IsNull());
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, memory_info));
outputValues.emplace_back(value.IsNull() ? Ort::Value{nullptr} : NapiValueToOrtValue(env, value, cpuMemoryInfo, gpuBufferMemoryInfo));
}
}

Expand All @@ -193,19 +228,47 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) {
runOptions = Ort::RunOptions{};
ParseRunOptions(info[2].As<Napi::Object>(), runOptions);
}
if (preferredOutputLocations_.size() == 0) {
session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);

session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions,
inputIndex == 0 ? nullptr : &inputNames_cstr[0], inputIndex == 0 ? nullptr : &inputValues[0],
inputIndex, outputIndex == 0 ? nullptr : &outputNames_cstr[0],
outputIndex == 0 ? nullptr : &outputValues[0], outputIndex);
Napi::Object result = Napi::Object::New(env);

Napi::Object result = Napi::Object::New(env);
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputValues[i])));
}
return scope.Escape(result);
} else {
// IO binding
ORT_NAPI_THROW_ERROR_IF(preferredOutputLocations_.size() != outputNames_.size(), env,
"Preferred output locations must have the same size as output names.");

for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, outputValues[i]));
}
for (size_t i = 0; i < inputIndex; i++) {
ioBinding_->BindInput(inputNames_cstr[i], inputValues[i]);
}
for (size_t i = 0; i < outputIndex; i++) {
// TODO: support preallocated output tensor (outputValues[i])

if (preferredOutputLocations_[i] == DATA_LOCATION_GPU_BUFFER) {
ioBinding_->BindOutput(outputNames_cstr[i], gpuBufferMemoryInfo);
} else {
ioBinding_->BindOutput(outputNames_cstr[i], cpuMemoryInfo);
}
}

session_->Run(runOptions == nullptr ? *defaultRunOptions_.get() : runOptions, *ioBinding_);

auto outputs = ioBinding_->GetOutputValues();
ORT_NAPI_THROW_ERROR_IF(outputs.size() != outputIndex, env, "Output count mismatch.");

return scope.Escape(result);
Napi::Object result = Napi::Object::New(env);
for (size_t i = 0; i < outputIndex; i++) {
result.Set(outputNames_[i], OrtValueToNapiValue(env, std::move(outputs[i])));
}
return scope.Escape(result);
}
} catch (Napi::Error const& e) {
throw e;
} catch (std::exception const& e) {
Expand All @@ -218,13 +281,29 @@ Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo& info) {
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");

this->ioBinding_.reset(nullptr);

this->defaultRunOptions_.reset(nullptr);
this->session_.reset(nullptr);

this->disposed_ = true;
return env.Undefined();
}

Napi::Value InferenceSessionWrap::EndProfiling(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");

Napi::EscapableHandleScope scope(env);

Ort::AllocatorWithDefaultOptions allocator;

auto filename = session_->EndProfilingAllocated(allocator);
Napi::String filenameValue = Napi::String::From(env, filename.get());
return scope.Escape(filenameValue);
}

Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::EscapableHandleScope scope(env);
Expand All @@ -242,6 +321,9 @@ Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo
#ifdef USE_DML
result.Set(result.Length(), createObject("dml", true));
#endif
#ifdef USE_WEBGPU
result.Set(result.Length(), createObject("webgpu", true));
#endif
#ifdef USE_CUDA
result.Set(result.Length(), createObject("cuda", false));
#endif
Expand Down
28 changes: 27 additions & 1 deletion js/node/src/inference_session_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,22 @@
class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports);
static Napi::FunctionReference& GetTensorConstructor();

InferenceSessionWrap(const Napi::CallbackInfo& info);

private:
/**
* [sync] initialize ONNX Runtime once.
*
* This function must be called before any other functions.
*
* @param arg0 a number specifying the log level.
*
* @returns undefined
*/
static Napi::Value InitOrtOnce(const Napi::CallbackInfo& info);

/**
* [sync] list supported backend list
* @returns array with objects { "name": "cpu", requirementsInstalled: true }
Expand Down Expand Up @@ -63,10 +76,19 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
*/
Napi::Value Dispose(const Napi::CallbackInfo& info);

/**
* [sync] end the profiling.
* @param nothing
* @returns nothing
* @throw nothing
*/
Napi::Value EndProfiling(const Napi::CallbackInfo& info);

// private members

// persistent constructor
static Napi::FunctionReference constructor;
static Napi::FunctionReference wrappedSessionConstructor;
static Napi::FunctionReference ortTensorConstructor;

// session objects
bool initialized_;
Expand All @@ -81,4 +103,8 @@ class InferenceSessionWrap : public Napi::ObjectWrap<InferenceSessionWrap> {
std::vector<std::string> outputNames_;
std::vector<ONNXType> outputTypes_;
std::vector<ONNXTensorElementDataType> outputTensorElementDataTypes_;

// preferred output locations
std::vector<int> preferredOutputLocations_;
std::unique_ptr<Ort::IoBinding> ioBinding_;
};
Loading

0 comments on commit dfada06

Please sign in to comment.