diff --git a/Cargo.lock b/Cargo.lock index 18b1d0c..fbbebaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "aho-corasick" version = "1.1.3" @@ -57,7 +72,7 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -67,7 +82,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -82,6 +97,30 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] + +[[package]] +name = "backtrace-ext" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537beee3be4a18fb023b570f80e3ae28003db9167a751266b259926e25539d50" +dependencies = [ + "backtrace", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -178,6 +217,7 @@ dependencies = [ "anstyle", "clap_lex", "strsim", + "terminal_size", ] [[package]] @@ -212,6 +252,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + [[package]] name = "diffy" version = "0.4.0" @@ -221,6 +267,31 @@ dependencies = [ "nu-ansi-term", ] +[[package]] +name = "divan" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d567df2c9c2870a43f3f2bd65aaeb18dbce1c18f217c3e564b4fbaeb3ee56c" +dependencies = [ + "cfg-if", + "clap", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] + +[[package]] +name = "divan-macros" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27540baf49be0d484d8f0130d7d8da3011c32a44d4fc873368154f1510e574a2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dunce" version = "1.0.5" @@ -233,6 +304,16 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "error-stack" version = "0.5.0" @@ -272,6 +353,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + [[package]] name = "glob" version = "0.3.1" @@ -284,6 +371,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "is_ci" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7655c9839580ee829dfacba1d1278c2b7883e50a277ff7541299489d6bdfdc45" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -330,9 +423,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets", + "windows-targets 0.52.6", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + [[package]] name = "lock_api" version = "0.4.12" @@ -365,12 +464,52 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "miette" +version = "7.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4edc8853320c2a0dab800fbda86253c8938f6ea88510dc92c5f1ed20e794afc1" +dependencies = [ + "backtrace", + "backtrace-ext", + "cfg-if", + "miette-derive", + "owo-colors", + "supports-color", + "supports-hyperlinks", + "supports-unicode", + "terminal_size", + "textwrap", + "thiserror", + "unicode-width", +] + +[[package]] +name = "miette-derive" +version = "7.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf09caffaac8068c346b6df2a7fc27a177fd20b39421a39ce0a211bde679a6c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "minimal-lexical" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mnn" version = "0.2.0" @@ -378,9 +517,11 @@ dependencies = [ "anyhow", "bytemuck", "clap", + "divan", "dunce", "error-stack", "libc", + "miette", "mnn-sys", "oneshot", "thiserror", @@ -461,7 +602,7 @@ version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -491,6 +632,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "object" +version = "0.36.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -503,6 +653,12 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" +[[package]] +name = "owo-colors" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -581,12 +737,24 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -602,6 +770,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "0.38.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -620,6 +801,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + [[package]] name = "spin" version = "0.9.8" @@ -635,6 +822,27 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "supports-color" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8775305acf21c96926c900ad056abeef436701108518cf890020387236ac5a77" +dependencies = [ + "is_ci", +] + +[[package]] +name = "supports-hyperlinks" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c0a1e5168041f5f3ff68ff7d95dcb9c8749df29f6e7e89ada40dd4c9de404ee" + +[[package]] +name = "supports-unicode" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7401a30af6cb5818bb64852270bb722533397edcfc7344954a38f420819ece2" + [[package]] name = "syn" version = "2.0.77" @@ -652,6 +860,27 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "terminal_size" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21bebf2b7c9e0a515f6e0f8c51dc0f8e4696391e6f1ff30379559f8365fb0df7" +dependencies = [ + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "textwrap" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" +dependencies = [ + "smawk", + "unicode-linebreak", + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.64" @@ -709,6 +938,12 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +[[package]] +name = "unicode-linebreak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" + [[package]] name = "unicode-width" version = "0.1.14" @@ -804,13 +1039,37 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -819,28 +1078,46 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -853,24 +1130,48 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index ba0cca1..e599a1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,3 +42,9 @@ default = ["mnn-threadpool"] anyhow = "1.0" bytemuck = "1.17" clap = { version = "4.5", features = ["derive"] } +divan = "0.1.14" +miette = { version = "7.2.0", features = ["fancy"] } + +[[bench]] +name = "mnn-bench" +harness = false diff --git a/benches/mnn-bench.rs b/benches/mnn-bench.rs new file mode 100644 index 0000000..8616f3a --- /dev/null +++ b/benches/mnn-bench.rs @@ -0,0 +1,33 @@ +use divan::*; +#[divan::bench_group(sample_size = 5, sample_count = 5)] +mod mnn_realesr_bench_with_ones { + use divan::*; + use mnn::*; + #[divan::bench] + pub fn mnn_benchmark_cpu(bencher: Bencher) { + let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); + let mut config = ScheduleConfig::new(); + config.set_type(ForwardType::CPU); + let session = net.create_session(config).unwrap(); + bencher.bench_local(|| { + let mut input = net.input(&session, "data").unwrap(); + input.fill(1f32); + net.run_session(&session).unwrap(); + }); + } + + #[cfg(feature = "opencl")] + #[divan::bench] + pub fn mnn_benchmark_opencl(bencher: Bencher) { + let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); + let mut config = ScheduleConfig::new(); + config.set_type(ForwardType::OpenCL); + let session = net.create_session(config).unwrap(); + bencher.bench_local(|| { + let mut input = net.input(&session, "data").unwrap(); + input.fill(1f32); + net.run_session(&session).unwrap(); + net.wait(&session); + }); + } +} diff --git a/examples/inspect.rs b/examples/inspect.rs index 07e0a64..f7ffb9c 100644 --- a/examples/inspect.rs +++ b/examples/inspect.rs @@ -54,6 +54,13 @@ pub fn main() -> anyhow::Result<()> { interpreter.update_cache_file(&mut session)?; let mut current = 0; + println!("--------------------------------Info--------------------------------"); + let mem = interpreter.memory(&session)?; + let flops = interpreter.flops(&session)?; + println!("Memory: {:?}MiB", mem); + println!("Flops : {:?}M", flops); + println!("ResizeStatus : {:?}", interpreter.resize_status(&session)?); + time!(loop { println!("--------------------------------Inputs--------------------------------"); interpreter.inputs(&session).iter().for_each(|x| { @@ -75,6 +82,7 @@ pub fn main() -> anyhow::Result<()> { }, }; }); + println!("Running session"); interpreter.run_session(&session)?; println!("--------------------------------Outputs--------------------------------"); diff --git a/flake.lock b/flake.lock index 3ccaf4a..616aba4 100644 --- a/flake.lock +++ b/flake.lock @@ -108,11 +108,11 @@ "mnn-src": { "flake": false, "locked": { - "lastModified": 1726130630, - "narHash": "sha256-dZNOAKPLjnR8MCkk0iJsOFHII2DMsUcQApfu7OfBRL8=", + "lastModified": 1727414197, + "narHash": "sha256-lfjkdaB4ZKL53wAd1lrEHBxzd0/AXnxZbkm2/0i/iUs=", "owner": "alibaba", "repo": "MNN", - "rev": "f0e516a4bd26855e21fedb47ed609113d86c5684", + "rev": "407a1c141d459d093f655bf2fed2a8a5e22a77ce", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 22f5cc1..3710e3e 100644 --- a/flake.nix +++ b/flake.nix @@ -68,11 +68,7 @@ craneLib = (crane.mkLib pkgs).overrideToolchain stableToolchain; craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain stableToolchainWithLLvmTools; - mnnFilters = path: type: (craneLib.filterCargoSources path type) || (lib.hasSuffix ".patch" path || lib.hasSuffix ".mnn" path || lib.hasSuffix ".h" path || lib.hasSuffix ".cpp" path || lib.hasSuffix ".svg" path); - src = lib.cleanSourceWith { - filter = mnnFilters; - src = ./.; - }; + src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" "lock"]; MNN_SRC = mnn-src; commonArgs = { diff --git a/mnn-sys/build.rs b/mnn-sys/build.rs index 81d3af1..69342bd 100644 --- a/mnn-sys/build.rs +++ b/mnn-sys/build.rs @@ -3,7 +3,6 @@ use anyhow::*; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::{ - fs::Permissions, path::{Path, PathBuf}, sync::LazyLock, }; @@ -27,6 +26,13 @@ static EMSCRIPTEN_CACHE: LazyLock = LazyLock::new(|| { emscripten_cache }); +const HALIDE_PATCH_1: &str = r#"#if __cplusplus >= 201103L"#; +const HALIDE_PATCH_2: &str = r#" +#else + HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t +#endif +"#; + fn ensure_vendor_exists(vendor: impl AsRef) -> Result<()> { if vendor .as_ref() @@ -65,9 +71,15 @@ fn main() -> Result<()> { .context("Failed to copy vendor")?; let intptr = vendor.join("include").join("MNN").join("HalideRuntime.h"); #[cfg(unix)] - std::fs::set_permissions(&intptr, Permissions::from_mode(0o644))?; - try_patch_file("patches/halide_type_t_64.patch", intptr) - .context("Failed to patch vendor")?; + std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; + // try_patch_file("patches/halide_type_t_64.patch", intptr) + // .context("Failed to patch vendor")?; + + let intptr_contents = std::fs::read_to_string(&intptr)?; + let patched = intptr_contents + .replace(HALIDE_PATCH_1, "") + .replace(HALIDE_PATCH_2, ""); + std::fs::write(intptr, patched)?; } let install_dir = out_dir.join("mnn-install"); @@ -258,6 +270,7 @@ pub fn build_cmake(path: impl AsRef, install: impl AsRef) -> Result< config.define("MNN_COREML", CxxOption::COREML.cmake_value()); config.define("MNN_OPENCL", CxxOption::OPENCL.cmake_value()); config.define("MNN_OPENGL", CxxOption::OPENGL.cmake_value()); + config.define("CMAKE_CXX_FLAGS", "-O0"); // #[cfg(windows)] if *TARGET_OS == "windows" { config.define("CMAKE_CXX_FLAGS", "-DWIN32=1"); @@ -444,3 +457,12 @@ impl CxxOption { } } } + +// mod cc_build { +// use super::*; +// pub fn build(source: impl AsRef) -> Result { +// let mut builder = cc::Build::new(); +// builder.std("c++11").cpp(true); +// todo!() +// } +// } diff --git a/mnn-sys/mnn_c/interpreter_c.cpp b/mnn-sys/mnn_c/interpreter_c.cpp index 5698f90..2fe130f 100644 --- a/mnn-sys/mnn_c/interpreter_c.cpp +++ b/mnn-sys/mnn_c/interpreter_c.cpp @@ -275,8 +275,9 @@ int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session, int code, void *ptr) { auto mnn_interpreter = reinterpret_cast(interpreter); auto mnn_session = reinterpret_cast(session); - return mnn_interpreter->getSessionInfo( + auto ret = mnn_interpreter->getSessionInfo( mnn_session, static_cast(code), ptr); + return static_cast(ret); } TensorInfoArray const * Interpreter_getSessionOutputAll(const Interpreter *interpreter, diff --git a/mnn-sys/vendor b/mnn-sys/vendor index ddd9a61..b521c25 160000 --- a/mnn-sys/vendor +++ b/mnn-sys/vendor @@ -1 +1 @@ -Subproject commit ddd9a61ded60d3862f203a1b9f161f1b905753c4 +Subproject commit b521c25825c0296f5da2ce512e217c10785e1dc7 diff --git a/src/interpreter.rs b/src/interpreter.rs index 552b429..6233685 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -169,6 +169,10 @@ impl Interpreter { unsafe { mnn_sys::Interpreter_resizeSession(self.inner, session.inner) } } + pub fn resize_session_reallocate(&self, session: &mut crate::Session) { + unsafe { mnn_sys::Interpreter_resizeSessionWithFlag(self.inner, session.inner, 1i32) } + } + pub fn resize_tensor(&self, tensor: &mut Tensor, dims: impl AsTensorShape) { let dims = dims.as_tensor_shape(); let dims_len = dims.size; @@ -420,12 +424,74 @@ impl Interpreter { Ok(()) } - // /// Wait for all output tensors to be ready after computation - // pub fn wait(&self, session: &crate::session::Session) { - // self.outputs(session).iter().for_each(|tinfo| { - // tinfo.raw_tensor().wait_read(true); - // }); - // } + /// Wait for all output tensors to be ready after computation + pub fn wait(&self, session: &crate::session::Session) { + self.outputs(session).iter().for_each(|tinfo| { + tinfo + .raw_tensor() + .wait(mnn_sys::MapType::MAP_TENSOR_READ, true); + }); + } + + pub fn memory(&self, session: &crate::session::Session) -> Result { + let mut memory = 0f32; + let memory_ptr = &mut memory as *mut f32; + let ret = unsafe { + mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 0, memory_ptr.cast()) + }; + ensure!( + ret == 1, + ErrorKind::InterpreterError; + "Failed to get memory usage" + ); + Ok(memory) + } + + pub fn flops(&self, session: &crate::Session) -> Result { + let mut flop = 0.0f32; + let flop_ptr = &mut flop as *mut f32; + let ret = unsafe { + mnn_sys::Interpreter_getSessionInfo( + self.inner, + session.inner, + 1, + flop_ptr.cast::(), + ) + }; + ensure!( + ret == 1, + ErrorKind::InterpreterError; + "Failed to get flops" + ); + Ok(flop) + } + + pub fn resize_status(&self, session: &crate::Session) -> Result { + let mut resize_status = 0i32; + let ptr = &mut resize_status as *mut i32; + let ret = unsafe { + mnn_sys::Interpreter_getSessionInfo(self.inner, session.inner, 2, ptr.cast()) + }; + ensure!( + ret == 1, + ErrorKind::InterpreterError; + "Failed to get resize status" + ); + match resize_status { + 0 => Ok(ResizeStatus::None), + 1 => Ok(ResizeStatus::NeedMalloc), + 2 => Ok(ResizeStatus::NeedResize), + _ => Err(error!(ErrorKind::InterpreterError)), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(C)] +pub enum ResizeStatus { + None = 0, + NeedMalloc = 1, + NeedResize = 2, } #[repr(transparent)] @@ -436,11 +502,11 @@ pub struct TensorInfo<'t, 'tl> { impl core::fmt::Debug for TensorInfo<'_, '_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - // let tensor = self.raw_tensor(); - // let shape = tensor.shape().clone(); + let tensor = self.raw_tensor(); + let shape = tensor.shape().clone(); f.debug_struct("TensorInfo") .field("name", &self.name()) - // .field("tensor", &shape) + .field("tensor", &shape) .finish() } } diff --git a/src/tensor.rs b/src/tensor.rs index 87be0cf..087e3b8 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -537,25 +537,10 @@ pub trait AsTensorShape { impl> AsTensorShape for T { fn as_tensor_shape(&self) -> TensorShape { let this = self.as_ref(); - let len = this.len(); - if len > 4 { - TensorShape { - shape: this[..4].try_into().expect("Impossible"), - size: 4, - } - } else { - TensorShape { - shape: this - .iter() - .chain(std::iter::repeat(&1)) - .take(4) - .copied() - .collect::>() - .try_into() - .expect("Impossible"), - size: len, - } - } + let size = std::cmp::min(this.len(), 4); + let mut shape = [1; 4]; + shape[..size].copy_from_slice(&this[..size]); + TensorShape { shape, size } } } @@ -684,25 +669,33 @@ impl super::TensorType for Dyn { } } +/// A raw tensor type that doesn't have any guarantees +/// and will be unconditionally dropped #[repr(transparent)] pub struct RawTensor<'r> { pub(crate) inner: *mut mnn_sys::Tensor, pub(crate) __marker: PhantomData<&'r ()>, } -impl<'r> core::ops::Drop for RawTensor<'r> { - fn drop(&mut self) { - unsafe { - mnn_sys::Tensor_destroy(self.inner); - } - } -} +// impl<'r> core::ops::Drop for RawTensor<'r> { +// fn drop(&mut self) { +// unsafe { +// mnn_sys::Tensor_destroy(self.inner); +// } +// } +// } impl<'r> RawTensor<'r> { pub fn shape(&self) -> TensorShape { unsafe { mnn_sys::Tensor_shape(self.inner) }.into() } + pub fn destroy(self) { + unsafe { + mnn_sys::Tensor_destroy(self.inner); + } + } + pub fn dimensions(&self) -> usize { unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize } } @@ -735,8 +728,7 @@ impl<'r> RawTensor<'r> { where T::H: HalideType, { - let this = core::mem::ManuallyDrop::new(self); - super::Tensor::from_ptr(this.inner) + super::Tensor::from_ptr(self.inner) } pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self { diff --git a/tests/assets/resizing.mnn b/tests/assets/resizing.mnn new file mode 100644 index 0000000..1bfcd84 --- /dev/null +++ b/tests/assets/resizing.mnn @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fd086430d897b39d97c2de595a4004cab6db7f3e06678c595074c321846ff14 +size 8624 diff --git a/tests/resizing.rs b/tests/resizing.rs new file mode 100644 index 0000000..c2ad983 --- /dev/null +++ b/tests/resizing.rs @@ -0,0 +1,46 @@ +mod common; +use common::*; + +#[test] +pub fn test_resizing() -> Result<()> { + let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); + let mut net = Interpreter::from_bytes(&model).unwrap(); + net.set_cache_file("resizing.cache", 128); + let mut config = ScheduleConfig::default(); + #[cfg(feature = "opencl")] + config.set_type(ForwardType::OpenCL); + let mut session = net.create_session(config).unwrap(); + net.update_cache_file(&mut session); + + loop { + let now = std::time::Instant::now(); + let mut mask = unsafe { net.input_unresized::(&session, "mask") }?; + net.resize_tensor(&mut mask, [2048, 2048]); + drop(mask); + + let mut og = unsafe { net.input_unresized::(&session, "original") }?; + net.resize_tensor(&mut og, [2048, 2048, 3]); + drop(og); + + let mut pain = unsafe { net.input_unresized::(&session, "inpainted") }?; + net.resize_tensor(&mut pain, [2048, 2048, 3]); + drop(pain); + + net.resize_session(&mut session); + let inputs = net.inputs(&session); + for tensor_info in inputs.iter() { + let tensor = tensor_info.tensor::().unwrap(); + println!( + "{:13}: {:>13}", + tensor_info.name(), + format!("{:?}", tensor.shape()) + ); + let mut host = tensor.create_host_tensor_from_device(false); + host.host_mut().fill(1.0); + } + drop(inputs); + net.run_session(&session).unwrap(); + println!("{:?}", now.elapsed()); + } + Ok(()) +} diff --git a/tests/segfault.rs b/tests/segfault.rs index 2be7f48..6342bcc 100644 --- a/tests/segfault.rs +++ b/tests/segfault.rs @@ -26,3 +26,44 @@ fn test_segfault_case_1_() -> Result<(), Box> { drop(net); Ok(()) } + +#[test] +pub fn test_resizing() { + use mnn::*; + let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); + let mut net = Interpreter::from_bytes(&model).unwrap(); + let config = ScheduleConfig::default(); + let mut session = net.create_session(config).unwrap(); + let inputs = net.inputs(&session); + + loop { + for tensor_info in inputs.iter() { + let mut tensor = unsafe { tensor_info.tensor_unresized::() }.unwrap(); + let mut shape = tensor.shape().as_ref().to_vec(); + dbg!(&shape); + shape.iter_mut().for_each(|v| { + // if *v == -1 { + // *v = 2048; + // } + }); + dbg!(&shape); + net.resize_tensor(&mut tensor, &shape); + } + drop(inputs); + + net.resize_session(&mut session); + let inputs = net.inputs(&session); + for tensor_info in inputs.iter() { + let tensor = tensor_info.tensor::().unwrap(); + println!( + "{:13}: {:>13}", + tensor_info.name(), + format!("{:?}", tensor.shape()) + ); + let mut host = tensor.create_host_tensor_from_device(false); + host.host_mut().fill(1.0); + } + drop(inputs); + net.run_session(&session).unwrap(); + } +}