From bc002c03c94c9fc2835900d4cb12b52dbd9659b2 Mon Sep 17 00:00:00 2001 From: "yanlong.wang" Date: Thu, 19 Dec 2024 18:50:29 +0800 Subject: [PATCH] [js/node] allow arenaExtendStrategy and gpuMemLimit for cuda --- js/common/lib/inference-session.ts | 2 ++ js/node/src/session_options_helper.cc | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index e62c6579e8333..ebd98f2aa9b65 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -223,6 +223,8 @@ export declare namespace InferenceSession { export interface CudaExecutionProviderOption extends ExecutionProviderOption { readonly name: 'cuda'; deviceId?: number; + arenaExtendStrategy?: 0|1; + gpuMemLimit?: number; } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 8c1d7ca06b8c3..548cc631a2cfb 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_CUDA + onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; + size_t gpuMemLimit = std::numeric_limits::max(); +#endif #ifdef USE_COREML int coreMlFlags = 0; #endif @@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_CUDA + if (obj.Has("arenaExtendStrategy")) { + arenaExtendStrategy = static_cast( + obj.Get("arenaExtendStrategy").As().Uint32Value()); + } + if (obj.Has("gpuMemLimit")) { + gpuMemLimit = static_cast( + obj.Get("gpuMemLimit").As().Uint32Value()); + } +#endif #ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); @@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess OrtCUDAProviderOptionsV2* options; Ort::GetApi().CreateCUDAProviderOptions(&options); options->device_id = deviceId; + options->arena_extend_strategy = arenaExtendStrategy; + options->gpu_mem_limit = gpuMemLimit; sessionOptions.AppendExecutionProvider_CUDA_V2(*options); Ort::GetApi().ReleaseCUDAProviderOptions(options); #endif