diff --git a/.gitignore b/.gitignore index 859f7d1..139d03b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ target lama *.json *.cache +result diff --git a/.ignore b/.ignore new file mode 100644 index 0000000..d88084b --- /dev/null +++ b/.ignore @@ -0,0 +1 @@ +mnn-sys/vendor diff --git a/Cargo.lock b/Cargo.lock index 557ab06..5e4bbc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -225,7 +225,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d3041965b7a63e70447ec818a46b1e5297f7fcae3058356d226c02750c4e6cb" dependencies = [ - "nu-ansi-term", + "nu-ansi-term 0.50.1", ] [[package]] @@ -359,6 +359,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.158" @@ -397,6 +403,15 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matrixmultiply" version = "0.3.9" @@ -455,6 +470,7 @@ dependencies = [ "mnn", "oneshot", "tracing", + "tracing-test", ] [[package]] @@ -520,6 +536,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -568,6 +594,12 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -631,8 +663,17 @@ checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -643,7 +684,7 @@ checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.4", ] [[package]] @@ -652,6 +693,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.4" @@ -698,12 +745,27 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + [[package]] name = "spin" version = "0.9.8" @@ -766,6 +828,16 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tracing" version = "0.1.40" @@ -795,6 +867,57 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term 0.46.0", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "tracing-test" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" +dependencies = [ + "quote", + "syn", ] [[package]] @@ -815,6 +938,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index e13fffc..c8128ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ anyhow = "1.0" bytemuck = "1.17" clap = { version = "4.5", features = ["derive"] } divan = "0.1.14" +# mnn-sync = { path = "mnn-sync" } [[bench]] name = "mnn-bench" diff --git a/deny.toml b/deny.toml index ee453ec..1e7ddbc 100644 --- a/deny.toml +++ b/deny.toml @@ -88,7 +88,7 @@ ignore = [ # List of explicitly allowed licenses # See https://spdx.org/licenses/ for list of possible licenses # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. -allow = ["MIT", "Apache-2.0", "BSD-3-Clause", "ISC", "Unicode-DFS-2016"] +allow = ["MIT", "Apache-2.0", "BSD-3-Clause", "ISC", "Unicode-DFS-2016", "Zlib"] # The confidence threshold for detecting a license from license text. # The higher the value, the more closely the license text must be to the # canonical license text of a valid SPDX license file. diff --git a/flake.lock b/flake.lock index 8b8f1c8..229a00a 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "advisory-db": { "flake": false, "locked": { - "lastModified": 1725883717, - "narHash": "sha256-QifFNLfu5bzKPO4iznCj1h+nHhqGZ8NR2Lo7tzh9FRc=", + "lastModified": 1730464311, + "narHash": "sha256-9xJoP1766XJSO1Qr0Lxg2P6dwPncTr3BJYlFMSXBd/E=", "owner": "rustsec", "repo": "advisory-db", - "rev": "7fbf1e630ae52b7b364791a107b5bee5ff929496", + "rev": "f3460e5ed91658ab94fa41908cfa44991f9f4f02", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "crane": { "locked": { - "lastModified": 1725409566, - "narHash": "sha256-PrtLmqhM6UtJP7v7IGyzjBFhbG4eOAHT6LPYOFmYfbk=", + "lastModified": 1730652660, + "narHash": "sha256-+XVYfmVXAiYA0FZT7ijHf555dxCe+AoAT5A6RU+6vSo=", "owner": "ipetkov", "repo": "crane", - "rev": "7e4586bad4e3f8f97a9271def747cf58c4b68f3c", + "rev": "a4ca93905455c07cb7e3aca95d4faf7601cba458", "type": "github" }, "original": { @@ -36,11 +36,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "lastModified": 1726560853, + "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=", "owner": "numtide", "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a", "type": "github" }, "original": { @@ -130,11 +130,11 @@ ] }, "locked": { - "lastModified": 1720066371, - "narHash": "sha256-uPlLYH2S0ACj0IcgaK9Lsf4spmJoGejR9DotXiXSBZQ=", + "lastModified": 1729742964, + "narHash": "sha256-B4mzTcQ0FZHdpeWcpDYPERtyjJd/NIuaQ9+BV1h+MpA=", "owner": "nix-community", "repo": "nix-github-actions", - "rev": "622f829f5fe69310a866c8a6cd07e747c44ef820", + "rev": "e04df33f62cdcf93d73e9a04142464753a16db67", "type": "github" }, "original": { @@ -145,11 +145,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1729665710, - "narHash": "sha256-AlcmCXJZPIlO5dmFzV3V2XF6x/OpNWUV8Y/FMPGd8Z4=", + "lastModified": 1730531603, + "narHash": "sha256-Dqg6si5CqIzm87sp57j5nTaeBbWhHFaVyG7V6L8k3lY=", "owner": "nixos", "repo": "nixpkgs", - "rev": "2768c7d042a37de65bb1b5b3268fc987e534c49d", + "rev": "7ffd9ae656aec493492b44d0ddfb28e79a1ea25d", "type": "github" }, "original": { @@ -178,11 +178,11 @@ ] }, "locked": { - "lastModified": 1726021481, - "narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=", + "lastModified": 1730773675, + "narHash": "sha256-pULo7GryzLkqGveWvnNWVz1Kk6EJqvq+HQeSkwvr7DA=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70", + "rev": "e19e9d54fac1e53f73411ebe22d19f946b1ba0bd", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 469e248..fcc3871 100644 --- a/flake.nix +++ b/flake.nix @@ -51,7 +51,7 @@ src = mnn-src; buildConverter = true; enableVulkan = false; - enableMetal = true; + # enableMetal = true; enableOpencl = true; }; }) @@ -83,15 +83,17 @@ nativeBuildInputs = with pkgs; [ cmake llvmPackages.libclang.lib + clang ]; buildInputs = with pkgs; [] ++ (lib.optionals pkgs.stdenv.isDarwin [ - darwin.apple_sdk.frameworks.OpenCL - darwin.apple_sdk.frameworks.OpenGL - darwin.apple_sdk.frameworks.CoreML - darwin.apple_sdk.frameworks.Metal - ]); + darwin.apple_sdk.frameworks.OpenCL + ] + ++ (lib.optionals pkgs.stdenv.isAarch64 [ + darwin.apple_sdk.frameworks.Metal + darwin.apple_sdk.frameworks.CoreML + ])); } // (lib.optionalAttrs pkgs.stdenv.isLinux { BINDGEN_EXTRA_CLANG_ARGS = "-I${pkgs.llvmPackages.libclang.lib}/lib/clang/18/include"; @@ -110,9 +112,6 @@ cargoDocExtraArgs = "-p mnn -p mnn-sys"; }); mnn-fmt = craneLib.cargoFmt {inherit src;}; - mnn-toml-fmt = craneLib.taploFmt { - src = pkgs.lib.sources.sourceFilesBySuffices src [".toml"]; - }; # Audit dependencies mnn-audit = craneLib.cargoAudit { inherit src advisory-db; @@ -140,33 +139,33 @@ partitionType = "count"; cargoExtraArgs = "-p mnn-sys"; }); - mnn-asan = let - rustPlatform = pkgs.makeRustPlatform { - cargo = nightlyToolchain; - rustc = nightlyToolchain; - }; - in - rustPlatform.buildRustPackage ( - commonArgs - // { - inherit src; - name = "mnn-leaks"; - cargoLock = { - lockFile = ./Cargo.lock; - outputHashes = { - "cmake-0.1.50" = "sha256-GM2D7dpb2i2S6qYVM4HYk5B40TwKCmGQnUPfXksyf0M="; - }; - }; - - buildPhase = '' - cargo test --target aarch64-apple-darwin - ''; - RUSTFLAGS = "-Zsanitizer=address"; - ASAN_OPTIONS = "detect_leaks=1"; - # MNN_COMPILE = "NO"; - # MNN_LIB_DIR = "${pkgs.mnn}/lib"; - } - ); + # mnn-asan = let + # rustPlatform = pkgs.makeRustPlatform { + # cargo = nightlyToolchain; + # rustc = nightlyToolchain; + # }; + # in + # rustPlatform.buildRustPackage ( + # commonArgs + # // { + # inherit src; + # name = "mnn-leaks"; + # cargoLock = { + # lockFile = ./Cargo.lock; + # outputHashes = { + # "cmake-0.1.50" = "sha256-GM2D7dpb2i2S6qYVM4HYk5B40TwKCmGQnUPfXksyf0M="; + # }; + # }; + # + # buildPhase = '' + # cargo test --target aarch64-apple-darwin + # ''; + # RUSTFLAGS = "-Zsanitizer=address"; + # ASAN_OPTIONS = "detect_leaks=1"; + # # MNN_COMPILE = "NO"; + # # MNN_LIB_DIR = "${pkgs.mnn}/lib"; + # } + # ); }; packages = rec { @@ -178,7 +177,9 @@ // { inherit cargoArtifacts; pname = "inspect"; - cargoExtraArgs = "--example inspect"; + cargoExtraArgs = + "--example inspect" + + (lib.optionalString pkgs.stdenv.isDarwin " --features opencl" + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml"); }); default = mnn; } @@ -201,10 +202,12 @@ llvm ] ++ (lib.optionals pkgs.stdenv.isDarwin [ - darwin.apple_sdk.frameworks.OpenCL - darwin.apple_sdk.frameworks.CoreML - darwin.apple_sdk.frameworks.Metal - ]); + darwin.apple_sdk.frameworks.OpenCL + ] + ++ (lib.optionals pkgs.stdenv.isAarch64 [ + darwin.apple_sdk.frameworks.Metal + darwin.apple_sdk.frameworks.CoreML + ])); # RUSTFLAGS = "-Zsanitizer=address"; # ASAN_OPTIONS = "detect_leaks=1"; }; @@ -213,7 +216,7 @@ ) // { githubActions = nix-github-actions.lib.mkGithubMatrix { - checks = nixpkgs.lib.getAttrs ["x86_64-linux"] self.checks; + checks = nixpkgs.lib.getAttrs ["x86_64-linux" "aarch64-darwin"] self.checks; }; }; } diff --git a/mnn-sync/Cargo.toml b/mnn-sync/Cargo.toml index 75ad77b..17760db 100644 --- a/mnn-sync/Cargo.toml +++ b/mnn-sync/Cargo.toml @@ -16,3 +16,6 @@ tracing = { version = "0.1", optional = true } [features] tracing = ["dep:tracing", "mnn/tracing"] + +[dev-dependencies] +tracing-test = "0.2.5" diff --git a/mnn-sync/src/lib.rs b/mnn-sync/src/lib.rs index 7d244e8..2ee066b 100644 --- a/mnn-sync/src/lib.rs +++ b/mnn-sync/src/lib.rs @@ -41,27 +41,153 @@ use error_stack::{Report, ResultExt}; use mnn::*; type Callback = Box Result<()> + Send + 'static>; + pub enum CallbackEnum { Callback(Callback), + Unload(oneshot::Sender>), + Load(oneshot::Sender>), Close, } -// type CallbackSender = (CallbackEnum, oneshot::Sender>); type CallbackSender = CallbackEnum; #[derive(Debug)] pub struct SessionHandle { #[allow(dead_code)] - pub(crate) handle: std::thread::JoinHandle>, + pub(crate) handle: Option>>, pub(crate) sender: Sender, - pub(crate) loop_handle: Receiver>, } impl Drop for SessionHandle { fn drop(&mut self) { - self.sender - .send(CallbackEnum::Close) - .expect("Failed to close SessionHandle"); - // rx.recv().expect("Failed to close SessionHandle"); + self.close().expect("Failed to close session"); + self.handle + .take() + .map(|j| j.join().expect("Failed to join thread")); + } +} + +#[derive(Debug)] +pub struct SessionState { + sr: SessionRunnerState, + receiver: Receiver, + config: ScheduleConfig, +} + +#[derive(Debug, Default)] +pub enum SessionRunnerState { + Loaded(SessionRunner), + Unloaded(Interpreter), + #[default] + Poisoned, +} + +impl SessionRunnerState { + pub fn is_loaded(&self) -> bool { + matches!(self, SessionRunnerState::Loaded(_)) + } + + pub fn is_unloaded(&self) -> bool { + matches!(self, SessionRunnerState::Unloaded(_)) + } + + pub fn is_poisoned(&self) -> bool { + matches!(self, SessionRunnerState::Poisoned) + } + + pub fn loaded(&self) -> Option<&SessionRunner> { + match self { + Self::Loaded(sr) => Some(sr), + _ => None, + } + } + + pub fn unloaded(&self) -> Option<&Interpreter> { + match self { + Self::Unloaded(net) => Some(net), + _ => None, + } + } + + pub fn loaded_mut(&mut self) -> Option<&mut SessionRunner> { + match self { + Self::Loaded(sr) => Some(sr), + _ => None, + } + } + + pub fn unloaded_mut(&mut self) -> Option<&mut Interpreter> { + match self { + Self::Unloaded(net) => Some(net), + _ => None, + } + } + + pub fn unload(&mut self) -> Result<()> { + #[cfg(feature = "tracing")] + tracing::info!("Unloading session"); + match core::mem::take(self) { + Self::Loaded(sr) => { + let net = sr.unload()?; + *self = Self::Unloaded(net); + Ok(()) + } + Self::Unloaded(u) => { + *self = Self::Unloaded(u); + Ok(()) + } + Self::Poisoned => Self::poisoned(), + } + } + + pub fn load(&mut self, config: &ScheduleConfig) -> Result<()> { + #[cfg(feature = "tracing")] + tracing::info!("Loading session"); + match core::mem::take(self) { + Self::Loaded(sr) => { + *self = Self::Loaded(sr); + Ok(()) + } + Self::Unloaded(net) => { + let sr = SessionRunner::create(net, config.clone())?; + *self = Self::Loaded(sr); + Ok(()) + } + Self::Poisoned => Self::poisoned(), + } + } + + pub fn sr(&mut self, config: &ScheduleConfig) -> Result<&mut SessionRunner> { + match self { + Self::Loaded(sr) => Ok(sr), + Self::Unloaded(_) => { + self.load(config)?; + Ok(self.loaded_mut().ok_or_else(|| { + Report::new(ErrorKind::SyncError).attach_printable("Failed to load session") + })?) + } + Self::Poisoned => { + Err(Report::new(ErrorKind::SyncError).attach_printable("Poisoned Session"))? + } + } + } + + fn poisoned() -> Result<()> { + Err(Report::new(ErrorKind::SyncError).attach_printable("Poisoned Session"))?; + Ok(()) + } +} + +impl SessionState { + pub fn sr(&mut self) -> Result<&mut SessionRunner> { + self.sr.sr(&self.config) + } + + pub fn load(&mut self) -> Result<()> { + self.sr.load(&self.config) + } + + pub fn unload(&mut self) -> Result<()> { + self.sr.unload() } } @@ -72,77 +198,161 @@ pub struct SessionRunner { pub session: Session, } -impl SessionHandle { - pub fn new(mut interpreter: Interpreter, config: ScheduleConfig) -> Result { - let (sender, receiver) = flume::unbounded::(); +impl SessionRunner { + pub fn new(interpreter: Interpreter, session: Session) -> Self { + Self { + interpreter, + session, + } + } - let builder = std::thread::Builder::new().name("mnn-session-thread".to_string()); - let (tx, rx) = flume::unbounded(); - let handle = builder - .spawn(move || -> Result<()> { - #[cfg(feature = "tracing")] - tracing::trace!("Initializing mnn session thread"); - let mut session = interpreter.create_session(config)?; - #[cfg(feature = "tracing")] - tracing::trace!("Updating mnn cache file"); - interpreter.update_cache_file(&mut session)?; - let mut session_runner = SessionRunner { - interpreter, - session, + pub fn create(mut net: Interpreter, config: ScheduleConfig) -> Result { + #[cfg(feature = "tracing")] + tracing::trace!("Creating session"); + #[cfg(feature = "tracing")] + let now = std::time::Instant::now(); + let mut session = net.create_session(config)?; + net.update_cache_file(&mut session)?; + #[cfg(feature = "tracing")] + tracing::trace!("Session created in {:?}", now.elapsed()); + Ok(Self { + interpreter: net, + session, + }) + } + + pub fn unload(self) -> Result { + let session = self.session; + let net = self.interpreter; + drop(session); + Ok(net) + } + + pub fn run_session(&mut self) -> Result<()> { + self.interpreter.run_session(&self.session) + } + + pub fn both_mut(&mut self) -> (&mut Interpreter, &mut Session) { + (&mut self.interpreter, &mut self.session) + } + + pub fn resize_session(&mut self) -> Result<()> { + self.interpreter.resize_session(&mut self.session); + Ok(()) + } + + pub fn interpreter(&self) -> &Interpreter { + &self.interpreter + } + + pub fn interpreter_mut(&mut self) -> &mut Interpreter { + &mut self.interpreter + } + + pub fn session(&self) -> &Session { + &self.session + } + + pub fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + fn run_callback(&mut self, f: Callback) -> Result<()> { + #[cfg(feature = "tracing")] + tracing::trace!("Running callback"); + #[cfg(feature = "tracing")] + let now = std::time::Instant::now(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(self))) + .unwrap_or_else(|e| { + let mut err = + Report::new(ErrorKind::SyncError).attach_printable(format!("{:?}", e)); + if let Some(location) = e.downcast_ref::() { + err = err.attach_printable(format!("{:?}", location)); }; + if let Some(backtrace) = e.downcast_ref::() { + err = err.attach_printable(format!("{:?}", backtrace)); + }; + let ret = Err(MNNError::from(err)); #[cfg(feature = "tracing")] - tracing::trace!("Initializing mnn session loop"); - loop { - let f = receiver - .recv() - .change_context(ErrorKind::SyncError) - .attach_printable("Internal Error: Unable to recv (Sender Dropped)")?; - let f = match f { - CallbackEnum::Callback(f) => f, - CallbackEnum::Close => break, - }; - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - f(&mut session_runner) - })) - .unwrap_or_else(|e| { - let mut err = - Report::new(ErrorKind::SyncError).attach_printable(format!("{:?}", e)); - if let Some(location) = e.downcast_ref::() { - err = err.attach_printable(format!("{:?}", location)); - }; - if let Some(backtrace) = e.downcast_ref::() { - err = err.attach_printable(format!("{:?}", backtrace)); - }; - let ret = Err(MNNError::from(err)); - #[cfg(feature = "tracing")] - tracing::error!("Panic in session thread: {:?}", ret); - ret - }); - tx.send(result) - .change_context(ErrorKind::SyncError) - .attach_printable( - "Internal Error: Failed to send result via oneshot channel", - )?; + tracing::error!("Panic in session thread: {:?}", ret); + ret + }); + #[cfg(feature = "tracing")] + tracing::trace!("Callback took: {:?}", now.elapsed()); + result + } +} + +impl SessionHandle { + pub fn new(interpreter: Interpreter, config: ScheduleConfig) -> Result { + let (sender, receiver) = flume::unbounded::(); + let builder = std::thread::Builder::new().name("mnn-session-thread".to_string()); + let spawner = move || -> Result<()> { + let mut ss = SessionState { + sr: SessionRunnerState::Unloaded(interpreter), + receiver, + config, + }; + + loop { + let cmd = ss + .receiver + .recv() + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Unable to recv (Sender possibly dropped without calling close)")?; + match cmd { + CallbackEnum::Callback(f) => { + let sr = ss.sr()?; + sr.run_callback(f) + .map_err(|e| e.into_inner()) + .attach_printable("Failure running the callback")?; + } + CallbackEnum::Unload(tx) => { + let res = ss.unload(); + tx.send(res) + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to send unload message")?; + } + CallbackEnum::Load(tx) => { + let res = ss.load(); + tx.send(res) + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to send load message")?; + } + CallbackEnum::Close => { + break; + } } - Ok(()) - }) + } + Ok(()) + }; + let handle = builder + .spawn(spawner) .change_context(ErrorKind::SyncError) - .attach_printable("Internal Error: Failed to create session thread")?; - // rx.recv() - // .change_context(ErrorKind::SyncError) - // .attach_printable("Internal Error: Unable to recv message")??; + .attach_printable("Internal Error: Failed to spawn thread")?; Ok(Self { - handle, + handle: Some(handle), sender, - loop_handle: rx, }) } + fn is_running(&self) -> bool { + self.handle.as_ref().is_some_and(|j| !j.is_finished()) + } + + fn ensure_running(&self) -> Result<()> { + if !self.is_running() { + Err(Report::new(ErrorKind::SyncError).attach_printable("Session thread is not running"))? + } + Ok(()) + } + pub fn run( &self, f: impl FnOnce(&mut SessionRunner) -> Result + Send + Sync + 'static, ) -> Result { + self.ensure_running()?; let f = f; let (tx, rx) = oneshot::channel(); let wrapped_f = move |sr: &mut SessionRunner| -> Result<()> { @@ -155,16 +365,16 @@ impl SessionHandle { self.sender .send(CallbackEnum::Callback(Box::new(wrapped_f))) .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; - Ok(rx - .recv() + rx.recv() .change_context(ErrorKind::SyncError) - .attach_printable("Internal Error: Unable to recv message")??) + .attach_printable("Internal Error: Unable to recv message")? } pub async fn run_async( &self, f: impl FnOnce(&mut SessionRunner) -> Result + Send + Sync + 'static, ) -> Result { + self.ensure_running()?; let f = f; let (tx, rx) = oneshot::channel(); let wrapped_f = move |sr: &mut SessionRunner| -> Result<()> { @@ -183,37 +393,53 @@ impl SessionHandle { .attach_printable("Internal Error: Unable to recv message")?) } - pub fn panicked(&self) -> bool { - self.loop_handle - .try_recv() - .map(|p| p.is_err()) - .unwrap_or(false) - } -} - -impl SessionRunner { - pub fn run_session(&mut self) -> Result<()> { - self.interpreter.run_session(&self.session) + pub fn unload(&self) -> Result<()> { + let (tx, rx) = oneshot::channel(); + self.sender + .send(CallbackEnum::Unload(tx)) + .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; + rx.recv() + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to recv unload message")? } - pub fn resize_session(&mut self) -> Result<()> { - self.interpreter.resize_session(&mut self.session); - Ok(()) + pub async fn unload_async(&self) -> Result<()> { + let (tx, rx) = oneshot::channel(); + self.sender + .send(CallbackEnum::Unload(tx)) + .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; + rx.await + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to recv unload message")? } - pub fn interpreter(&self) -> &Interpreter { - &self.interpreter + pub fn load(&self) -> Result<()> { + self.ensure_running()?; + let (tx, rx) = oneshot::channel(); + self.sender + .send(CallbackEnum::Load(tx)) + .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; + rx.recv() + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to recv load message")? } - pub fn interpreter_mut(&mut self) -> &mut Interpreter { - &mut self.interpreter + pub async fn load_async(&self) -> Result<()> { + self.ensure_running()?; + let (tx, rx) = oneshot::channel(); + self.sender + .send(CallbackEnum::Load(tx)) + .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; + rx.await + .change_context(ErrorKind::SyncError) + .attach_printable("Internal Error: Failed to recv load message")? } - pub fn session(&self) -> &Session { - &self.session - } - pub fn session_mut(&mut self) -> &mut Session { - &mut self.session + pub fn close(&self) -> Result<()> { + self.sender + .send(CallbackEnum::Close) + .map_err(|e| Report::new(ErrorKind::SyncError).attach_printable(e.to_string()))?; + Ok(()) } } @@ -257,8 +483,8 @@ pub fn test_sync_api() { #[test] #[ignore = "This test is not reliable on CI"] pub fn test_sync_api_race() { - let interpreter = - Interpreter::from_file("tests/assets/realesr.mnn").expect("Failed to create interpreter"); + let interpreter = Interpreter::from_file("../tests/assets/realesr.mnn") + .expect("Failed to create interpreter"); let session_handle = SessionHandle::new(interpreter, ScheduleConfig::new()) .expect("Failed to create session handle"); session_handle @@ -332,7 +558,7 @@ pub fn test_sync_api_race() { let cpu_output = output.create_host_tensor_from_device(true); Ok(cpu_output.host().to_vec()) }) - .expect("Sed"); + .expect("failed to copy output"); } #[test] @@ -340,3 +566,34 @@ pub fn test_sync_api_is_send_sync() { fn is_send_sync() {} is_send_sync::(); } + +#[test] +#[cfg_attr(feature = "tracing", tracing_test::traced_test)] +pub fn test_load_unload() { + let interpreter = Interpreter::from_file("../tests/assets/realesr.mnn") + .expect("Failed to create interpreter"); + let session_handle = SessionHandle::new(interpreter, ScheduleConfig::new()) + .expect("Failed to create session handle"); + session_handle + .run(|sr| { + for input in sr.interpreter.inputs(sr.session()).iter() { + input + .tensor::() + .expect("Failed to get tensor") + .fill(1.0f32); + } + Ok(()) + }) + .expect("Failed to run"); + session_handle.load().expect("Failed to load"); + session_handle.unload().expect("Failed to unload"); + session_handle.load().expect("Failed to load"); + session_handle.unload().expect("Failed to unload"); + session_handle + .run(|sr| { + sr.run_session()?; + Ok(()) + }) + .expect("Failed to run"); + session_handle.unload().expect("Failed to unload"); +} diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index 525318b..2a5e158 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -90,10 +90,6 @@ fn main() -> Result<()> { .find_position(|line| line.contains(HALIDE_SEARCH)) { // remove the last line and the next 3 lines - // patched.remove(idx - 1); - // patched.remove(idx); - // patched.remove(idx); - // patched.remove(idx); let patched = patched .into_iter() .enumerate() @@ -111,12 +107,10 @@ fn main() -> Result<()> { "cargo:rustc-link-search=native={}", install_dir.join("lib").display() ); + } else if let core::result::Result::Ok(lib_dir) = std::env::var("MNN_LIB_DIR") { + println!("cargo:rustc-link-search=native={}", lib_dir); } else { - if let Some(lib_dir) = std::env::var("MNN_LIB_DIR").ok() { - println!("cargo:rustc-link-search=native={}", lib_dir); - } else { - panic!("MNN_LIB_DIR not set while MNN_COMPILE is false"); - } + panic!("MNN_LIB_DIR not set while MNN_COMPILE is false"); } mnn_c_build(PathBuf::from(MANIFEST_DIR).join("mnn_c"), &vendor) @@ -230,7 +224,7 @@ pub fn mnn_cpp_bindgen(vendor: impl AsRef, out: impl AsRef) -> Resul let vendor = vendor.as_ref(); let bindings = bindgen::Builder::default() .clang_args(["-x", "c++"]) - .clang_args(["-std=c++11"]) + .clang_args(["-std=c++14"]) .clang_arg(CxxOption::VULKAN.cxx()) .clang_arg(CxxOption::METAL.cxx()) .clang_arg(CxxOption::COREML.cxx()) @@ -310,7 +304,7 @@ pub fn mnn_c_build(path: impl AsRef, vendor: impl AsRef) -> Result<( pub fn build_cmake(path: impl AsRef, install: impl AsRef) -> Result<()> { let threads = std::thread::available_parallelism()?; cmake::Config::new(path) - .cxxflag("-std=c++14") + .define("CMAKE_CXX_STANDARD", "14") .parallel(threads.get() as u8) .define("MNN_BUILD_SHARED_LIBS", "OFF") .define("MNN_SEP_BUILD", "OFF") diff --git a/mnn-sys/mnn_c/backend_c.cpp b/mnn-sys/mnn_c/backend_c.cpp index e8fbbd4..ff802d2 100644 --- a/mnn-sys/mnn_c/backend_c.cpp +++ b/mnn-sys/mnn_c/backend_c.cpp @@ -5,6 +5,11 @@ MNNBackendConfig *mnnbc_create() { return reinterpret_cast(new MNN::BackendConfig()); } +MNNBackendConfig *mnnbc_clone(const MNNBackendConfig *config) { + return reinterpret_cast(new MNN::BackendConfig( + *reinterpret_cast(config))); +} + void mnnbc_destroy(MNNBackendConfig *config) { delete reinterpret_cast(config); } @@ -18,7 +23,7 @@ void mnnbc_set_power_mode(MNNBackendConfig *config, PowerMode power_mode) { static_cast(power_mode); } void mnnbc_set_precision_mode(MNNBackendConfig *config, - PrecisionMode precision_mode) { + PrecisionMode precision_mode) { reinterpret_cast(config)->precision = static_cast(precision_mode); } diff --git a/mnn-sys/mnn_c/backend_c.h b/mnn-sys/mnn_c/backend_c.h index cd5a4a4..9f2ae4e 100644 --- a/mnn-sys/mnn_c/backend_c.h +++ b/mnn-sys/mnn_c/backend_c.h @@ -27,15 +27,16 @@ typedef struct MNNBackendConfig MNNBackendConfig; // }; MNNBackendConfig *mnnbc_create(); +MNNBackendConfig *mnnbc_clone(const MNNBackendConfig *config); void mnnbc_destroy(MNNBackendConfig *config); void mnnbc_set_memory_mode(MNNBackendConfig *config, MemoryMode memory_mode); void mnnbc_set_power_mode(MNNBackendConfig *config, PowerMode power_mode); -void mnnbc_set_precision_mode(MNNBackendConfig *config, PrecisionMode precision_mode); +void mnnbc_set_precision_mode(MNNBackendConfig *config, + PrecisionMode precision_mode); void mnnbc_set_shared_context(MNNBackendConfig *config, void *shared_context); void mnnbc_set_flags(MNNBackendConfig *config, size_t flags); void mnnbc_reset(MNNBackendConfig *config); - #ifdef __cplusplus } #endif diff --git a/mnn-sys/mnn_c/schedule_c.cpp b/mnn-sys/mnn_c/schedule_c.cpp index 9e2e62b..5fc10b4 100644 --- a/mnn-sys/mnn_c/schedule_c.cpp +++ b/mnn-sys/mnn_c/schedule_c.cpp @@ -8,6 +8,12 @@ MNNScheduleConfig *mnnsc_create() { return reinterpret_cast(mnnsc); } +MNNScheduleConfig *mnnsc_clone(const MNNScheduleConfig *from) { + auto mnn_from = reinterpret_cast(from); + auto mnn_to = new MNN::ScheduleConfig(*mnn_from); + return reinterpret_cast(mnn_to); +} + void mnnsc_destroy(MNNScheduleConfig *config) { auto mnn_config = reinterpret_cast(config); delete mnn_config; diff --git a/mnn-sys/mnn_c/schedule_c.h b/mnn-sys/mnn_c/schedule_c.h index cedb764..048f8bb 100644 --- a/mnn-sys/mnn_c/schedule_c.h +++ b/mnn-sys/mnn_c/schedule_c.h @@ -11,6 +11,7 @@ extern "C" { typedef struct MNNScheduleConfig MNNScheduleConfig; MNNScheduleConfig *mnnsc_create(); +MNNScheduleConfig *mnnsc_clone(const MNNScheduleConfig *from); void mnnsc_destroy(MNNScheduleConfig *config); void mnnsc_set_save_tensors(MNNScheduleConfig *config, const char *const *saveTensors, diff --git a/mnn-sys/mnn_c/session_c.cpp b/mnn-sys/mnn_c/session_c.cpp index 9f8fb7e..665b6dc 100644 --- a/mnn-sys/mnn_c/session_c.cpp +++ b/mnn-sys/mnn_c/session_c.cpp @@ -9,7 +9,7 @@ class Session { } // namespace MNN void Session_destroy(Session *session) { auto mnn_session = reinterpret_cast(session); - // delete mnn_session; + delete mnn_session; } int Session_hasAsyncWork(Session *session) { @@ -17,6 +17,3 @@ int Session_hasAsyncWork(Session *session) { return mnn_session->hasAsyncWork(); // return true; } - - - diff --git a/mnn-sys/vendor b/mnn-sys/vendor index 407a1c1..a74551b 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit 407a1c141d459d093f655bf2fed2a8a5e22a77ce +Subproject commit a74551b4f34b46ce7027c64e800d49fcab497261 diff --git a/sgconfig.yml b/sgconfig.yml new file mode 100644 index 0000000..57747fa --- /dev/null +++ b/sgconfig.yml @@ -0,0 +1,6 @@ +ruleDirs: +- ./tools/sg-lints/lints +utilDirs: +- ./tools/sg-lints/utils +ignores: + - mnn-sys/vendor diff --git a/src/backend.rs b/src/backend.rs index 1895ffb..b71cb34 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -10,6 +10,18 @@ pub struct BackendConfig { __marker: core::marker::PhantomData<()>, } +impl Clone for BackendConfig { + fn clone(&self) -> Self { + unsafe { + let inner = mnn_sys::mnnbc_clone(self.inner); + Self { + inner, + __marker: core::marker::PhantomData, + } + } + } +} + impl Drop for BackendConfig { fn drop(&mut self) { unsafe { diff --git a/src/interpreter.rs b/src/interpreter.rs index 2abe0b3..80c1a25 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -249,6 +249,7 @@ impl Interpreter { assert!(!session.is_null()); Ok(crate::session::Session { inner: session, + net: self.inner, __session_internals: crate::SessionInternals::Single(schedule), __marker: PhantomData, }) @@ -279,6 +280,7 @@ impl Interpreter { assert!(!session.is_null()); Ok(crate::session::Session { inner: session, + net: self.inner, __session_internals: crate::SessionInternals::MultiSession(schedules), __marker: PhantomData, }) @@ -290,7 +292,7 @@ impl Interpreter { let path = path.as_ref(); crate::ensure!(path.exists(), ErrorKind::IOError); let path = path.to_str().ok_or_else(|| error!(ErrorKind::AsciiError))?; - let c_path = std::ffi::CString::new(path).unwrap(); + let c_path = std::ffi::CString::new(path).change_context(ErrorKind::AsciiError)?; unsafe { mnn_sys::modelPrintIO(c_path.as_ptr()) } Ok(()) } @@ -376,13 +378,16 @@ impl Interpreter { /// # Safety /// Very **unsafe** since it doesn't check the type of the tensor /// as well as the shape of the tensor + /// + /// **Panics** if the name is not ascii + /// **Undefined Behavior** if the tensor is not of type `H` pub unsafe fn input_unchecked<'s, H: HalideType>( &self, session: &'s crate::Session, name: impl AsRef, ) -> Tensor>> { let name = name.as_ref(); - let c_name = std::ffi::CString::new(name).unwrap(); + let c_name = std::ffi::CString::new(name).expect("Input tensor name is not ascii"); let input = mnn_sys::Interpreter_getSessionInput(self.inner, session.inner, c_name.as_ptr()); Tensor::from_ptr(input) @@ -603,7 +608,7 @@ pub struct TensorInfo<'t, 'tl> { impl core::fmt::Debug for TensorInfo<'_, '_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let tensor = self.raw_tensor(); - let shape = tensor.shape().clone(); + let shape = tensor.shape(); f.debug_struct("TensorInfo") .field("name", &self.name()) .field("tensor", &shape) @@ -864,7 +869,7 @@ fn check_whether_sync_actually_works() { } #[test] -#[ignore = "This test doesn't work in CI"] +#[ignore = "Fails on CI"] fn try_to_drop_interpreter_before_session() { let file = Path::new("tests/assets/realesr.mnn") .canonicalize() diff --git a/src/schedule.rs b/src/schedule.rs index 3defc0c..875ec67 100644 --- a/src/schedule.rs +++ b/src/schedule.rs @@ -3,6 +3,8 @@ use std::{ffi::CString, mem::ManuallyDrop}; use crate::{prelude::*, BackendConfig}; +/// Backend used for running the model +/// /// The `ForwardType` enum is used to specify the backend that will be used for forward computation /// in the MNN framework. Each variant corresponds to a different backend, which may be enabled /// or disabled based on the features enabled in the build configuration. @@ -21,7 +23,7 @@ use crate::{prelude::*, BackendConfig}; /// # Example /// /// ```rust -/// use mnn_rs::schedule::ForwardType; +/// use mnn::schedule::ForwardType; /// /// let forward_type = ForwardType::Auto; /// println!("Selected forward type: {:?}", forward_type); @@ -127,7 +129,7 @@ impl core::str::FromStr for ForwardType { /// # Example /// /// ```rust -/// use mnn_rs::schedule::{ScheduleConfig, ForwardType}; +/// use mnn::schedule::{ScheduleConfig, ForwardType}; /// /// let mut config = ScheduleConfig::new(); /// config.set_type(ForwardType::Auto); @@ -168,6 +170,19 @@ pub struct ScheduleConfig { pub(crate) __marker: core::marker::PhantomData<()>, } +impl Clone for ScheduleConfig { + fn clone(&self) -> Self { + unsafe { + let inner = mnnsc_clone(self.inner); + Self { + inner, + backend_config: self.backend_config.clone(), + __marker: core::marker::PhantomData, + } + } + } +} + impl Drop for ScheduleConfig { fn drop(&mut self) { unsafe { diff --git a/src/session.rs b/src/session.rs index 179cac6..5be36db 100644 --- a/src/session.rs +++ b/src/session.rs @@ -7,6 +7,13 @@ use crate::prelude::*; pub struct Session { /// Pointer to the underlying MNN session. pub(crate) inner: *mut mnn_sys::Session, + /// Pointer to the underlying MNN interpreter + /// # Safety Note + /// Since the interpreter is actually not owned by session but it is a shared resource we can + /// reasonably assume that the interpreter will outlive the session. (This is not a compile + /// time gurantee yet) + /// TODO: Add a proper lifetime bound to ensure the interpreter outlives the session. + pub(crate) net: *mut mnn_sys::Interpreter, /// Internal session configurations. pub(crate) __session_internals: crate::SessionInternals, /// Marker to ensure the struct is not Send or Sync. @@ -33,11 +40,18 @@ impl Session { // pub fn as_ptr_mut(&self) -> *mut mnn_sys::Session { // self.session // } + // + pub fn destroy(&mut self) { + unsafe { + mnn_sys::Interpreter_releaseSession(self.net, self.inner); + } + // unsafe { mnn_sys::Session_destroy(self.inner) } + } } impl Drop for Session { /// Custom drop implementation to ensure the underlying MNN session is properly destroyed. fn drop(&mut self) { - unsafe { mnn_sys::Session_destroy(self.inner) } + self.destroy(); } } diff --git a/tests/basic.rs b/tests/basic.rs index 071894f..420ab95 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -3,12 +3,12 @@ use common::*; use mnn::ForwardType; #[test] -#[ignore = "takes too long"] fn test_basic_cpu() { test_basic(ForwardType::CPU).unwrap(); } #[cfg(feature = "metal")] #[test] +#[ignore = "Doesn't work on ci"] fn test_basic_metal() { test_basic(ForwardType::Metal).unwrap(); } diff --git a/tests/common.rs b/tests/common.rs index 30b44e0..57a51cb 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -25,6 +25,7 @@ impl AsRef<[u8]> for Model { } } +#[allow(dead_code)] pub fn test_basic(backend: ForwardType) -> Result<()> { let mut net = mnn::Interpreter::from_file("tests/assets/realesr.mnn")?; let mut config = ScheduleConfig::new(); @@ -46,6 +47,7 @@ pub fn test_basic(backend: ForwardType) -> Result<()> { Ok(()) } +#[allow(dead_code)] pub fn test_multipath_session(backend: ForwardType, backend2: ForwardType) -> Result<()> { use mnn::BackendConfig; diff --git a/tests/resizing.rs b/tests/resizing.rs index c2ad983..f3e4753 100644 --- a/tests/resizing.rs +++ b/tests/resizing.rs @@ -5,42 +5,40 @@ use common::*; pub fn test_resizing() -> Result<()> { let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); let mut net = Interpreter::from_bytes(&model).unwrap(); - net.set_cache_file("resizing.cache", 128); - let mut config = ScheduleConfig::default(); + net.set_cache_file("resizing.cache", 128)?; + let config = ScheduleConfig::default(); #[cfg(feature = "opencl")] config.set_type(ForwardType::OpenCL); let mut session = net.create_session(config).unwrap(); - net.update_cache_file(&mut session); + net.update_cache_file(&mut session)?; - loop { - let now = std::time::Instant::now(); - let mut mask = unsafe { net.input_unresized::(&session, "mask") }?; - net.resize_tensor(&mut mask, [2048, 2048]); - drop(mask); + let now = std::time::Instant::now(); + let mut mask = unsafe { net.input_unresized::(&session, "mask") }?; + net.resize_tensor(&mut mask, [2048, 2048]); + drop(mask); - let mut og = unsafe { net.input_unresized::(&session, "original") }?; - net.resize_tensor(&mut og, [2048, 2048, 3]); - drop(og); + let mut og = unsafe { net.input_unresized::(&session, "original") }?; + net.resize_tensor(&mut og, [2048, 2048, 3]); + drop(og); - let mut pain = unsafe { net.input_unresized::(&session, "inpainted") }?; - net.resize_tensor(&mut pain, [2048, 2048, 3]); - drop(pain); + let mut pain = unsafe { net.input_unresized::(&session, "inpainted") }?; + net.resize_tensor(&mut pain, [2048, 2048, 3]); + drop(pain); - net.resize_session(&mut session); - let inputs = net.inputs(&session); - for tensor_info in inputs.iter() { - let tensor = tensor_info.tensor::().unwrap(); - println!( - "{:13}: {:>13}", - tensor_info.name(), - format!("{:?}", tensor.shape()) - ); - let mut host = tensor.create_host_tensor_from_device(false); - host.host_mut().fill(1.0); - } - drop(inputs); - net.run_session(&session).unwrap(); - println!("{:?}", now.elapsed()); + net.resize_session(&mut session); + let inputs = net.inputs(&session); + for tensor_info in inputs.iter() { + let tensor = tensor_info.tensor::().unwrap(); + println!( + "{:13}: {:>13}", + tensor_info.name(), + format!("{:?}", tensor.shape()) + ); + let mut host = tensor.create_host_tensor_from_device(false); + host.host_mut().fill(1.0); } + drop(inputs); + net.run_session(&session).unwrap(); + println!("{:?}", now.elapsed()); Ok(()) } diff --git a/tests/segfault.rs b/tests/segfault.rs index d93f5c3..d061941 100644 --- a/tests/segfault.rs +++ b/tests/segfault.rs @@ -28,6 +28,7 @@ fn test_segfault_case_1_() -> Result<(), Box> { } #[test] +#[ignore] pub fn test_resizing() { use mnn::*; let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); @@ -42,9 +43,9 @@ pub fn test_resizing() { let mut shape = tensor.shape().as_ref().to_vec(); dbg!(&shape); shape.iter_mut().for_each(|v| { - // if *v == -1 { - // *v = 2048; - // } + if *v == -1 { + *v = 3; + } }); dbg!(&shape); net.resize_tensor(&mut tensor, &shape); diff --git a/tools/sg-lints/lints/no-println.yaml b/tools/sg-lints/lints/no-println.yaml new file mode 100644 index 0000000..b053ef3 --- /dev/null +++ b/tools/sg-lints/lints/no-println.yaml @@ -0,0 +1,22 @@ +id: no-println +message: Do not use println! use `tracing::info`/`tracing::trace`/`tracing::debug` instead +severity: warning +language: Rust +rule: + kind: macro_invocation + pattern: println!($$$ITEMS) + not: + inside: + stopBy: end + matches: is-test + +fix: tracing::info!($$$ITEMS) +files: + - src/**/*.rs + - mnn-sync/src/*.rs + - mnn-sys/src/*.rs + - mnn-bridge/src/**/*.rs +ignores: + - build.rs + - mnn-sys/build.rs + - mnn-sys/vendor/**/*.rs diff --git a/tools/sg-lints/lints/no-unwrap.yml b/tools/sg-lints/lints/no-unwrap.yml new file mode 100644 index 0000000..39d395a --- /dev/null +++ b/tools/sg-lints/lints/no-unwrap.yml @@ -0,0 +1,19 @@ +id: no-unwrap +message: Do not use unwrap +severity: error +language: Rust +rule: + pattern: $ITEM.unwrap() + not: + inside: + stopBy: end + matches: is-test +files: + - src/**/*.rs + - mnn-sync/src/*.rs + - mnn-sys/src/*.rs + - mnn-bridge/src/**/*.rs +ignores: + - build.rs + - mnn-sys/vendor/**/*.rs + diff --git a/tools/sg-lints/utils/is-test.yml b/tools/sg-lints/utils/is-test.yml new file mode 100644 index 0000000..23df265 --- /dev/null +++ b/tools/sg-lints/utils/is-test.yml @@ -0,0 +1,23 @@ +id: is-test +language: Rust + +rule: + all: + - kind: function_item + - follows: + stopBy: + kind: function_item + matches: test-token + +utils: + test-token: + kind: attribute_item + has: + kind: attribute + has: + any: + - pattern: test + - pattern: tokio::test + +ignores: + - mnn-sys/vendor/**/*.rs