Skip to content

Commit

Permalink
[js/node] allow arenaExtendStrategy and gpuMemLimit for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
nomagick committed Dec 21, 2024
1 parent ae6dcc8 commit bc002c0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions js/common/lib/inference-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ export declare namespace InferenceSession {
export interface CudaExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'cuda';
deviceId?: number;
arenaExtendStrategy?: 0|1;
gpuMemLimit?: number;
}
export interface DmlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'dml';
Expand Down
16 changes: 16 additions & 0 deletions js/node/src/session_options_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
Napi::Value epValue = epList[i];
std::string name;
int deviceId = 0;
#ifdef USE_CUDA
onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
size_t gpuMemLimit = std::numeric_limits<size_t>::max();
#endif
#ifdef USE_COREML
int coreMlFlags = 0;
#endif
Expand All @@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
if (obj.Has("deviceId")) {
deviceId = obj.Get("deviceId").As<Napi::Number>();
}
#ifdef USE_CUDA
if (obj.Has("arenaExtendStrategy")) {
arenaExtendStrategy = static_cast<onnxruntime::ArenaExtendStrategy>(
obj.Get("arenaExtendStrategy").As<Napi::Number>().Uint32Value());
}
if (obj.Has("gpuMemLimit")) {
gpuMemLimit = static_cast<size_t>(
obj.Get("gpuMemLimit").As<Napi::Number>().Uint32Value());
}
#endif
#ifdef USE_COREML
if (obj.Has("coreMlFlags")) {
coreMlFlags = obj.Get("coreMlFlags").As<Napi::Number>();
Expand Down Expand Up @@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess
OrtCUDAProviderOptionsV2* options;
Ort::GetApi().CreateCUDAProviderOptions(&options);
options->device_id = deviceId;
options->arena_extend_strategy = arenaExtendStrategy;
options->gpu_mem_limit = gpuMemLimit;
sessionOptions.AppendExecutionProvider_CUDA_V2(*options);
Ort::GetApi().ReleaseCUDAProviderOptions(options);
#endif
Expand Down

0 comments on commit bc002c0

Please sign in to comment.