Skip to content

Commit

Permalink
fix(build): Possibly fix build issue on windows devices
Browse files Browse the repository at this point in the history
The halide_type_t_64.patch file was manually applied in the build.rs
instead of using the diffy crate

chore(tests): Added more tests for resizing and test cases where mnn
segfaults if wrong input is given
  • Loading branch information
uttarayan21 committed Oct 16, 2024
1 parent b54976f commit 5cf6509
Show file tree
Hide file tree
Showing 14 changed files with 578 additions and 63 deletions.
325 changes: 313 additions & 12 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ default = ["mnn-threadpool"]
anyhow = "1.0"
bytemuck = "1.17"
clap = { version = "4.5", features = ["derive"] }
divan = "0.1.14"
miette = { version = "7.2.0", features = ["fancy"] }

[[bench]]
name = "mnn-bench"
harness = false
33 changes: 33 additions & 0 deletions benches/mnn-bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use divan::*;
#[divan::bench_group(sample_size = 5, sample_count = 5)]
mod mnn_realesr_bench_with_ones {
use divan::*;
use mnn::*;
#[divan::bench]
pub fn mnn_benchmark_cpu(bencher: Bencher) {
let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap();
let mut config = ScheduleConfig::new();
config.set_type(ForwardType::CPU);
let session = net.create_session(config).unwrap();
bencher.bench_local(|| {
let mut input = net.input(&session, "data").unwrap();
input.fill(1f32);
net.run_session(&session).unwrap();
});
}

#[cfg(feature = "opencl")]
#[divan::bench]
pub fn mnn_benchmark_opencl(bencher: Bencher) {
let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap();
let mut config = ScheduleConfig::new();
config.set_type(ForwardType::OpenCL);
let session = net.create_session(config).unwrap();
bencher.bench_local(|| {
let mut input = net.input(&session, "data").unwrap();
input.fill(1f32);
net.run_session(&session).unwrap();
net.wait(&session);
});
}
}
8 changes: 8 additions & 0 deletions examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ pub fn main() -> anyhow::Result<()> {
interpreter.update_cache_file(&mut session)?;

let mut current = 0;
println!("--------------------------------Info--------------------------------");
let mem = interpreter.memory(&session)?;
let flops = interpreter.flops(&session)?;
println!("Memory: {:?}MiB", mem);
println!("Flops : {:?}M", flops);
println!("ResizeStatus : {:?}", interpreter.resize_status(&session)?);

time!(loop {
println!("--------------------------------Inputs--------------------------------");
interpreter.inputs(&session).iter().for_each(|x| {
Expand All @@ -75,6 +82,7 @@ pub fn main() -> anyhow::Result<()> {
},
};
});

println!("Running session");
interpreter.run_session(&session)?;
println!("--------------------------------Outputs--------------------------------");
Expand Down
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 1 addition & 5 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@
craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain;
craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools;

mnnFilters = path: type: (craneLib.filterCargoSources path type) || (lib.hasSuffix ".patch" path || lib.hasSuffix ".mnn" path || lib.hasSuffix ".h" path || lib.hasSuffix ".cpp" path || lib.hasSuffix ".svg" path);
src = lib.cleanSourceWith {
filter = mnnFilters;
src = ./.;
};
src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" "lock"];
MNN_SRC = mnn-src;
commonArgs =
{
Expand Down
30 changes: 26 additions & 4 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use anyhow::*;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::{
fs::Permissions,
path::{Path, PathBuf},
sync::LazyLock,
};
Expand All @@ -27,6 +26,13 @@ static EMSCRIPTEN_CACHE: LazyLock<String> = LazyLock::new(|| {
emscripten_cache
});

const HALIDE_PATCH_1: &str = r#"#if __cplusplus >= 201103L"#;
const HALIDE_PATCH_2: &str = r#"
#else
HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t
#endif
"#;

fn ensure_vendor_exists(vendor: impl AsRef<Path>) -> Result<()> {
if vendor
.as_ref()
Expand Down Expand Up @@ -65,9 +71,15 @@ fn main() -> Result<()> {
.context("Failed to copy vendor")?;
let intptr = vendor.join("include").join("MNN").join("HalideRuntime.h");
#[cfg(unix)]
std::fs::set_permissions(&intptr, Permissions::from_mode(0o644))?;
try_patch_file("patches/halide_type_t_64.patch", intptr)
.context("Failed to patch vendor")?;
std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?;
// try_patch_file("patches/halide_type_t_64.patch", intptr)
// .context("Failed to patch vendor")?;

let intptr_contents = std::fs::read_to_string(&intptr)?;
let patched = intptr_contents
.replace(HALIDE_PATCH_1, "")
.replace(HALIDE_PATCH_2, "");
std::fs::write(intptr, patched)?;
}

let install_dir = out_dir.join("mnn-install");
Expand Down Expand Up @@ -258,6 +270,7 @@ pub fn build_cmake(path: impl AsRef<Path>, install: impl AsRef<Path>) -> Result<
config.define("MNN_COREML", CxxOption::COREML.cmake_value());
config.define("MNN_OPENCL", CxxOption::OPENCL.cmake_value());
config.define("MNN_OPENGL", CxxOption::OPENGL.cmake_value());
config.define("CMAKE_CXX_FLAGS", "-O0");
// #[cfg(windows)]
if *TARGET_OS == "windows" {
config.define("CMAKE_CXX_FLAGS", "-DWIN32=1");
Expand Down Expand Up @@ -444,3 +457,12 @@ impl CxxOption {
}
}
}

// mod cc_build {
// use super::*;
// pub fn build(source: impl AsRef<Path>) -> Result<PathBuf> {
// let mut builder = cc::Build::new();
// builder.std("c++11").cpp(true);
// todo!()
// }
// }
3 changes: 2 additions & 1 deletion mnn-sys/mnn_c/interpreter_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session,
int code, void *ptr) {
auto mnn_interpreter = reinterpret_cast<MNN::Interpreter *>(interpreter);
auto mnn_session = reinterpret_cast<const MNN::Session *>(session);
return mnn_interpreter->getSessionInfo(
auto ret = mnn_interpreter->getSessionInfo(
mnn_session, static_cast<MNN::Interpreter::SessionInfoCode>(code), ptr);
return static_cast<int>(ret);
}
TensorInfoArray const *
Interpreter_getSessionOutputAll(const Interpreter *interpreter,
Expand Down
84 changes: 75 additions & 9 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ impl Interpreter {
unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) }
}

pub fn resize_session_reallocate(&self, session: &mut crate::Session) {
unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) }
}

