Skip to content

Commit

Permalink
Merge pull request #46 from nbigaouette/42-tie-session-lifetime-to-en…
Browse files Browse the repository at this point in the history
…vironment

Store a reference to environment in session to tie its lifetime
  • Loading branch information
nbigaouette authored Dec 27, 2020
2 parents 694bb7c + 7198ab1 commit 3b804d4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl Environment {
/// Create a new [`SessionBuilder`](../session/struct.SessionBuilder.html)
/// used to create a new ONNX session.
pub fn new_session_builder(&self) -> Result<SessionBuilder> {
SessionBuilder::new(self.clone())
SessionBuilder::new(self)
}
}

Expand Down
48 changes: 24 additions & 24 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ use crate::{download::AvailableOnnxModel, error::OrtDownloadError};
/// # }
/// ```
#[derive(Debug)]
pub struct SessionBuilder {
env: Environment,
pub struct SessionBuilder<'a> {
env: &'a Environment,
session_options_ptr: *mut sys::OrtSessionOptions,

allocator: AllocatorType,
memory_type: MemType,
}

impl Drop for SessionBuilder {
impl<'a> Drop for SessionBuilder<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session options.");
Expand All @@ -79,8 +79,8 @@ impl Drop for SessionBuilder {
}
}

impl SessionBuilder {
pub(crate) fn new(env: Environment) -> Result<SessionBuilder> {
impl<'a> SessionBuilder<'a> {
pub(crate) fn new(env: &'a Environment) -> Result<SessionBuilder<'a>> {
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };

Expand All @@ -97,7 +97,7 @@ impl SessionBuilder {
}

/// Configure the session to use a number of threads
pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder> {
pub fn with_number_threads(self, num_threads: i16) -> Result<SessionBuilder<'a>> {
// FIXME: Pre-built binaries use OpenMP, set env variable instead

// We use a u16 in the builder to cover the 16-bits positive values of a i32.
Expand All @@ -113,7 +113,7 @@ impl SessionBuilder {
pub fn with_optimization_level(
self,
opt_level: GraphOptimizationLevel,
) -> Result<SessionBuilder> {
) -> Result<SessionBuilder<'a>> {
// Sets graph optimization level
unsafe {
g_ort().SetSessionGraphOptimizationLevel.unwrap()(
Expand All @@ -127,47 +127,44 @@ impl SessionBuilder {
/// Set the session's allocator
///
/// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena)
pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder> {
pub fn with_allocator(mut self, allocator: AllocatorType) -> Result<SessionBuilder<'a>> {
self.allocator = allocator;
Ok(self)
}

/// Set the session's memory type
///
/// Defaults to [`MemType::Default`](../enum.MemType.html#variant.Default)
pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder> {
pub fn with_memory_type(mut self, memory_type: MemType) -> Result<SessionBuilder<'a>> {
self.memory_type = memory_type;
Ok(self)
}

/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
#[cfg(feature = "model-fetching")]
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session>
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
where
M: Into<AvailableOnnxModel>,
{
self.with_model_downloaded_monomorphized(model.into())
}

#[cfg(feature = "model-fetching")]
fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session> {
fn with_model_downloaded_monomorphized(self, model: AvailableOnnxModel) -> Result<Session<'a>> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = model.download_to(download_dir)?;
self.with_model_from_file_monomorphized(downloaded_path.as_ref())
self.with_model_from_file(downloaded_path)
}

// TODO: Add all functions changing the options.
// See all OrtApi methods taking a `options: *mut OrtSessionOptions`.

/// Load an ONNX graph from a file and commit the session
pub fn with_model_from_file<P>(self, model_filepath: P) -> Result<Session>
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> Result<Session<'a>>
where
P: AsRef<Path>,
P: AsRef<Path> + 'a,
{
self.with_model_from_file_monomorphized(model_filepath.as_ref())
}

fn with_model_from_file_monomorphized(self, model_filepath: &Path) -> Result<Session> {
let model_filepath = model_filepath_ref.as_ref();
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();

if !model_filepath.exists() {
Expand Down Expand Up @@ -224,6 +221,7 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;

Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
Expand All @@ -233,14 +231,14 @@ impl SessionBuilder {
}

/// Load an ONNX graph from memory and commit the session
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session>
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session<'a>>
where
B: AsRef<[u8]>,
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}

fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session> {
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session<'a>> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();

let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
Expand Down Expand Up @@ -279,6 +277,7 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;

Ok(Session {
env: self.env,
session_ptr,
allocator_ptr,
memory_info,
Expand All @@ -290,7 +289,8 @@ impl SessionBuilder {

/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
#[derive(Debug)]
pub struct Session {
pub struct Session<'a> {
env: &'a Environment,
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
memory_info: MemoryInfo,
Expand Down Expand Up @@ -348,7 +348,7 @@ impl Output {
}
}

impl Drop for Session {
impl<'a> Drop for Session<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session.");
Expand All @@ -360,7 +360,7 @@ impl Drop for Session {
}
}

impl Session {
impl<'a> Session<'a> {
/// Run the input data through the ONNX graph, performing inference.
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
Expand Down Expand Up @@ -562,7 +562,7 @@ impl Session {

/// This module contains dangerous functions working on raw pointers.
/// Those functions are only to be used from inside the
/// `SessionBuilder::with_model_from_file_monomorphized()` method.
/// `SessionBuilder::with_model_from_file()` method.
mod dangerous {
use super::*;

Expand Down

0 comments on commit 3b804d4

Please sign in to comment.