Skip to content

Commit

Permalink
feat: Added multi-path session and caching of models
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Sep 10, 2024
1 parent 803bd64 commit 58f6e52
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 64 deletions.
70 changes: 43 additions & 27 deletions examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ pub struct Cli {
power: PowerMode,
#[clap(short, long, default_value = "high")]
precision: PrecisionMode,
#[clap(short, long, default_value = "high")]
memory: MemoryMode,
#[clap(short, long, default_value = "1")]
loops: usize,
}

macro_rules! time {
Expand All @@ -30,40 +34,52 @@ macro_rules! time {
pub fn main() -> anyhow::Result<()> {
use clap::Parser;
let cli = Cli::parse();
let mut interpreter = Interpreter::from_file(cli.model)?;
let mut interpreter = Interpreter::from_file(&cli.model)?;
interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?;

let mut config = ScheduleConfig::new();
config.set_type(cli.forward);
let mut backend_config = BackendConfig::new();
backend_config.set_precision_mode(PrecisionMode::High);
backend_config.set_power_mode(PowerMode::High);

config.set_backend_config(backend_config);
// let mut backend_config = BackendConfig::new();
// backend_config.set_precision_mode(PrecisionMode::High);
// backend_config.set_power_mode(PowerMode::High);
// config.set_backend_config(backend_config);
// let handle = mnn::sync::SessionHandle::new(interpreter, config)?;
let session = time!(interpreter.create_session(config)?; "create session");
let mut session = time!(interpreter.create_session(config)?; "create session");
interpreter.update_cache_file(&mut session)?;
let mut input = interpreter.input::<f32>(&session, "image")?;
let mut shape = input.shape();
shape[0] = 512;
shape[1] = 512;
shape[2] = 3;
interpreter.resize_tensor(&mut input, shape);
drop(input);
interpreter.resize_session(&mut session);
// let session = time!(interpreter.create_session(config)?; "create session");
// handle.run(|sr| {
// let interpreter = sr.interpreter();
// let session = sr.session();
let inputs = interpreter.inputs(&session);
inputs.iter().for_each(|x| {
let mut tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1.0f32);
// let mut cpu_tensor = tensor.create_host_tensor_from_device(false);
// cpu_tensor.host_mut().fill(1.0f32);
// tensor
// .copy_from_host_tensor(&cpu_tensor)
// .expect("Could not copy tensor");
});
time!(interpreter.run_session(&session)?;"run session");
let outputs = interpreter.outputs(&session);
outputs.iter().for_each(|x| {
let tensor = x.tensor::<f32>().expect("No tensor");
time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
println!("{}: {:?}", x.name(), tensor.shape());
let _ = tensor.create_host_tensor_from_device(true);
// std::fs::write(format!("{}.bin", x.name()), bytemuck::cast_slice(cpu_tensor.host())).expect("Unable to write");
});
let mut current = 0;
time!(loop {
let inputs = interpreter.inputs(&session);
inputs.iter().for_each(|x| {
let mut tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1.0f32);
});
time!(interpreter.run_session(&session)?;"run session");
let outputs = interpreter.outputs(&session);
outputs.iter().for_each(|x| {
let tensor = x.tensor::<u8>().expect("No tensor");
time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
println!("{}: {:?}", x.name(), tensor.shape());
let _ = tensor.create_host_tensor_from_device(true);
// std::fs::write(format!("{}.bin", x.name()), bytemuck::cast_slice(cpu_tensor.host())).expect("Unable to write");
});
current += 1;
if current >= cli.loops {
break;
}
}; "run loop");
// Ok(())
// })?;
Ok(())
Expand Down
46 changes: 32 additions & 14 deletions mnn-sys/mnn_c/interpreter_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ void Interpreter_setExternalFile(Interpreter *interpreter, const char *file,
mnn_interpreter->setExternalFile(file, flag);
}
ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter,
Session *session, int flag) {
Session *session) {
auto mnn_interpreter = reinterpret_cast<MNN::Interpreter *>(interpreter);
auto mnn_session = reinterpret_cast<MNN::Session *>(session);
return static_cast<ErrorCode>(
mnn_interpreter->updateCacheFile(mnn_session, flag));
return static_cast<ErrorCode>(mnn_interpreter->updateCacheFile(mnn_session));
}
void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value) {
auto mnn_interpreter = reinterpret_cast<MNN::Interpreter *>(interpreter);
Expand Down Expand Up @@ -111,17 +110,36 @@ Session *Interpreter_createSession(Interpreter *interpreter,
// cppConfig.backendConfig = config->backendConfig;
// return interpreter->createSession(cppConfig, *runtime);
// }
Session *Interpreter_createMultiPathSession(Interpreter *interpreter,
const MNNScheduleConfig *configs,
size_t configSize) {

auto mnn_configs = reinterpret_cast<const MNN::ScheduleConfig *>(configs);
// @todo: check if this is correct
std::vector<MNN::ScheduleConfig> cppConfigs(mnn_configs,
mnn_configs + configSize);
// Session *Interpreter_createMultiPathSession(Interpreter *interpreter,
// const MNNScheduleConfig *configs,
// size_t configSize) {
//
// auto mnn_configs = reinterpret_cast<const MNN::ScheduleConfig *>(configs);
// std::vector<MNN::ScheduleConfig> cppConfigs(mnn_configs,
// mnn_configs + configSize);
// auto mnn_interpreter = reinterpret_cast<MNN::Interpreter *>(interpreter);
// return reinterpret_cast<Session *>(
// mnn_interpreter->createMultiPathSession(cppConfigs));
// }
Session *
Interpreter_createMultiPathSession(Interpreter *interpreter,
const MNNScheduleConfig *const *configs,
size_t configSize) {
auto mnn_configs =
reinterpret_cast<const MNN::ScheduleConfig *const *>(configs);
std::vector<MNN::ScheduleConfig> s_configs;
for (size_t i = 0; i < configSize; ++i) {
s_configs.push_back(*mnn_configs[i]);
}
// std::vector<MNN::ScheduleConfig *> cppConfigs(mnn_configs,
// mnn_configs + configSize);
// Create a std::vector<MMN::ScheduleConfig> from
// std::vector<MNN::ScheduleConfig *>
// auto s_configs =
// std::vector<MNN::ScheduleConfig>(cppConfigs.begin(), cppConfigs.end());
auto mnn_interpreter = reinterpret_cast<MNN::Interpreter *>(interpreter);
return reinterpret_cast<Session *>(
mnn_interpreter->createMultiPathSession(cppConfigs));
MNN::Session *session = mnn_interpreter->createMultiPathSession(s_configs);
return reinterpret_cast<Session *>(session);
}

// Session* Interpreter_createMultiPathSessionWithRuntime(Interpreter*
Expand Down Expand Up @@ -299,6 +317,6 @@ const char *Interpreter_uuid(const Interpreter *interpreter) {
}
void Session_destroy(Session *session) {
auto mnn_session = reinterpret_cast<MNN::Session *>(session);
delete mnn_session;
// delete mnn_session;
}
} // extern "C"
21 changes: 12 additions & 9 deletions mnn-sys/mnn_c/interpreter_c.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef INTERPRETER_C_H
#define INTERPRETER_C_H
#include "backend_c.h"
#include "schedule_c.h"
#include "error_code_c.h"
#include "schedule_c.h"
#include "tensor_c.h"
#include "utils.h"
#include <MNN/HalideRuntime.h>
Expand Down Expand Up @@ -144,7 +144,7 @@ void Interpreter_setCacheFile(Interpreter *interpreter, const char *cacheFile,
void Interpreter_setExternalFile(Interpreter *interpreter, const char *file,
size_t flag);
ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter,
Session *session, int flag);
Session *session);
void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value);
// RuntimeInfo *Interpreter_createRuntime(const ScheduleConfig *configs,
// size_t configSize);
Expand All @@ -153,9 +153,10 @@ Session *Interpreter_createSession(Interpreter *interpreter,
// Session *Interpreter_createSessionWithRuntime(Interpreter *interpreter,
// const ScheduleConfig *config,
// const RuntimeInfo *runtime);
Session *Interpreter_createMultiPathSession(Interpreter *interpreter,
const MNNScheduleConfig *configs,
size_t configSize);
Session *
Interpreter_createMultiPathSession(Interpreter *interpreter,
const MNNScheduleConfig *const *configs,
size_t configSize);
// Session *Interpreter_createMultiPathSessionWithRuntime(
// Interpreter *interpreter, const ScheduleConfig *configs, size_t
// configSize, const RuntimeInfo *runtime);
Expand Down Expand Up @@ -186,11 +187,13 @@ Tensor *Interpreter_getSessionOutput(Interpreter *interpreter,
const Session *session, const char *name);
int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session,
int code, void *ptr);
TensorInfoArray const * Interpreter_getSessionOutputAll(const Interpreter *interpreter,
const Session *session);
TensorInfoArray const *
Interpreter_getSessionOutputAll(const Interpreter *interpreter,
const Session *session);