pub fn resize_tensor<T: TensorType>(&self, tensor: &mut Tensor<T>, dims: impl AsTensorShape) {
let dims = dims.as_tensor_shape();
let dims_len = dims.size;
Expand Down Expand Up @@ -420,12 +424,74 @@ impl Interpreter {
Ok(())
}

// /// Wait for all output tensors to be ready after computation
// pub fn wait(&self, session: &crate::session::Session) {
// self.outputs(session).iter().for_each(|tinfo| {
// tinfo.raw_tensor().wait_read(true);
// });
// }
/// Wait for all output tensors to be ready after computation
pub fn wait(&self, session: &crate::session::Session) {
self.outputs(session).iter().for_each(|tinfo| {
tinfo
.raw_tensor()
.wait(mnn_sys::MapType::MAP_TENSOR_READ, true);
});
}

pub fn memory(&self, session: &crate::session::Session) -> Result<f32> {
let mut memory = 0f32;
let memory_ptr = &mut memory as *mut f32;
let ret = unsafe {
mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 0, memory_ptr.cast())
};
ensure!(
ret == 1,
ErrorKind::InterpreterError;
"Failed to get memory usage"
);
Ok(memory)
}

pub fn flops(&self, session: &crate::Session) -> Result<f32> {
let mut flop = 0.0f32;
let flop_ptr = &mut flop as *mut f32;
let ret = unsafe {
mnn_sys::Interpreter_getSessionInfo(
self.inner,
session.inner,
1,
flop_ptr.cast::<libc::c_void>(),
)
};
ensure!(
ret == 1,
ErrorKind::InterpreterError;
"Failed to get flops"
);
Ok(flop)
}

