Skip to content

Commit

Permalink
feat(mnn-sync): Use the same ScheduleConfig for all load-unloads
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Oct 25, 2024
1 parent 2fdec58 commit e1a436a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
21 changes: 10 additions & 11 deletions mnn-sync/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl SessionRunnerState {
}
}

pub fn load(&mut self) -> Result<()> {
pub fn load(&mut self, config: &ScheduleConfig) -> Result<()> {
#[cfg(feature = "tracing")]
tracing::info!("Loading session");
match core::mem::take(self) {
Expand All @@ -148,19 +148,19 @@ impl SessionRunnerState {
Ok(())
}
Self::Unloaded(net) => {
let mut sr = SessionRunner::create(net, ScheduleConfig::new())?;
let sr = SessionRunner::create(net, config.clone())?;
*self = Self::Loaded(sr);
Ok(())
}
Self::Poisoned => Self::poisoned(),
}
}

pub fn sr(&mut self) -> Result<&mut SessionRunner> {
pub fn sr(&mut self, config: &ScheduleConfig) -> Result<&mut SessionRunner> {
match self {
Self::Loaded(sr) => Ok(sr),
Self::Unloaded(net) => {
self.load()?;
Self::Unloaded(_) => {
self.load(config)?;
Ok(self.loaded_mut().ok_or_else(|| {
Report::new(ErrorKind::SyncError).attach_printable("Failed to load session")
})?)
Expand All @@ -179,11 +179,11 @@ impl SessionRunnerState {

impl SessionState {
pub fn sr(&mut self) -> Result<&mut SessionRunner> {
self.sr.sr()
self.sr.sr(&self.config)
}

pub fn load(&mut self) -> Result<()> {
self.sr.load()
self.sr.load(&self.config)
}

pub fn unload(&mut self) -> Result<()> {
Expand Down Expand Up @@ -284,7 +284,7 @@ impl SessionRunner {
}

impl SessionHandle {
pub fn new(mut interpreter: Interpreter, config: ScheduleConfig) -> Result<Self> {
pub fn new(interpreter: Interpreter, config: ScheduleConfig) -> Result<Self> {
let (sender, receiver) = flume::unbounded::<CallbackSender>();
let builder = std::thread::Builder::new().name("mnn-session-thread".to_string());
let spawner = move || -> Result<()> {
Expand Down Expand Up @@ -365,10 +365,9 @@ impl SessionHandle {
self.sender
.send(CallbackEnum::Callback(Box::new(wrapped_f)))
.map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?;
Ok(rx
.recv()
rx.recv()
.change_context(ErrorKind::SyncError)
.attach_printable("Internal Error: Unable to recv message")??)
.attach_printable("Internal Error: Unable to recv message")?
}

pub async fn run_async<R: Send + Sync + 'static>(
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ pub struct TensorInfo<'t, 'tl> {
impl core::fmt::Debug for TensorInfo<'_, '_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let tensor = self.raw_tensor();
let shape = tensor.shape().clone();
let shape = tensor.shape();
f.debug_struct("TensorInfo")
.field("name", &self.name())
.field("tensor", &shape)
Expand Down
2 changes: 2 additions & 0 deletions src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::{ffi::CString, mem::ManuallyDrop};

use crate::{prelude::*, BackendConfig};

/// Backend used for running the model
///
/// The `ForwardType` enum is used to specify the backend that will be used for forward computation
/// in the MNN framework. Each variant corresponds to a different backend, which may be enabled
/// or disabled based on the features enabled in the build configuration.
Expand Down

0 comments on commit e1a436a

Please sign in to comment.