TensorInfoArray const * Interpreter_getSessionInputAll(const Interpreter *interpreter,
const Session *session);
TensorInfoArray const *
Interpreter_getSessionInputAll(const Interpreter *interpreter,
const Session *session);
void Interpreter_resizeTensor(Interpreter *interpreter, Tensor *tensor,
const int *dims, size_t dimsSize);
void Interpreter_resizeTensorByNCHW(Interpreter *interpreter, Tensor *tensor,
Expand Down
57 changes: 45 additions & 12 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{prelude::*, AsTensorShape, Device, Ref, RefMut, Tensor, TensorType};
use std::path::Path;

use crate::{prelude::*, AsTensorShape, Device, Ref, RefMut, ScheduleConfig, Tensor, TensorType};
use mnn_sys::HalideType;

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -59,7 +61,7 @@ pub struct Interpreter {
unsafe impl Send for Interpreter {}

impl Interpreter {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
ensure!(path.exists(), ErrorKind::IOError; path.to_string_lossy().to_string());
let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
Expand Down Expand Up @@ -134,13 +136,30 @@ impl Interpreter {
assert!(!session.is_null());
Ok(crate::session::Session {
inner: session,
__schedule_config: schedule,
__session_internals: crate::SessionInternals::Single(schedule),
__marker: PhantomData,
})
})
}

pub fn create_multipath_session(
&mut self,
schedule: impl IntoIterator<Item = ScheduleConfig>,
) -> Result<crate::session::Session> {
profile!("Creating multipath session"; {
let schedules: crate::ScheduleConfigs = schedule.into_iter().collect();
let sc: &[_] = schedules.inner.as_ref();
let session = unsafe { mnn_sys::Interpreter_createMultiPathSession(self.inner, sc.as_ptr(), sc.len()) };
assert!(!session.is_null());
Ok(crate::session::Session {
inner: session,
__session_internals: crate::SessionInternals::MultiSession(schedules),
__marker: PhantomData,
})
})
}

pub fn model_print_io(path: impl AsRef<std::path::Path>) -> Result<()> {
pub fn model_print_io(path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
crate::ensure!(path.exists(), ErrorKind::IOError);
let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
Expand All @@ -166,8 +185,8 @@ impl Interpreter {
};
ensure!(!input.is_null(), ErrorKind::TensorError; format!("Input tensor \"{name}\" not found"));
let tensor = unsafe { Tensor::from_ptr(input) };
let shape = tensor.shape();
ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
// let shape = tensor.shape();
// ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError);
ensure!(
tensor.is_type_of::<H>(),
ErrorKind::HalideTypeMismatch {
Expand Down Expand Up @@ -203,12 +222,12 @@ impl Interpreter {

pub fn run_session(&self, session: &crate::session::Session) -> Result<()> {
profile!("Running session"; {
let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
ensure!(
ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
ErrorKind::InternalError(ret)
);
Ok(())
let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
ensure!(
ret == mnn_sys::ErrorCode::ERROR_CODE_NO_ERROR,
ErrorKind::InternalError(ret)
);
Ok(())
})
}

Expand All @@ -217,6 +236,20 @@ impl Interpreter {
unsafe { mnn_sys::Interpreter_getSessionOutputAll(self.inner, session.inner) };
TensorList::from_ptr(outputs)
}

pub fn set_cache_file(&mut self, path: impl AsRef<Path>, key_size: usize) -> Result<()> {
let path = path.as_ref();
let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?;
let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?;
unsafe { mnn_sys::Interpreter_setCacheFile(self.inner, c_path.as_ptr(), key_size) }
Ok(())
}
pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
MNNError::from_error_code(unsafe {
mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
});
Ok(())
}
}

#[repr(transparent)]
Expand Down
58 changes: 57 additions & 1 deletion src/schedule.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use mnn_sys::*;
use std::ffi::CString;
use std::{ffi::CString, mem::ManuallyDrop};

use crate::{prelude::*, BackendConfig};

Expand Down Expand Up @@ -177,3 +177,59 @@ impl ScheduleConfig {
}
}
}

pub struct ScheduleConfigs {
pub(crate) inner: Vec<*const MNNScheduleConfig>,
pub(crate) backend_configs: Vec<Option<BackendConfig>>,
}

impl Drop for ScheduleConfigs {
fn drop(&mut self) {
unsafe {
for i in self.inner.iter() {
mnnsc_destroy(*i.cast());
}
}
}
}

impl ScheduleConfigs {
pub fn push(&mut self, config: ScheduleConfig) {
let mut config = ManuallyDrop::new(config);
self.inner.push(config.inner);
self.backend_configs.push(config.backend_config.take());
}

pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Vec::with_capacity(capacity),
backend_configs: Vec::with_capacity(capacity),
}
}

pub const fn new() -> Self {
Self {
inner: Vec::new(),
backend_configs: Vec::new(),
}
}
}

impl Default for ScheduleConfigs {
fn default() -> Self {
Self::new()
}
}

impl FromIterator<ScheduleConfig> for ScheduleConfigs {
fn from_iter<T: IntoIterator<Item = ScheduleConfig>>(iter: T) -> Self {
let iter = iter.into_iter();
let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default());
iter.for_each(|item| {
ret.push(item);
});
ret
}
}

unsafe impl Send for ScheduleConfigs {}
7 changes: 6 additions & 1 deletion src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ use crate::prelude::*;

pub struct Session {
pub(crate) inner: *mut mnn_sys::Session,
pub(crate) __schedule_config: crate::ScheduleConfig,
pub(crate) __session_internals: crate::SessionInternals,
pub(crate) __marker: PhantomData<()>,
}

pub enum SessionInternals {
Single(crate::ScheduleConfig),
MultiSession(crate::ScheduleConfigs),
}

impl Session {
// pub unsafe fn from_ptr(session: *mut mnn_sys::Session) -> Self {
// Self {
Expand Down
Loading

0 comments on commit 58f6e52

Please sign in to comment.