pub fn resize_status(&self, session: &crate::Session) -> Result<ResizeStatus> {
let mut resize_status = 0i32;
let ptr = &mut resize_status as *mut i32;
let ret = unsafe {
mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 2, ptr.cast())
};
ensure!(
ret == 1,
ErrorKind::InterpreterError;
"Failed to get resize status"
);
match resize_status {
0 => Ok(ResizeStatus::None),
1 => Ok(ResizeStatus::NeedMalloc),
2 => Ok(ResizeStatus::NeedResize),
_ => Err(error!(ErrorKind::InterpreterError)),
}
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(C)]
pub enum ResizeStatus {
None = 0,
NeedMalloc = 1,
NeedResize = 2,
}

#[repr(transparent)]
Expand All @@ -436,11 +502,11 @@ pub struct TensorInfo<'t, 'tl> {

impl core::fmt::Debug for TensorInfo<'_, '_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// let tensor = self.raw_tensor();
// let shape = tensor.shape().clone();
let tensor = self.raw_tensor();
let shape = tensor.shape().clone();
f.debug_struct("TensorInfo")
.field("name", &self.name())
// .field("tensor", &shape)
.field("tensor", &shape)
.finish()
}
}
Expand Down
48 changes: 20 additions & 28 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,25 +537,10 @@ pub trait AsTensorShape {
impl<T: AsRef<[i32]>> AsTensorShape for T {
fn as_tensor_shape(&self) -> TensorShape {
let this = self.as_ref();
let len = this.len();
if len > 4 {
TensorShape {
shape: this[..4].try_into().expect("Impossible"),
size: 4,
}
} else {
TensorShape {
shape: this
.iter()
.chain(std::iter::repeat(&1))
.take(4)
.copied()
.collect::<Vec<i32>>()
.try_into()
.expect("Impossible"),
size: len,
}
}
let size = std::cmp::min(this.len(), 4);
let mut shape = [1; 4];
shape[..size].copy_from_slice(&this[..size]);
TensorShape { shape, size }
}
}

Expand Down Expand Up @@ -684,25 +669,33 @@ impl<T: super::TensorType> super::TensorType for Dyn<T> {
}
}

/// A raw tensor type that doesn't have any guarantees
/// and will be unconditionally dropped
#[repr(transparent)]
pub struct RawTensor<'r> {
pub(crate) inner: *mut mnn_sys::Tensor,
pub(crate) __marker: PhantomData<&'r ()>,
}

impl<'r> core::ops::Drop for RawTensor<'r> {
fn drop(&mut self) {
unsafe {
mnn_sys::Tensor_destroy(self.inner);
}
}
}
// impl<'r> core::ops::Drop for RawTensor<'r> {
// fn drop(&mut self) {
// unsafe {
// mnn_sys::Tensor_destroy(self.inner);
// }
// }
// }

impl<'r> RawTensor<'r> {
pub fn shape(&self) -> TensorShape {
unsafe { mnn_sys::Tensor_shape(self.inner) }.into()
}

pub fn destroy(self) {
unsafe {
mnn_sys::Tensor_destroy(self.inner);
}
}

pub fn dimensions(&self) -> usize {
unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize }
}
Expand Down Expand Up @@ -735,8 +728,7 @@ impl<'r> RawTensor<'r> {
where
T::H: HalideType,
{
let this = core::mem::ManuallyDrop::new(self);
super::Tensor::from_ptr(this.inner)
super::Tensor::from_ptr(self.inner)
}

pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self {
Expand Down
3 changes: 3 additions & 0 deletions tests/assets/resizing.mnn
Git LFS file not shown
Loading

0 comments on commit 5cf6509

Please sign in to comment.