From 0f6f9816c396e8446a5c6fa7e2445a00e8b2ba2c Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Fri, 29 Sep 2023 12:57:27 +0800 Subject: [PATCH] dev(narugo): use both --- imgutils/utils/onnxruntime.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/imgutils/utils/onnxruntime.py b/imgutils/utils/onnxruntime.py index 455ee892afe..22a37a11b9d 100644 --- a/imgutils/utils/onnxruntime.py +++ b/imgutils/utils/onnxruntime.py @@ -63,14 +63,18 @@ def get_onnx_provider(provider: Optional[str] = None): f'but unsupported provider {provider!r} found.') -def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession: +def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession: options = SessionOptions() options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL if provider == "CPUExecutionProvider": options.intra_op_num_threads = os.cpu_count() + providers = [provider] + if use_cpu and "CPUExecutionProvider" not in providers: + providers.append("CPUExecutionProvider") + logging.info(f'Model {ckpt!r} loaded with provider {provider!r}') - return InferenceSession(ckpt, options, [provider]) + return InferenceSession(ckpt, options, providers=providers) def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession: