Skip to content

Commit

Permalink
feat(API): Don't need &mut anymore to run / create session
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Nov 13, 2024
1 parent 07f1dda commit 7f8698a
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 91 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
matrix: ${{fromJSON(needs.checks-matrix.outputs.matrix)}}
steps:
- uses: actions/checkout@v4
with:
lfs: true
submodules: 'recursive'
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- run: nix build -L '.#${{ matrix.attr }}'
Expand All @@ -44,6 +47,9 @@ jobs:

steps:
- uses: actions/checkout@v4
with:
lfs: true
submodules: 'recursive'
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main

Expand Down
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ mnn = { version = "0.3.0", path = "." }
error-stack = { version = "0.5" }

[dependencies]
libc = "0.2"
mnn-sys = { version = "0.2", path = "mnn-sys", features = [] }
mnn-sys = { version = "0.2", path = "mnn-sys", default-features = false, features = [
"mnn-threadpool",
"sparse-compute",
] }
thiserror = "2"
error-stack.workspace = true
oneshot = "0.1"
tracing = { version = "0.1.40", optional = true }
dunce = "1.0.5"

Expand All @@ -32,13 +33,16 @@ opencl = ["mnn-sys/opencl"]
metal = ["mnn-sys/metal"]
coreml = ["mnn-sys/coreml"]

vulkan = [] # This is currently unimplemented

crt_static = ["mnn-sys/crt_static"]
# Disable mnn-threadpool to enable this
mnn-threadpool = ["mnn-sys/mnn-threadpool"]

tracing = ["dep:tracing"]
profile = ["tracing"]

default = ["mnn-threadpool"]
simd = ["mnn-sys/simd"]

default = ["simd"]


[dev-dependencies]
Expand Down
5 changes: 4 additions & 1 deletion mnn-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ sparse-compute = []
arm82 = []
bf16 = []
cpu-weight-dequant-gemm = []

# Disable if you don't plan to use cpu backend and want quicker compilation
sse = []
avx512 = []
neon = []
simd = ["sse", "avx512", "neon"]

low-memory = []

default = ["mnn-threadpool", "sparse-compute", "sse", "neon", "opencl"]
default = ["mnn-threadpool", "sparse-compute", "opencl", "simd"]
22 changes: 11 additions & 11 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ static TARGET_FEATURES: LazyLock<Vec<String>> = LazyLock::new(|| {

static TARGET_OS: LazyLock<String> =
LazyLock::new(|| std::env::var("CARGO_CFG_TARGET_OS").expect("CARGO_CFG_TARGET_OS not set"));

static TARGET_ARCH: LazyLock<String> = LazyLock::new(|| {
std::env::var("CARGO_CFG_TARGET_ARCH").expect("CARGO_CFG_TARGET_ARCH not found")
});

static EMSCRIPTEN_CACHE: LazyLock<String> = LazyLock::new(|| {
let emscripten_cache = std::process::Command::new("em-config")
.arg("CACHE")
Expand Down Expand Up @@ -180,7 +182,6 @@ fn _main() -> Result<()> {
}
}
if is_emscripten() {
// println!("cargo:rustc-link-lib=static=stdc++");
let emscripten_cache = std::process::Command::new("em-config")
.arg("CACHE")
.output()
Expand Down Expand Up @@ -526,7 +527,7 @@ fn read_dir(input: impl AsRef<Path>) -> impl Iterator<Item = PathBuf> {
ignore::WalkBuilder::new(input)
.max_depth(Some(1))
.build()
.filter_map(Result::ok)
.flatten()
.map(|e| e.into_path())
}

Expand Down Expand Up @@ -559,12 +560,11 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {

let core_files_dir = vendor.join("source").join("core");
let core_files = ignore::Walk::new(&core_files_dir)
.filter_map(Result::ok)
.flatten()
.filter(|e| e.path().extension() == Some(OsStr::new("cpp")))
.map(|e| e.into_path());
build.files(core_files);

// #[cfg(feature = "cpu")]
{
let cpu_files_dir = vendor.join("source").join("backend").join("cpu");
let cpu_files = ignore::WalkBuilder::new(&cpu_files_dir)
Expand All @@ -573,7 +573,7 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
.add_custom_ignore_filename("CPUImageProcess.hpp")
.add_custom_ignore_filename("CPUImageProcess.cpp")
.build()
.filter_map(Result::ok)
.flatten()
.filter(|e| e.path().extension() == Some(OsStr::new("cpp")))
.map(|e| e.into_path());

Expand All @@ -597,7 +597,7 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
{
let cv_files_dir = vendor.join("source").join("cv");
let cv_files = ignore::Walk::new(&cv_files_dir)
.filter_map(Result::ok)
.flatten()
.filter(|e| e.path().extension() == Some(OsStr::new("cpp")))
.map(|e| e.into_path());
// build.include(cv_files_dir.join("schema").join("current"));
Expand All @@ -613,8 +613,8 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
.add(vendor.join("source").join("geometry"))
.add(vendor.join("source").join("utils"))
.build()
.filter_map(Result::ok)
.filter(|e| e.path().extension() == Some(OsStr::new("cpp")))
.flatten()
.filter(|p| cpp_filter(p.path()))
.map(|e| e.into_path());
build.files(extra_files);
}
Expand All @@ -623,7 +623,7 @@ pub fn build_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
{
let opencl_files_dir = vendor.join("source").join("backend").join("opencl");
let opencl_files = ignore::Walk::new(&opencl_files_dir)
.filter_map(Result::ok)
.flatten()
.filter(|e| e.path().extension() == Some(OsStr::new("cpp")))
.map(|e| e.into_path());
let ocl_includes = opencl_files_dir.join("schema").join("current");
Expand Down Expand Up @@ -654,7 +654,7 @@ fn arm(build: &mut cc::Build, arm_dir: impl AsRef<Path>) -> Result<&mut cc::Buil
if *TARGET_POINTER_WIDTH == 64 {
let arm64_sources_dir = arm_source_dir.join("arm64");
let arm64_sources = ignore::Walk::new(&arm64_sources_dir)
.filter_map(Result::ok)
.flatten()
.filter(|e| {
e.path().extension() == Some(OsStr::new("S"))
|| e.path().extension() == Some(OsStr::new("s"))
Expand All @@ -672,7 +672,7 @@ fn arm(build: &mut cc::Build, arm_dir: impl AsRef<Path>) -> Result<&mut cc::Buil
} else if *TARGET_POINTER_WIDTH == 32 {
let arm32_sources_dir = arm_source_dir.join("arm32");
let arm32_sources = ignore::Walk::new(&arm32_sources_dir)
.filter_map(Result::ok)
.flatten()
.filter(|e| {
e.path().extension() == Some(OsStr::new("S"))
|| e.path().extension() == Some(OsStr::new("s"))
Expand Down
36 changes: 17 additions & 19 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ impl Interpreter {
///
/// return: the created session
pub fn create_session(
&mut self,
&self,
schedule: crate::ScheduleConfig,
) -> Result<crate::session::Session> {
) -> Result<crate::session::Session<'_>> {
profile!("Creating session"; {
let session = unsafe { mnn_sys::Interpreter_createSession(self.inner, schedule.inner) };
assert!(!session.is_null());
Expand All @@ -270,7 +270,7 @@ impl Interpreter {
///
/// return: the created session
pub fn create_multipath_session(
&mut self,
&self,
schedule: impl IntoIterator<Item = ScheduleConfig>,
) -> Result<crate::session::Session> {
profile!("Creating multipath session"; {
Expand Down Expand Up @@ -436,7 +436,7 @@ impl Interpreter {
}

/// Run a session
pub fn run_session(&mut self, session: &crate::session::Session) -> Result<()> {
pub fn run_session(&self, session: &crate::session::Session) -> Result<()> {
profile!("Running session"; {
let ret = unsafe { mnn_sys::Interpreter_runSession(self.inner, session.inner) };
ensure!(
Expand All @@ -457,7 +457,7 @@ impl Interpreter {
///
/// `sync` : synchronously wait for finish of execution or not.
pub fn run_session_with_callback(
&mut self,
&self,
session: &crate::session::Session,
before: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
end: impl Fn(&[RawTensor], OperatorInfo) -> bool + 'static,
Expand Down Expand Up @@ -510,7 +510,7 @@ impl Interpreter {
}

/// Update cache file
pub fn update_cache_file(&mut self, session: &mut crate::session::Session) -> Result<()> {
pub fn update_cache_file(&self, session: &mut crate::session::Session) -> Result<()> {
MNNError::from_error_code(unsafe {
mnn_sys::Interpreter_updateCacheFile(self.inner, session.inner)
});
Expand Down Expand Up @@ -820,7 +820,6 @@ impl OperatorInfo<'_> {
}

#[test]
#[ignore = "This test doesn't work in CI"]
fn test_run_session_with_callback_info_api() {
let file = Path::new("tests/assets/realesr.mnn")
.canonicalize()
Expand All @@ -838,7 +837,6 @@ fn test_run_session_with_callback_info_api() {
}

#[test]
#[ignore = "This test doesn't work in CI"]
fn check_whether_sync_actually_works() {
let file = Path::new("tests/assets/realesr.mnn")
.canonicalize()
Expand Down Expand Up @@ -868,14 +866,14 @@ fn check_whether_sync_actually_works() {
assert!((time - time2) > std::time::Duration::from_millis(50));
}

#[test]
#[ignore = "Fails on CI"]
fn try_to_drop_interpreter_before_session() {
let file = Path::new("tests/assets/realesr.mnn")
.canonicalize()
.unwrap();
let mut interpreter = Interpreter::from_file(&file).unwrap();
let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
drop(interpreter);
drop(session);
}
// Impossible to compile
// #[test]
// fn try_to_drop_interpreter_before_session() {
// let file = Path::new("tests/assets/realesr.mnn")
// .canonicalize()
// .unwrap();
// let mut interpreter = Interpreter::from_file(&file).unwrap();
// let session = interpreter.create_session(ScheduleConfig::new()).unwrap();
// drop(interpreter);
// drop(session);
// }
9 changes: 0 additions & 9 deletions src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::{prelude::*, BackendConfig};
/// - `CPU`: Use the CPU for computation.
/// - `Metal`: Use the Metal backend for computation (requires the `metal` feature).
/// - `OpenCL`: Use the OpenCL backend for computation (requires the `opencl` feature).
/// - `OpenGL`: Use the OpenGL backend for computation (requires the `opengl` feature).
/// - `Vulkan`: Use the Vulkan backend for computation (requires the `vulkan` feature).
/// - `CoreML`: Use the CoreML backend for computation (requires the `coreml` feature).
///
Expand All @@ -43,8 +42,6 @@ pub enum ForwardType {
Metal,
#[cfg(feature = "opencl")]
OpenCL,
#[cfg(feature = "opengl")]
OpenGL,
#[cfg(feature = "vulkan")]
Vulkan,
#[cfg(feature = "coreml")]
Expand All @@ -62,8 +59,6 @@ impl ForwardType {
ForwardType::Metal => MNNForwardType::MNN_FORWARD_METAL,
#[cfg(feature = "opencl")]
ForwardType::OpenCL => MNNForwardType::MNN_FORWARD_OPENCL,
#[cfg(feature = "opengl")]
ForwardType::OpenGL => MNNForwardType::MNN_FORWARD_OPENGL,
#[cfg(feature = "vulkan")]
ForwardType::Vulkan => MNNForwardType::MNN_FORWARD_VULKAN,
#[cfg(feature = "coreml")]
Expand All @@ -80,8 +75,6 @@ impl ForwardType {
"metal",
#[cfg(feature = "opencl")]
"opencl",
#[cfg(feature = "opengl")]
"opengl",
#[cfg(feature = "vulkan")]
"vulkan",
#[cfg(feature = "coreml")]
Expand All @@ -102,8 +95,6 @@ impl core::str::FromStr for ForwardType {
"metal" => Ok(ForwardType::Metal),
#[cfg(feature = "opencl")]
"opencl" => Ok(ForwardType::OpenCL),
#[cfg(feature = "opengl")]
"opengl" => Ok(ForwardType::OpenGL),
#[cfg(feature = "vulkan")]
"vulkan" => Ok(ForwardType::Vulkan),
#[cfg(feature = "coreml")]
Expand Down
8 changes: 4 additions & 4 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::prelude::*;
///
/// Inference unit. multiple sessions could share one net/interpreter.
#[derive(Debug)]
pub struct Session {
pub struct Session<'i> {
/// Pointer to the underlying MNN session.
pub(crate) inner: *mut mnn_sys::Session,
/// Pointer to the underlying MNN interpreter
Expand All @@ -17,7 +17,7 @@ pub struct Session {
/// Internal session configurations.
pub(crate) __session_internals: crate::SessionInternals,
/// Marker to ensure the struct is not Send or Sync.
pub(crate) __marker: PhantomData<()>,
pub(crate) __marker: PhantomData<&'i ()>,
}

/// Enum representing the internal configurations of a session.
Expand All @@ -29,7 +29,7 @@ pub enum SessionInternals {
MultiSession(crate::ScheduleConfigs),
}

impl Session {
impl Session<'_> {
// pub unsafe fn from_ptr(session: *mut mnn_sys::Session) -> Self {
// Self {
// session,
Expand All @@ -49,7 +49,7 @@ impl Session {
}
}

impl Drop for Session {
impl Drop for Session<'_> {
/// Custom drop implementation to ensure the underlying MNN session is properly destroyed.
fn drop(&mut self) {
self.destroy();
Expand Down
10 changes: 3 additions & 7 deletions tests/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use mnn::ForwardType;
fn test_basic_cpu() {
test_basic(ForwardType::CPU).unwrap();
}

#[cfg(feature = "metal")]
#[test]
#[ignore = "Doesn't work on ci"]
fn test_basic_metal() {
test_basic(ForwardType::Metal).unwrap();
}

#[cfg(feature = "opencl")]
#[test]
#[ignore = "Doesn't work on ci"]
fn test_basic_opencl() -> Result<(), Box<dyn std::error::Error>> {
let backend = ForwardType::OpenCL;
let realesr = std::path::Path::new("tests/assets/realesr.mnn");
Expand Down Expand Up @@ -46,16 +46,12 @@ fn test_basic_opencl() -> Result<(), Box<dyn std::error::Error>> {
// drop(net);
Ok(())
}

#[cfg(feature = "coreml")]
#[test]
fn test_basic_coreml() {
test_basic(ForwardType::CoreML).unwrap();
}
#[cfg(feature = "opengl")]
#[test]
fn test_basic_opengl() {
test_basic(ForwardType::OpenGL).unwrap();
}

#[test]
#[ignore = "takes too long and unreliable on CI"]
Expand Down
4 changes: 3 additions & 1 deletion tests/resizing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ pub fn test_resizing() -> Result<()> {
let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model");
let mut net = Interpreter::from_bytes(&model).unwrap();
net.set_cache_file("resizing.cache", 128)?;
let config = ScheduleConfig::default();
let mut config = ScheduleConfig::default();
#[cfg(feature = "opencl")]
config.set_type(ForwardType::OpenCL);
#[cfg(not(feature = "opencl"))]
config.set_type(ForwardType::CPU);
let mut session = net.create_session(config).unwrap();
net.update_cache_file(&mut session)?;

Expand Down
Loading

0 comments on commit 7f8698a

Please sign in to comment.