diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index e62c6579e8333..ebd98f2aa9b65 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -223,6 +223,8 @@ export declare namespace InferenceSession { export interface CudaExecutionProviderOption extends ExecutionProviderOption { readonly name: 'cuda'; deviceId?: number; + arenaExtendStrategy?: 0|1; + gpuMemLimit?: number; } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; diff --git a/js/node/src/session_options_helper.cc b/js/node/src/session_options_helper.cc index 8c1d7ca06b8c3..548cc631a2cfb 100644 --- a/js/node/src/session_options_helper.cc +++ b/js/node/src/session_options_helper.cc @@ -41,6 +41,10 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; +#ifdef USE_CUDA + onnxruntime::ArenaExtendStrategy arenaExtendStrategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; + size_t gpuMemLimit = std::numeric_limits::max(); +#endif #ifdef USE_COREML int coreMlFlags = 0; #endif @@ -59,6 +63,16 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } +#ifdef USE_CUDA + if (obj.Has("arenaExtendStrategy")) { + arenaExtendStrategy = static_cast( + obj.Get("arenaExtendStrategy").As().Uint32Value()); + } + if (obj.Has("gpuMemLimit")) { + gpuMemLimit = static_cast( + obj.Get("gpuMemLimit").As().Uint32Value()); + } +#endif #ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); @@ -86,6 +100,8 @@ void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sess OrtCUDAProviderOptionsV2* options; Ort::GetApi().CreateCUDAProviderOptions(&options); options->device_id = deviceId; + options->arena_extend_strategy = arenaExtendStrategy; + options->gpu_mem_limit = gpuMemLimit; sessionOptions.AppendExecutionProvider_CUDA_V2(*options); Ort::GetApi().ReleaseCUDAProviderOptions(options); #endif