From 2785b090c6bad533685cbe729b39c32a96e37d7b Mon Sep 17 00:00:00 2001 From: Collide <44722470+TD-Sky@users.noreply.github.com> Date: Tue, 3 Dec 2024 19:16:23 +0800 Subject: [PATCH] upgrade ort to v2.0.0-rc.9 (#52) --- Cargo.toml | 5 +- src/core/ort_engine.rs | 113 +++++++++++++++++++++-------------------- 2 files changed, 60 insertions(+), 58 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index db6fc67..7db9dcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"] [dependencies] clap = { version = "4.2.4", features = ["derive"] } ndarray = { version = "0.16.1", features = ["rayon"] } -ort = { version = "2.0.0-rc.5", default-features = false} +ort = { version = "2.0.0-rc.9", default-features = false } anyhow = { version = "1.0.75" } regex = { version = "1.5.4" } rand = { version = "0.8.5" } @@ -30,7 +30,7 @@ imageproc = { version = "0.24" } ab_glyph = "0.2.23" geo = "0.28.0" prost = "0.12.4" -fast_image_resize = { version = "4.2.1", features = ["image"]} +fast_image_resize = { version = "4.2.1", features = ["image"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tempfile = "3.12.0" @@ -50,7 +50,6 @@ default = [ "ort/cuda", "ort/tensorrt", "ort/coreml", - "ort/operator-libraries" ] auto = ["ort/download-binaries"] diff --git a/src/core/ort_engine.rs b/src/core/ort_engine.rs index 9aa5f29..d0b915d 100644 --- a/src/core/ort_engine.rs +++ b/src/core/ort_engine.rs @@ -2,7 +2,9 @@ use anyhow::Result; use half::f16; use ndarray::{Array, IxDyn}; use ort::{ - ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider, + execution_providers::{ExecutionProvider, TensorRTExecutionProvider}, + session::{builder::SessionBuilder, Session}, + tensor::TensorElementType, }; use prost::Message; use std::collections::HashSet; @@ -88,14 +90,14 @@ impl OrtEngine { // build ort::init().commit()?; - let builder = Session::builder()?; + let mut builder = Session::builder()?; let mut device = config.device.to_owned(); match device { Device::Trt(device_id) => { Self::build_trt( &inputs_attrs.names, &inputs_minoptmax, - &builder, + &mut builder, device_id, config.trt_int8_enable, config.trt_fp16_enable, @@ -103,23 +105,23 @@ impl OrtEngine { )?; } Device::Cuda(device_id) => { - Self::build_cuda(&builder, device_id).unwrap_or_else(|err| { + Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| { tracing::warn!("{err}, Using cpu"); device = Device::Cpu(0); }) } - Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| { + Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| { tracing::warn!("{err}, Using cpu"); device = Device::Cpu(0); }), Device::Cpu(_) => { - Self::build_cpu(&builder)?; + Self::build_cpu(&mut builder)?; } _ => todo!(), } let session = builder - .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)? .commit_from_file(&config.onnx_path)?; // summary @@ -149,7 +151,7 @@ impl OrtEngine { fn build_trt( names: &[String], inputs_minoptmax: &[Vec], - builder: &SessionBuilder, + builder: &mut SessionBuilder, device_id: usize, int8_enable: bool, fp16_enable: bool, @@ -205,8 +207,9 @@ impl OrtEngine { } } - fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> { - let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32); + fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> { + let ep = ort::execution_providers::CUDAExecutionProvider::default() + .with_device_id(device_id as i32); if ep.is_available()? && ep.register(builder).is_ok() { Ok(()) } else { @@ -214,8 +217,8 @@ impl OrtEngine { } } - fn build_coreml(builder: &SessionBuilder) -> Result<()> { - let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); + fn build_coreml(builder: &mut SessionBuilder) -> Result<()> { + let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); if ep.is_available()? && ep.register(builder).is_ok() { Ok(()) } else { @@ -223,8 +226,8 @@ impl OrtEngine { } } - fn build_cpu(builder: &SessionBuilder) -> Result<()> { - let ep = ort::CPUExecutionProvider::default(); + fn build_cpu(builder: &mut SessionBuilder) -> Result<()> { + let ep = ort::execution_providers::CPUExecutionProvider::default(); if ep.is_available()? && ep.register(builder).is_ok() { Ok(()) } else { @@ -292,28 +295,28 @@ impl OrtEngine { let t_pre = std::time::Instant::now(); for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { let x_ = match &idtype { - TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(), + TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(), TensorElementType::Float16 => { - ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() } TensorElementType::Int32 => { - ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() } TensorElementType::Int64 => { - ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() } TensorElementType::Uint8 => { - ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() } TensorElementType::Int8 => { - ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() } TensorElementType::Bool => { - ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() + ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() } _ => todo!(), }; - xs_.push(Into::>::into(x_)); + xs_.push(Into::>::into(x_)); } let t_pre = t_pre.elapsed(); self.ts.add_or_push(0, t_pre); @@ -451,45 +454,45 @@ impl OrtEngine { } #[allow(dead_code)] - fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize { + fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize { match x { - ort::TensorElementType::Float64 - | ort::TensorElementType::Uint64 - | ort::TensorElementType::Int64 => 8, // i64, f64, u64 - ort::TensorElementType::Float32 - | ort::TensorElementType::Uint32 - | ort::TensorElementType::Int32 - | ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4) - ort::TensorElementType::Float16 - | ort::TensorElementType::Bfloat16 - | ort::TensorElementType::Int16 - | ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 - ort::TensorElementType::Uint8 - | ort::TensorElementType::Int8 - | ort::TensorElementType::Bool => 1, // u8, i8, bool + ort::tensor::TensorElementType::Float64 + | ort::tensor::TensorElementType::Uint64 + | ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64 + ort::tensor::TensorElementType::Float32 + | ort::tensor::TensorElementType::Uint32 + | ort::tensor::TensorElementType::Int32 + | ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4) + ort::tensor::TensorElementType::Float16 + | ort::tensor::TensorElementType::Bfloat16 + | ort::tensor::TensorElementType::Int16 + | ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 + ort::tensor::TensorElementType::Uint8 + | ort::tensor::TensorElementType::Int8 + | ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool } } #[allow(dead_code)] - fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { + fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option { match value { 0 => None, - 1 => Some(ort::TensorElementType::Float32), - 2 => Some(ort::TensorElementType::Uint8), - 3 => Some(ort::TensorElementType::Int8), - 4 => Some(ort::TensorElementType::Uint16), - 5 => Some(ort::TensorElementType::Int16), - 6 => Some(ort::TensorElementType::Int32), - 7 => Some(ort::TensorElementType::Int64), - 8 => Some(ort::TensorElementType::String), - 9 => Some(ort::TensorElementType::Bool), - 10 => Some(ort::TensorElementType::Float16), - 11 => Some(ort::TensorElementType::Float64), - 12 => Some(ort::TensorElementType::Uint32), - 13 => Some(ort::TensorElementType::Uint64), + 1 => Some(ort::tensor::TensorElementType::Float32), + 2 => Some(ort::tensor::TensorElementType::Uint8), + 3 => Some(ort::tensor::TensorElementType::Int8), + 4 => Some(ort::tensor::TensorElementType::Uint16), + 5 => Some(ort::tensor::TensorElementType::Int16), + 6 => Some(ort::tensor::TensorElementType::Int32), + 7 => Some(ort::tensor::TensorElementType::Int64), + 8 => Some(ort::tensor::TensorElementType::String), + 9 => Some(ort::tensor::TensorElementType::Bool), + 10 => Some(ort::tensor::TensorElementType::Float16), + 11 => Some(ort::tensor::TensorElementType::Float64), + 12 => Some(ort::tensor::TensorElementType::Uint32), + 13 => Some(ort::tensor::TensorElementType::Uint64), 14 => None, // COMPLEX64 15 => None, // COMPLEX128 - 16 => Some(ort::TensorElementType::Bfloat16), + 16 => Some(ort::tensor::TensorElementType::Bfloat16), _ => None, } } @@ -499,7 +502,7 @@ impl OrtEngine { value_info: &[onnx::ValueInfoProto], ) -> Result { let mut dimss: Vec> = Vec::new(); - let mut dtypes: Vec = Vec::new(); + let mut dtypes: Vec = Vec::new(); let mut names: Vec = Vec::new(); for v in value_info.iter() { if initializer_names.contains(v.name.as_str()) { @@ -569,7 +572,7 @@ impl OrtEngine { &self.outputs_attrs.names } - pub fn odtypes(&self) -> &Vec { + pub fn odtypes(&self) -> &Vec { &self.outputs_attrs.dtypes } @@ -585,7 +588,7 @@ impl OrtEngine { &self.inputs_attrs.names } - pub fn idtypes(&self) -> &Vec { + pub fn idtypes(&self) -> &Vec { &self.inputs_attrs.dtypes }