diff --git a/onnxruntime/src/environment.rs b/onnxruntime/src/environment.rs index 68a2a4a1..3e26866a 100644 --- a/onnxruntime/src/environment.rs +++ b/onnxruntime/src/environment.rs @@ -135,7 +135,7 @@ impl Environment { /// Create a new [`SessionBuilder`](../session/struct.SessionBuilder.html) /// used to create a new ONNX session. pub fn new_session_builder(&self) -> Result { - SessionBuilder::new(self.clone()) + SessionBuilder::new(self) } } diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 2cef12e5..39cb558a 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -62,15 +62,15 @@ use crate::{download::AvailableOnnxModel, error::OrtDownloadError}; /// # } /// ``` #[derive(Debug)] -pub struct SessionBuilder { - env: Environment, +pub struct SessionBuilder<'a> { + env: &'a Environment, session_options_ptr: *mut sys::OrtSessionOptions, allocator: AllocatorType, memory_type: MemType, } -impl Drop for SessionBuilder { +impl<'a> Drop for SessionBuilder<'a> { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping the session options."); @@ -79,8 +79,8 @@ impl Drop for SessionBuilder { } } -impl SessionBuilder { - pub(crate) fn new(env: Environment) -> Result { +impl<'a> SessionBuilder<'a> { + pub(crate) fn new(env: &'a Environment) -> Result> { let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut(); let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) }; @@ -97,7 +97,7 @@ impl SessionBuilder { } /// Configure the session to use a number of threads - pub fn with_number_threads(self, num_threads: i16) -> Result { + pub fn with_number_threads(self, num_threads: i16) -> Result> { // FIXME: Pre-built binaries use OpenMP, set env variable instead // We use a u16 in the builder to cover the 16-bits positive values of a i32. @@ -113,7 +113,7 @@ impl SessionBuilder { pub fn with_optimization_level( self, opt_level: GraphOptimizationLevel, - ) -> Result { + ) -> Result> { // Sets graph optimization level unsafe { g_ort().SetSessionGraphOptimizationLevel.unwrap()( @@ -127,7 +127,7 @@ impl SessionBuilder { /// Set the session's allocator /// /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena) - pub fn with_allocator(mut self, allocator: AllocatorType) -> Result { + pub fn with_allocator(mut self, allocator: AllocatorType) -> Result> { self.allocator = allocator; Ok(self) } @@ -135,14 +135,14 @@ impl SessionBuilder { /// Set the session's memory type /// /// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default) - pub fn with_memory_type(mut self, memory_type: MemType) -> Result { + pub fn with_memory_type(mut self, memory_type: MemType) -> Result> { self.memory_type = memory_type; Ok(self) } /// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session #[cfg(feature = "model-fetching")] - pub fn with_model_downloaded(self, model: M) -> Result + pub fn with_model_downloaded(self, model: M) -> Result> where M: Into, { @@ -150,24 +150,21 @@ impl SessionBuilder { } #[cfg(feature = "model-fetching")] - fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result { + fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result> { let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?; let downloaded_path = model.download_to(download_dir)?; - self.with_model_from_file_monomorphized(downloaded_path.as_ref()) + self.with_model_from_file(downloaded_path) } // TODO: Add all functions changing the options. // See all OrtApi methods taking a `options: *mut OrtSessionOptions`. /// Load an ONNX graph from a file and commit the session - pub fn with_model_from_file

(self, model_filepath: P) -> Result + pub fn with_model_from_file

(self, model_filepath_ref: P) -> Result> where - P: AsRef, + P: AsRef + 'a, { - self.with_model_from_file_monomorphized(model_filepath.as_ref()) - } - - fn with_model_from_file_monomorphized(self, model_filepath: &Path) -> Result { + let model_filepath = model_filepath_ref.as_ref(); let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut(); if !model_filepath.exists() { @@ -224,6 +221,7 @@ impl SessionBuilder { .collect::>>()?; Ok(Session { + env: self.env, session_ptr, allocator_ptr, memory_info, @@ -233,14 +231,14 @@ impl SessionBuilder { } /// Load an ONNX graph from memory and commit the session - pub fn with_model_from_memory(self, model_bytes: B) -> Result + pub fn with_model_from_memory(self, model_bytes: B) -> Result> where B: AsRef<[u8]>, { self.with_model_from_memory_monomorphized(model_bytes.as_ref()) } - fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result { + fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result> { let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut(); let env_ptr: *const sys::OrtEnv = self.env.env_ptr(); @@ -279,6 +277,7 @@ impl SessionBuilder { .collect::>>()?; Ok(Session { + env: self.env, session_ptr, allocator_ptr, memory_info, @@ -290,7 +289,8 @@ impl SessionBuilder { /// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html) #[derive(Debug)] -pub struct Session { +pub struct Session<'a> { + env: &'a Environment, session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, memory_info: MemoryInfo, @@ -348,7 +348,7 @@ impl Output { } } -impl Drop for Session { +impl<'a> Drop for Session<'a> { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping the session."); @@ -360,7 +360,7 @@ impl Drop for Session { } } -impl Session { +impl<'a> Session<'a> { /// Run the input data through the ONNX graph, performing inference. /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus @@ -562,7 +562,7 @@ impl Session { /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the -/// `SessionBuilder::with_model_from_file_monomorphized()` method. +/// `SessionBuilder::with_model_from_file()` method. mod dangerous { use super::*;