Skip to content

Commit

Permalink
feat: Make the SessionHandle send+sync
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Sep 22, 2024
1 parent 7fbb2cb commit 7f70a7e
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 4 deletions.
134 changes: 134 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions mnn-sync/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@ license.workspace = true

[dependencies]
error-stack = "0.5.0"
flume = { version = "0.11.0", default-features = false, features = [
"eventual-fairness",
"nanorand",
] }
mnn = { version = "0.1.0", path = ".." }
oneshot = "0.1.8"
16 changes: 12 additions & 4 deletions mnn-sync/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
//! When you run a closure it is sent to the thread and executed in that session and the result is
//! sent back to the main thread via a [oneshot::Sender]
use flume::{Receiver, Sender};

use error_stack::{Report, ResultExt};
use mnn::*;

Expand All @@ -49,8 +51,8 @@ type CallbackSender = CallbackEnum;
pub struct SessionHandle {
#[allow(dead_code)]
pub(crate) handle: std::thread::JoinHandle<Result<()>>,
pub(crate) sender: std::sync::mpsc::Sender<CallbackSender>,
pub(crate) loop_handle: std::sync::mpsc::Receiver<Result<()>>,
pub(crate) sender: Sender<CallbackSender>,
pub(crate) loop_handle: Receiver<Result<()>>,
}

impl Drop for SessionHandle {
Expand All @@ -69,10 +71,10 @@ pub struct SessionRunner {

impl SessionHandle {
pub fn new(mut interpreter: Interpreter, config: ScheduleConfig) -> Result<Self> {
let (sender, receiver) = std::sync::mpsc::channel::<CallbackSender>();
let (sender, receiver) = flume::unbounded::<CallbackSender>();

let builder = std::thread::Builder::new().name("mnn-session-thread".to_string());
let (tx, rx) = std::sync::mpsc::channel();
let (tx, rx) = flume::unbounded();
let handle = builder
.spawn(move || -> Result<()> {
let session = interpreter.create_session(config)?;
Expand Down Expand Up @@ -291,3 +293,9 @@ pub fn test_sync_api_race() {
})
.expect("Sed");
}

#[test]
pub fn test_sync_api_is_send_sync() {
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<SessionHandle>();
}

0 comments on commit 7f70a7e

Please sign in to comment.