diff --git a/examples/inspect.rs b/examples/inspect.rs index 03b2300..1298911 100644 --- a/examples/inspect.rs +++ b/examples/inspect.rs @@ -10,6 +10,10 @@ pub struct Cli { power: PowerMode, #[clap(short, long, default_value = "high")] precision: PrecisionMode, + #[clap(short, long, default_value = "high")] + memory: MemoryMode, + #[clap(short, long, default_value = "1")] + loops: usize, } macro_rules! time { @@ -30,40 +34,52 @@ macro_rules! time { pub fn main() -> anyhow::Result<()> { use clap::Parser; let cli = Cli::parse(); - let mut interpreter = Interpreter::from_file(cli.model)?; + let mut interpreter = Interpreter::from_file(&cli.model)?; + interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?; let mut config = ScheduleConfig::new(); config.set_type(cli.forward); - let mut backend_config = BackendConfig::new(); - backend_config.set_precision_mode(PrecisionMode::High); - backend_config.set_power_mode(PowerMode::High); - - config.set_backend_config(backend_config); + // let mut backend_config = BackendConfig::new(); + // backend_config.set_precision_mode(PrecisionMode::High); + // backend_config.set_power_mode(PowerMode::High); + // config.set_backend_config(backend_config); // let handle = mnn::sync::SessionHandle::new(interpreter, config)?; - let session = time!(interpreter.create_session(config)?; "create session"); + let mut session = time!(interpreter.create_session(config)?; "create session"); + interpreter.update_cache_file(&mut session)?; + let mut input = interpreter.input::(&session, "image")?; + let mut shape = input.shape(); + shape[0] = 512; + shape[1] = 512; + shape[2] = 3; + interpreter.resize_tensor(&mut input, shape); + drop(input); + interpreter.resize_session(&mut session); + // let session = time!(interpreter.create_session(config)?; "create session"); // handle.run(|sr| { // let interpreter = sr.interpreter(); // let session = sr.session(); - let inputs = interpreter.inputs(&session); - inputs.iter().for_each(|x| { - let mut tensor = x.tensor::().expect("No tensor"); - println!("{}: {:?}", x.name(), tensor.shape()); - tensor.fill(1.0f32); - // let mut cpu_tensor = tensor.create_host_tensor_from_device(false); - // cpu_tensor.host_mut().fill(1.0f32); - // tensor - // .copy_from_host_tensor(&cpu_tensor) - // .expect("Could not copy tensor"); - }); - time!(interpreter.run_session(&session)?;"run session"); - let outputs = interpreter.outputs(&session); - outputs.iter().for_each(|x| { - let tensor = x.tensor::().expect("No tensor"); - time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name())); - println!("{}: {:?}", x.name(), tensor.shape()); - let _ = tensor.create_host_tensor_from_device(true); - // std::fs::write(format!("{}.bin", x.name()), bytemuck::cast_slice(cpu_tensor.host())).expect("Unable to write"); - }); + let mut current = 0; + time!(loop { + let inputs = interpreter.inputs(&session); + inputs.iter().for_each(|x| { + let mut tensor = x.tensor::().expect("No tensor"); + println!("{}: {:?}", x.name(), tensor.shape()); + tensor.fill(1.0f32); + }); + time!(interpreter.run_session(&session)?;"run session"); + let outputs = interpreter.outputs(&session); + outputs.iter().for_each(|x| { + let tensor = x.tensor::().expect("No tensor"); + time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name())); + println!("{}: {:?}", x.name(), tensor.shape()); + let _ = tensor.create_host_tensor_from_device(true); + // std::fs::write(format!("{}.bin", x.name()), bytemuck::cast_slice(cpu_tensor.host())).expect("Unable to write"); + }); + current += 1; + if current >= cli.loops { + break; + } + }; "run loop"); // Ok(()) // })?; Ok(()) diff --git a/mnn-sys/mnn_c/interpreter_c.cpp b/mnn-sys/mnn_c/interpreter_c.cpp index 6c1bd7f..6bfb0b4 100644 --- a/mnn-sys/mnn_c/interpreter_c.cpp +++ b/mnn-sys/mnn_c/interpreter_c.cpp @@ -53,11 +53,10 @@ void Interpreter_setExternalFile(Interpreter *interpreter, const char *file, mnn_interpreter->setExternalFile(file, flag); } ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter, - Session *session, int flag) { + Session *session) { auto mnn_interpreter = reinterpret_cast(interpreter); auto mnn_session = reinterpret_cast(session); - return static_cast( - mnn_interpreter->updateCacheFile(mnn_session, flag)); + return static_cast(mnn_interpreter->updateCacheFile(mnn_session)); } void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value) { auto mnn_interpreter = reinterpret_cast(interpreter); @@ -111,17 +110,36 @@ Session *Interpreter_createSession(Interpreter *interpreter, // cppConfig.backendConfig = config->backendConfig; // return interpreter->createSession(cppConfig, *runtime); // } -Session *Interpreter_createMultiPathSession(Interpreter *interpreter, - const MNNScheduleConfig *configs, - size_t configSize) { - - auto mnn_configs = reinterpret_cast(configs); - // @todo: check if this is correct - std::vector cppConfigs(mnn_configs, - mnn_configs + configSize); +// Session *Interpreter_createMultiPathSession(Interpreter *interpreter, +// const MNNScheduleConfig *configs, +// size_t configSize) { +// +// auto mnn_configs = reinterpret_cast(configs); +// std::vector cppConfigs(mnn_configs, +// mnn_configs + configSize); +// auto mnn_interpreter = reinterpret_cast(interpreter); +// return reinterpret_cast( +// mnn_interpreter->createMultiPathSession(cppConfigs)); +// } +Session * +Interpreter_createMultiPathSession(Interpreter *interpreter, + const MNNScheduleConfig *const *configs, + size_t configSize) { + auto mnn_configs = + reinterpret_cast(configs); + std::vector s_configs; + for (size_t i = 0; i < configSize; ++i) { + s_configs.push_back(*mnn_configs[i]); + } + // std::vector cppConfigs(mnn_configs, + // mnn_configs + configSize); + // Create a std::vector from + // std::vector + // auto s_configs = + // std::vector(cppConfigs.begin(), cppConfigs.end()); auto mnn_interpreter = reinterpret_cast(interpreter); - return reinterpret_cast( - mnn_interpreter->createMultiPathSession(cppConfigs)); + MNN::Session *session = mnn_interpreter->createMultiPathSession(s_configs); + return reinterpret_cast(session); } // Session* Interpreter_createMultiPathSessionWithRuntime(Interpreter* @@ -299,6 +317,6 @@ const char *Interpreter_uuid(const Interpreter *interpreter) { } void Session_destroy(Session *session) { auto mnn_session = reinterpret_cast(session); - delete mnn_session; + // delete mnn_session; } } // extern "C" diff --git a/mnn-sys/mnn_c/interpreter_c.h b/mnn-sys/mnn_c/interpreter_c.h index 9932edf..a3c1a04 100644 --- a/mnn-sys/mnn_c/interpreter_c.h +++ b/mnn-sys/mnn_c/interpreter_c.h @@ -1,8 +1,8 @@ #ifndef INTERPRETER_C_H #define INTERPRETER_C_H #include "backend_c.h" -#include "schedule_c.h" #include "error_code_c.h" +#include "schedule_c.h" #include "tensor_c.h" #include "utils.h" #include @@ -144,7 +144,7 @@ void Interpreter_setCacheFile(Interpreter *interpreter, const char *cacheFile, void Interpreter_setExternalFile(Interpreter *interpreter, const char *file, size_t flag); ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter, - Session *session, int flag); + Session *session); void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value); // RuntimeInfo *Interpreter_createRuntime(const ScheduleConfig *configs, // size_t configSize); @@ -153,9 +153,10 @@ Session *Interpreter_createSession(Interpreter *interpreter, // Session *Interpreter_createSessionWithRuntime(Interpreter *interpreter, // const ScheduleConfig *config, // const RuntimeInfo *runtime); -Session *Interpreter_createMultiPathSession(Interpreter *interpreter, - const MNNScheduleConfig *configs, - size_t configSize); +Session * +Interpreter_createMultiPathSession(Interpreter *interpreter, + const MNNScheduleConfig *const *configs, + size_t configSize); // Session *Interpreter_createMultiPathSessionWithRuntime( // Interpreter *interpreter, const ScheduleConfig *configs, size_t // configSize, const RuntimeInfo *runtime); @@ -186,11 +187,13 @@ Tensor *Interpreter_getSessionOutput(Interpreter *interpreter, const Session *session, const char *name); int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session, int code, void *ptr); -TensorInfoArray const * Interpreter_getSessionOutputAll(const Interpreter *interpreter, - const Session *session); +TensorInfoArray const * +Interpreter_getSessionOutputAll(const Interpreter *interpreter, + const Session *session); -TensorInfoArray const * Interpreter_getSessionInputAll(const Interpreter *interpreter, - const Session *session); +TensorInfoArray const * +Interpreter_getSessionInputAll(const Interpreter *interpreter, + const Session *session); void Interpreter_resizeTensor(Interpreter *interpreter, Tensor *tensor, const int *dims, size_t dimsSize); void Interpreter_resizeTensorByNCHW(Interpreter *interpreter, Tensor *tensor, diff --git a/src/interpreter.rs b/src/interpreter.rs index 3240191..b1c555d 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -1,4 +1,6 @@ -use crate::{prelude::*, AsTensorShape, Device, Ref, RefMut, Tensor, TensorType}; +use std::path::Path; + +use crate::{prelude::*, AsTensorShape, Device, Ref, RefMut, ScheduleConfig, Tensor, TensorType}; use mnn_sys::HalideType; #[derive(Debug, Copy, Clone)] @@ -59,7 +61,7 @@ pub struct Interpreter { unsafe impl Send for Interpreter {} impl Interpreter { - pub fn from_file(path: impl AsRef) -> Result { + pub fn from_file(path: impl AsRef) -> Result { let path = path.as_ref(); ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string()); let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?; @@ -134,13 +136,30 @@ impl Interpreter { assert!(!session.is_null()); Ok(crate::session::Session { inner: session, - __schedule_config: schedule, + __session_internals: crate::SessionInternals::Single(schedule), + __marker: PhantomData, + }) + }) + } + + pub fn create_multipath_session( + &mut self, + schedule: impl IntoIterator, + ) -> Result { + profile!("Creating multipath session"; { + let schedules: crate::ScheduleConfigs = schedule.into_iter().collect(); + let sc: &[_] = schedules.inner.as_ref(); + let session = unsafe { mnn_sys::Interpreter_createMultiPathSession(self.inner, sc.as_ptr(), sc.len()) }; + assert!(!session.is_null()); + Ok(crate::session::Session { + inner: session, + __session_internals: crate::SessionInternals::MultiSession(schedules), __marker: PhantomData, }) }) } - pub fn model_print_io(path: impl AsRef) -> Result<()> { + pub fn model_print_io(path: impl AsRef) -> Result<()> { let path = path.as_ref(); crate::ensure!(path.exists(), ErrorKind::IOError); let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?; @@ -166,8 +185,8 @@ impl Interpreter { }; ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found")); let tensor = unsafe { Tensor::from_ptr(input) }; - let shape = tensor.shape(); - ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError); + // let shape = tensor.shape(); + // ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError); ensure!( tensor.is_type_of::(), ErrorKind::HalideTypeMismatch { @@ -203,12 +222,12 @@ impl Interpreter { pub fn run_session(&self, session: &crate::session::Session) -> Result<()> { profile!("Running session"; { - let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) }; - ensure!( - ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR, - ErrorKind::InternalError(ret) - ); - Ok(()) + let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) }; + ensure!( + ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR, + ErrorKind::InternalError(ret) + ); + Ok(()) }) } @@ -217,6 +236,20 @@ impl Interpreter { unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) }; TensorList::from_ptr(outputs) } + + pub fn set_cache_file(&mut self, path: impl AsRef, key_size: usize) -> Result<()> { + let path = path.as_ref(); + let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?; + let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?; + unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) } + Ok(()) + } + pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> { + MNNError::from_error_code(unsafe { + mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner) + }); + Ok(()) + } } #[repr(transparent)] diff --git a/src/schedule.rs b/src/schedule.rs index 54b45b2..e7693b7 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -1,5 +1,5 @@ use mnn_sys::*; -use std::ffi::CString; +use std::{ffi::CString, mem::ManuallyDrop}; use crate::{prelude::*, BackendConfig}; @@ -177,3 +177,59 @@ impl ScheduleConfig { } } } + +pub struct ScheduleConfigs { + pub(crate) inner: Vec<*const MNNScheduleConfig>, + pub(crate) backend_configs: Vec>, +} + +impl Drop for ScheduleConfigs { + fn drop(&mut self) { + unsafe { + for i in self.inner.iter() { + mnnsc_destroy(*i.cast()); + } + } + } +} + +impl ScheduleConfigs { + pub fn push(&mut self, config: ScheduleConfig) { + let mut config = ManuallyDrop::new(config); + self.inner.push(config.inner); + self.backend_configs.push(config.backend_config.take()); + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + inner: Vec::with_capacity(capacity), + backend_configs: Vec::with_capacity(capacity), + } + } + + pub const fn new() -> Self { + Self { + inner: Vec::new(), + backend_configs: Vec::new(), + } + } +} + +impl Default for ScheduleConfigs { + fn default() -> Self { + Self::new() + } +} + +impl FromIterator for ScheduleConfigs { + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default()); + iter.for_each(|item| { + ret.push(item); + }); + ret + } +} + +unsafe impl Send for ScheduleConfigs {} diff --git a/src/session.rs b/src/session.rs index 931350c..885f6d2 100644 --- a/src/session.rs +++ b/src/session.rs @@ -2,10 +2,15 @@ use crate::prelude::*; pub struct Session { pub(crate) inner: *mut mnn_sys::Session, - pub(crate) __schedule_config: crate::ScheduleConfig, + pub(crate) __session_internals: crate::SessionInternals, pub(crate) __marker: PhantomData<()>, } +pub enum SessionInternals { + Single(crate::ScheduleConfig), + MultiSession(crate::ScheduleConfigs), +} + impl Session { // pub unsafe fn from_ptr(session: *mut mnn_sys::Session) -> Self { // Self { diff --git a/src/tensor.rs b/src/tensor.rs index c4a3580..637c8ae 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -250,6 +250,10 @@ where } } + pub fn is_dynamic(&self) -> bool { + self.shape().iter().any(|&x| x == -1) + } + /// DO not use this function directly /// # Safety /// This is just provided as a 1:1 compat mostly for possible later use @@ -654,3 +658,26 @@ pub fn test_tensor_borrow_mut() { tensor.host_mut().fill(1); assert_eq!(data, &[1, 1, 1, 1, 1, 1]); } + +pub struct Dyn { + __marker: PhantomData, +} +impl seal::Sealed for Dyn {} + +impl super::TensorType for Dyn { + type H = T::H; + fn host() -> bool { + T::host() + } + fn device() -> bool { + T::device() + } + fn owned() -> bool { + T::owned() + } + fn borrowed() -> bool { + T::borrowed() + } +} + +// impl super::Tensor> {} diff --git a/tests/basic.rs b/tests/basic.rs index b2c8909..eea2df0 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -58,3 +58,52 @@ fn test_basic_coreml() { fn test_basic_opengl() { test_basic(ForwardType::OpenGL).unwrap(); } + +#[cfg(test)] +pub fn test_multipath_session(backend: ForwardType, backend2: ForwardType) -> Result<()> { + use mnn::BackendConfig; + + let mut net = mnn::Interpreter::from_bytes(Model::new())?; + let mut config = ScheduleConfig::new(); + config.set_type(backend); + config.set_backup_type(backend); + let mut bc = BackendConfig::new(); + bc.set_memory_mode(mnn::MemoryMode::High); + bc.set_precision_mode(mnn::PrecisionMode::High); + bc.set_power_mode(mnn::PowerMode::High); + let mut config2 = ScheduleConfig::new(); + config2.set_type(backend2); + config2.set_backup_type(backend2); + let mut bc = BackendConfig::new(); + bc.set_memory_mode(mnn::MemoryMode::High); + bc.set_precision_mode(mnn::PrecisionMode::High); + bc.set_power_mode(mnn::PowerMode::High); + config2.set_backend_config(bc); + + let session = net.create_multipath_session([config, config2])?; + let inputs = net.inputs(&session); + for input in inputs.iter() { + println!("input: {:?}", input); + input.tensor::()?.fill(0.0); + } + net.run_session(&session)?; + let outputs = net.outputs(&session); + for output in outputs.iter() { + println!("output: {:?}", output); + let tensor = output.tensor::()?; + let shape = tensor.shape(); + assert_eq!(shape.as_ref(), [1, 3, 2048, 2048]); + } + Ok(()) +} + +#[test] +fn test_multi_path_cpu_cpu() { + test_multipath_session(ForwardType::CPU, ForwardType::CPU).unwrap(); +} + +#[cfg(feature = "opencl")] +#[test] +fn test_multi_path_opencl_cpu() { + test_multipath_session(ForwardType::OpenCL, ForwardType::CPU).unwrap(); +}